In [1]:
import networkx as nx
from functools import lru_cache
from itertools import combinations, product

In [2]:
def parse_input(input_gen):
    valves = []
    flow_rates = {}
    targets = []
    for row in input_gen:
        row = row.rstrip().split(" ")
        valves.append(row[1])
        targets.append("".join(row[9:]).split(","))
        flow_rates[valves[-1]] = int(row[4][5:-1])
    return valves, flow_rates, targets

In [3]:
def task_one(filename: str, max_time: int = 30) -> int:
    with open(filename) as f:
        valves, flow_rates, targets = parse_input(f)

    graph = nx.DiGraph()
    graph.add_nodes_from(valves)
    for valve, target in zip(valves, targets):
        for t in target:
            graph.add_edge(valve, t)

    # Prune non-starting nodes with zero flow-rate
    interesting_nodes = {"AA"}.union({nn for nn in valves if flow_rates[nn] > 0})
    path_lengths = {
        nn: dists
        for nn, dists in nx.all_pairs_shortest_path_length(graph, cutoff=30)
    }

    path_lengths = {k: path_lengths[k] for k in path_lengths.keys() & interesting_nodes}
    for k, v in path_lengths.items():
        path_lengths[k] = {_k: v[_k] for _k in v.keys() & interesting_nodes}

    @lru_cache(maxsize=10000)
    def check_solution(curr_score: int, curr_time: int, curr_valve: str, opened_valves: tuple[str]):
        best_score = curr_score
        opened_valves = set(opened_valves).union({curr_valve})
        opened_valves_tuple = tuple(sorted(opened_valves))
        # print(opened_valves)
        for valve in interesting_nodes.difference(opened_valves):
            new_time = curr_time + path_lengths[curr_valve][valve] + 1
            if new_time < max_time:
                new_score = curr_score + (max_time - new_time) * flow_rates[valve]
                new_sol = check_solution(
                    new_score,
                    new_time,
                    valve,
                    opened_valves_tuple
                )
                best_score = new_sol if new_sol > best_score else best_score
        return best_score
    
    return check_solution(0, 0, "AA", tuple())

In [4]:
task_one("test-input.txt")

1651

In [5]:
task_one("input.txt")

2114

In [6]:
def task_two(filename: str, max_time: int = 30) -> dict[tuple[str], float]:
    with open(filename) as f:
        valves, flow_rates, targets = parse_input(f)

    graph = nx.DiGraph()
    graph.add_nodes_from(valves)
    for valve, target in zip(valves, targets):
        for t in target:
            graph.add_edge(valve, t)

    # Prune non-starting nodes with zero flow-rate
    interesting_nodes = {"AA"}.union({nn for nn in valves if flow_rates[nn] > 0})
    path_lengths = {
        nn: dists
        for nn, dists in nx.all_pairs_shortest_path_length(graph, cutoff=30)
    }

    path_lengths = {k: path_lengths[k] for k in path_lengths.keys() & interesting_nodes}
    for k, v in path_lengths.items():
        path_lengths[k] = {_k: v[_k] for _k in v.keys() & interesting_nodes}

    def check_solution(curr_score: int, curr_time: int, curr_valve: str, opened_valves: tuple[str], cache_dict: dict[tuple[str], float]) -> dict[tuple[str], float]:
        opened_valves = set(opened_valves).union({curr_valve})
        opened_valves_tuple = tuple(sorted(opened_valves))
        cache_dict[opened_valves_tuple] = max(cache_dict.get(opened_valves_tuple, 0), curr_score)
        
        for valve in interesting_nodes.difference(opened_valves):
            new_time = curr_time + path_lengths[curr_valve][valve] + 1
            if new_time < max_time:
                new_score = curr_score + (max_time - new_time) * flow_rates[valve]
                check_solution(
                    new_score,
                    new_time,
                    valve,
                    opened_valves_tuple,
                    cache_dict,
                )
        return cache_dict
    
    return check_solution(0, 0, "AA", tuple(), {})

In [7]:
max(task_two("test-input.txt", max_time=30).values())

1651

In [8]:
max(task_two("input.txt", max_time=30).values())

2114

In [9]:
cached_dict = task_two("test-input.txt", max_time=26)
max(v_h + v_el for (k_h, v_h), (k_el, v_el) in product(cached_dict.items(), repeat=2) if set(k_h) & set(k_el) == {"AA"})

1707

In [10]:
cached_dict = task_two("input.txt", max_time=26)
# if we replaced tuples in the key representation with e.g. ints (as they are hashable and we can directly calc the intersection)
# this would be way quicker
%time max(v_h + v_el for (k_h, v_h), (k_el, v_el) in combinations(cached_dict.items(), 2) if set(k_h) & set(k_el) == {"AA"})

CPU times: user 7.72 s, sys: 0 ns, total: 7.72 s
Wall time: 7.75 s


2666