In [1]:
import ast
import copy
import re

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from aocd import get_data, submit

DAY = 16
YEAR = 2022

In [2]:
# use test data
raw_test = """Valve AA has flow rate=0; tunnels lead to valves DD, II, BB
Valve BB has flow rate=13; tunnels lead to valves CC, AA
Valve CC has flow rate=2; tunnels lead to valves DD, BB
Valve DD has flow rate=20; tunnels lead to valves CC, AA, EE
Valve EE has flow rate=3; tunnels lead to valves FF, DD
Valve FF has flow rate=0; tunnels lead to valves EE, GG
Valve GG has flow rate=0; tunnels lead to valves FF, HH
Valve HH has flow rate=22; tunnel leads to valve GG
Valve II has flow rate=0; tunnels lead to valves AA, JJ
Valve JJ has flow rate=21; tunnel leads to valve II"""

# use real data
raw = get_data(day=DAY, year=YEAR)

print(raw_test)

Valve AA has flow rate=0; tunnels lead to valves DD, II, BB
Valve BB has flow rate=13; tunnels lead to valves CC, AA
Valve CC has flow rate=2; tunnels lead to valves DD, BB
Valve DD has flow rate=20; tunnels lead to valves CC, AA, EE
Valve EE has flow rate=3; tunnels lead to valves FF, DD
Valve FF has flow rate=0; tunnels lead to valves EE, GG
Valve GG has flow rate=0; tunnels lead to valves FF, HH
Valve HH has flow rate=22; tunnel leads to valve GG
Valve II has flow rate=0; tunnels lead to valves AA, JJ
Valve JJ has flow rate=21; tunnel leads to valve II


In [3]:
def parse_data(data):
    data = data.split("\n")
    out = {}
    for d in data:
        src, *dst = re.findall("([A-Z]{2})", d)
        pr = re.search("([0-9]+)", d).group(1)
        out[src] = {"dst": dst, "pr": int(pr)}

    return out


dummy = parse_data(raw_test)
real = parse_data(raw)

dummy

{'AA': {'dst': ['DD', 'II', 'BB'], 'pr': 0},
 'BB': {'dst': ['CC', 'AA'], 'pr': 13},
 'CC': {'dst': ['DD', 'BB'], 'pr': 2},
 'DD': {'dst': ['CC', 'AA', 'EE'], 'pr': 20},
 'EE': {'dst': ['FF', 'DD'], 'pr': 3},
 'FF': {'dst': ['EE', 'GG'], 'pr': 0},
 'GG': {'dst': ['FF', 'HH'], 'pr': 0},
 'HH': {'dst': ['GG'], 'pr': 22},
 'II': {'dst': ['AA', 'JJ'], 'pr': 0},
 'JJ': {'dst': ['II'], 'pr': 21}}

# Part 1

In [21]:
# https://en.wikipedia.org/wiki/Dijkstra%27s_algorithm#Algorithm
def find_optimal_path_djikstra(data, start, fill_value=99999):
    current = start
    unvisited = set(list(data.keys()))
    cost_from_start = {node: fill_value for node in unvisited}
    path_from_start = {node: [] for node in unvisited}
    cost_from_start[current] = 0

    while True:
        neighbors = set(data[current]["dst"]) & unvisited

        for nb in neighbors:
            cost_from_start[nb] = cost_from_start[current] + 1
            path_from_start[nb] = path_from_start[current] + [current]

        unvisited -= {current}
        init_cands = {node for node in unvisited if cost_from_start[node] == fill_value}
        if len(unvisited) == len(init_cands):
            break

        current = sorted(unvisited - init_cands, key=lambda x: cost_from_start[x])[0]

    return cost_from_start, path_from_start

