In [1]:
from aocd import get_data, submit
from operator import itemgetter
import re

In [2]:
example_data = "Valve AA has flow rate=0; tunnels lead to valves DD, II, BB\nValve BB has flow rate=13; tunnels lead to valves CC, AA\nValve CC has flow rate=2; tunnels lead to valves DD, BB\nValve DD has flow rate=20; tunnels lead to valves CC, AA, EE\nValve EE has flow rate=3; tunnels lead to valves FF, DD\nValve FF has flow rate=0; tunnels lead to valves EE, GG\nValve GG has flow rate=0; tunnels lead to valves FF, HH\nValve HH has flow rate=22; tunnel leads to valve GG\nValve II has flow rate=0; tunnels lead to valves AA, JJ\nValve JJ has flow rate=21; tunnel leads to valve II"
real_data = get_data(day=16, year=2022)

In [3]:
class Valve:
    def __init__(self, name, rate, tunnels):
        self.name = name
        self.rate = rate
        self.tunnels = tunnels
        self.distance_map = {}

    def set_distance_map(self, valves):
        self.distance_map = {self.name: 0}

        visited = set()
        queue = [self.name]
        visited.add(self.name)

        while len(queue)>0:
            current = queue.pop(0)
            for next_valve in valves[current].tunnels:
                if next_valve not in visited:
                    self.distance_map[next_valve] = self.distance_map[current]+1
                    visited.add(next_valve)
                    queue.append(next_valve)

In [4]:
valves = dict()

pattern_text = r'^Valve ([A-Z]+) has flow rate=(\d+); tunnels? leads? to valves? (.*)$'
pattern = re.compile(pattern_text)

for name, rate, tunnels in [pattern.match(line).groups() for line in real_data.splitlines()]:
    valves[name] = Valve(name, int(rate), tunnels.split(", "))

for valve in valves.values():
    valve.set_distance_map(valves)

relevant_valves = [valve.name for valve in sorted(valves.values(), key=lambda valve: valve.rate, reverse=True) if valve.rate > 0]
valves_mapping = {valve:1 << i for i,valve in enumerate(relevant_valves+['AA'])}

In [5]:
%%time

all_valves_open = (1 << (len(relevant_valves)+1)) - 1

def find_max_rate(max_minutes=30):
    max_sum_rate = dict()
    states = [("AA", 0, 0, 0, max_minutes)]

    while len(states)>0:
        current_valve, open_valves, rate, sum_rate, remaining_minutes = states.pop(0)

        max_sum_rate[open_valves] = max(max_sum_rate.get(open_valves, 0), sum_rate)

        if open_valves == all_valves_open:
            continue
        else:
            for valve in relevant_valves:
                if valves_mapping[valve] & open_valves:
                    continue

                next_position = valve
                next_open_valves = open_valves | valves_mapping[valve]
                next_rate = rate + valves[valve].rate
                next_remaining_minutes = remaining_minutes - valves[current_valve].distance_map[valve] - 1
                next_sum_rate = sum_rate + valves[valve].rate * next_remaining_minutes

                if next_remaining_minutes > 0:
                    states.append((next_position, next_open_valves, next_rate, next_sum_rate, next_remaining_minutes))

    return max_sum_rate

print(f"Maximum pressure : {max(find_max_rate(30).values())}")

rates_by_open_valves = find_max_rate(26)
print(f"Maximum pressure with elephant: {max([sum_1 + sum_2 for open_vales_1, sum_1 in rates_by_open_valves.items() for open_vales_2, sum_2 in rates_by_open_valves.items() if open_vales_1 & open_vales_2 == 0])}")

Maximum pressure : 1871
Maximum pressure with elephant: 2416
CPU times: total: 3.73 s
Wall time: 3.77 s
