In [1]:
test_input = """Valve AA has flow rate=0; tunnels lead to valves DD, II, BB
Valve BB has flow rate=13; tunnels lead to valves CC, AA
Valve CC has flow rate=2; tunnels lead to valves DD, BB
Valve DD has flow rate=20; tunnels lead to valves CC, AA, EE
Valve EE has flow rate=3; tunnels lead to valves FF, DD
Valve FF has flow rate=0; tunnels lead to valves EE, GG
Valve GG has flow rate=0; tunnels lead to valves FF, HH
Valve HH has flow rate=22; tunnel leads to valve GG
Valve II has flow rate=0; tunnels lead to valves AA, JJ
Valve JJ has flow rate=21; tunnel leads to valve II"""

In [2]:
import re
from functools import cache, lru_cache
import networkx as nx
import itertools

matcher = "Valve (\w+) has flow rate=(\d+); tunnels? leads? to valves? (.*)"

In [3]:
def parse_input(s):
    tunnel_data = {}

    for l in s.split("\n"):
        r = re.match(matcher, l)

        valve = r.group(1)
        flow_rate = int(r.group(2))
        connected = r.group(3).split(", ")

        tunnel_data[valve] = (flow_rate, connected)
    return tunnel_data

In [5]:
def parse(s):
    tunnel_data = parse_input(s)

    g = nx.DiGraph()

    for tunnel, (flow_rate, neighbors) in tunnel_data.items():
        g.add_node(tunnel)
        g.nodes[tunnel]['label'] = f"{tunnel} - {flow_rate}"
        g.nodes[tunnel]['fr'] = flow_rate

        for n in neighbors:
            g.add_edge(tunnel, n)

    shortest_paths = dict(nx.all_pairs_shortest_path_length(g))

    return {l: (tunnel_data[l][0], {n: d for n, d in shortest_paths[l].items() if n != l}) for l in g.nodes()}

In [6]:
def p1(tunnel_data, time):
    all_to_vist = set([l for l, d in tunnel_data.items() if d[0] > 0 or l == 'AA'])

    @cache
    def choose(location, remaining_time, already_on):
        '''
        return the max total flow we can add from this position.
        '''
        if remaining_time <= 1:
            return 0

        flow_rate, shortest_paths = tunnel_data[location]
        added_flow = remaining_time * flow_rate

        new_on = already_on | frozenset([location])
        to_visit = all_to_vist - new_on

        return added_flow + max((choose(n, remaining_time-shortest_paths[n]-1, new_on) for n in to_visit), default=0)
    
    result = choose("AA", time, frozenset())
    print(choose.cache_info())
    return result

In [7]:
from itertools import chain, combinations

# https://docs.python.org/3/library/itertools.html#itertools-recipes
def powerset(iterable):
    "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

In [8]:
tunnel_data = parse(open("inputs/16").read())
# tunnel_data = parse(test_input)

In [9]:
p1(tunnel_data, 30)

CacheInfo(hits=242099, misses=309926, maxsize=None, currsize=309926)


1728

In [15]:
def p2(tunnel_data, time):
    nonzero_flow = frozenset([l for l, d in tunnel_data.items() if d[0] > 0])
    all_to_visit = nonzero_flow | set(['AA'])

    @cache
    def choose(location, remaining_time, already_on):
        '''
        return the max total flow we can add from this position.
        '''
        if remaining_time <= 0:
            return 0

        flow_rate, shortest_paths = tunnel_data[location]
        added_flow = remaining_time * flow_rate

        new_on = already_on | frozenset([location])
        to_visit = all_to_visit - new_on

        return added_flow + max((choose(n, remaining_time-shortest_paths[n]-1, new_on) for n in to_visit), default=0)

    m = 0

    for a in powerset(nonzero_flow):
        subset_a = frozenset(a)
        subset_b = nonzero_flow - subset_a

        # could add a heuristic here about the size of the subsets but it's not necessary

        r = choose("AA", time, subset_a) + choose("AA", time, subset_b)

        if r > m:
            m = r

    return m

In [16]:
p2(tunnel_data, 26)

2304