In [43]:
def find_paths(data, time_limit, path, pressure=0, unreleased=None, result=None):
    global path_info

    if unreleased is None:
        unreleased = set([k for k, v in data.items() if v["pr"] > 0])

    if len(unreleased) == 0:
        if result is None or pressure > result[1]:
            result = (path, pressure)
        return result

    current = path[-1]
    dist_to_next, path_to_next = path_info[current]

    # stop early
    unreleased_potential = sum([data[cand]["pr"] * time_limit for cand in unreleased])
    if result is not None and pressure + unreleased_potential < result[1]:
        return result

    candidates = [(c, dist) for (c, dist) in dist_to_next.items() if c in unreleased]
    candidates = sorted(candidates, key=lambda x: dist_to_next[x[0]])

    for cand, dist in candidates:
        ntime = time_limit - dist_to_next[cand] - 1

        npath = path if ntime < 0 else path + path_to_next[cand][1:] + [cand]
        np = pressure if ntime < 0 else pressure + ntime * data[cand]["pr"]
        nr = unreleased if ntime < 0 else unreleased - {cand}

        if ntime <= 0:
            if result is None or np > result[1]:
                result = (path, np)
            continue

        result = find_paths(data, ntime, npath, np, nr, result)

    return result


data = real.copy()

path_info = {k: find_optimal_path_djikstra(data, k) for k in data.keys()}

start = "AA"
result = find_paths(data, 30, ["AA"])[-1]
result

2181

In [38]:
# submit(result, part="a", day=DAY, year=YEAR)

# Part 2

In [47]:
def find_paths(data, time_limit1, time_limit2, path1, path2, pressure=0, unreleased=None, result=None):
    global path_info

    if unreleased is None:
        unreleased = set([k for k, v in data.items() if v["pr"] > 0])

    if len(unreleased) == 0 or (time_limit1 <= 0 and time_limit2 <= 0):
        if result is None or pressure > result[-1]:
            result = (path1, path2, pressure)
            print(result[-1])
        return result

    current1 = path1[-1]
    current2 = path2[-1]
    dist_to_next1, path_to_next1 = path_info[current1]
    dist_to_next2, path_to_next2 = path_info[current2]

    candidates1 = [(c, dist) for (c, dist) in dist_to_next1.items() if c in unreleased]
    candidates1 = sorted(candidates1, key=lambda x: x[1])

    # stop early
    unreleased_potential = sum([data[cand]["pr"] * max([time_limit1, time_limit2]) for cand in unreleased])
    if result is not None and pressure + unreleased_potential < result[-1]:
        return result

    for cand1, dist1 in candidates1:
        ntime1 = time_limit1 - dist_to_next1[cand1] - 1

        npath1 = path1 if ntime1 < 0 else path1 + path_to_next1[cand1][1:] + [cand1]
        np_a1 = pressure if ntime1 < 0 else pressure + ntime1 * data[cand1]["pr"]
        nr_a1 = unreleased if ntime1 < 0 else unreleased - {cand1}

        candidates2 = [(c, dist) for (c, dist) in dist_to_next2.items() if c in nr_a1]
        candidates2 = sorted(candidates2, key=lambda x: x[1])

        if ntime1 <= 0:
            if result is None or np_a1 > result[-1]:
                result = (path1, path2, np_a1)
                print(result[-1])
            continue

        for cand2, dist2 in candidates2:
            ntime2 = time_limit2 - dist_to_next2[cand2] - 1

            npath2 = path2 if ntime2 < 0 else path2 + path_to_next2[cand2][1:] + [cand2]
            np = np_a1 if ntime2 < 0 else np_a1 + ntime2 * data[cand2]["pr"]
            nr = nr_a1 if ntime2 < 0 else nr_a1 - {cand2}

            if ntime2 <= 0:
                if result is None or np > result[-1]:
                    result = (path1, path2, np)
                    print(result[-1])
                continue

            result = find_paths(data, ntime1, ntime2, npath1, npath2, np, nr, result)

    return result


data = real.copy()

path_info = {k: find_optimal_path_djikstra(data, k) for k in data.keys()}

start = "AA"
result = find_paths(data, 26, 26, ["AA"], ["AA"])[-1]
result

1985
2039
2194
2218
2237
2295
2371
2383
2399
2445
2498
2516
2582
2646
2669
2732
2734
2737
2772
2786
2824


2824

In [61]:
# submit(result, part="b", day=DAY, year=YEAR)