In [31]:
import re

lines = open('./data.txt', 'r').read().splitlines()
line_parse = re.compile(r'Valve (\D\D) has flow rate=(\d+); tunnels? leads? to valves? (.*)')

class Node:
    def __init__(self, node_line: str):
        result = line_parse.search(node_line)
        (name, flow, paths) = result.groups()
        self.n = name
        self.f = int(flow)
        self.p_n = paths.split(', ')
        self.costs = {}

    def compute_cost(self, target, current_path):
        cur_len = len(current_path)
        if self is target:
            return cur_len
        costs = [n.compute_cost(target, current_path + [n.n]) for n in self.p if n.n not in current_path]
        costs = [c for c in costs if c is not None]
        if costs:
            cost = min(costs)
            return cost
        else:
            return None

    def get_cost(self, target, current_path = None):
        if target.n not in self.costs:
            self.costs[target.n] = self.compute_cost(target, [])
        return self.costs[target.n]
    
    def set_neighbors(self, neighbors):
        self.p = neighbors.copy()
        self.p_d = { n.n:n for n in neighbors }
        for n in neighbors:
            self.costs[n.n] = 1
    
    def get_potential(self, time_remaining: int):
        return self.f * time_remaining
    
    def get_priority(self, target, time:int):
        c = self.get_cost(target)
        if c > time - 1:
            return None
        return -(target.get_potential(time - c))

node_list = [Node(line) for line in lines]
node_map = {n.n: n for n in node_list}
for n in node_list:
    n.set_neighbors([node_map[name] for name in n.p_n])

start = node_map['AA']

all_targets = [n for n in node_list if n.f > 0]

all_costs = set()

for t1 in all_targets:
    all_costs.add(start.get_cost(t1))
    for t2 in all_targets:
        if t1 is not t2:
            all_costs.add(t1.get_cost(t2))

min_cost = min(all_costs)


In [2]:
class TreeNode:
    def __init__(self, node: Node, visited: list[str], time: int, flow: int):
        self.n = node
        self.v = visited
        self.t = time
        self.f = flow
        self.next = []
    
    def __str__(self):
        return f'TreeNode n={self.n.n} v={self.v} t={self.t} f={self.f} nc={len(self.next)}'

    def add_next(self, node):
        self.next.append(node)

In [None]:
root_node = TreeNode(node_map['AA'], ['AA'], 30, 0)
leaves = []

def build_tree(node: TreeNode):
    add_leaf = False
    for n in all_targets:
        if n.n not in node.v:
            c = node.n.get_cost(n)
            if ((c + 1) < node.t):
                p = n.get_potential(node.t - c)
                node.add_next(TreeNode(n, node.v + [n.n], node.t - c - 1, node.f + p))
            else:
                add_leaf = True
    if (add_leaf):
        leaf = TreeNode(None, node.v, 0, node.f)
        node.add_next(leaf)
        leaves.append(leaf)
    for n in node.next:
        if (n.n):
            build_tree(n)

build_tree(root_node)

In [72]:
[str(n) for n in root_node.next]
len(leaves)
max([l.f for l in leaves])

1880

In [35]:
from itertools import combinations, permutations
leaves2 = []

Cur = tuple[Node, Node, int]

# def traverse_tree(workers:list[Cur], time: int, visited: set[str], flow):
#     def create_leaf():
#         leaf = TreeNode(None, visited, 0, flow)
#         leaves2.append(leaf)
#     if time == 0:
#         create_leaf()
#         return

#     add_leaf = False
#     next_flow = flow
#     next_time = time - 1
#     needs_targ = []
#     moving = []
#     for w in workers:
#         if w[2] == 0 and w[1]:
#             needs_targ.append(w)
#             next_flow += w[1].get_potential(time + 1)
#         elif not w[1]:
#             moving.append(w)
#         else:
#             moving.append((w[0], w[1], w[2] - 1))
    
#     if needs_targ:
#         remaining = [t for t in all_targets if t.n not in visited]
#         combos = list(combinations(remaining, len(needs_targ)))
#         for p in combos:
#             new_workers = [m for m in moving]
#             for i in range(len(needs_targ)):
#                 t = needs_targ[i]
#                 new_cur = t[1]
#                 new_targ = p[i]
#                 c = new_cur.get_cost(new_targ) + 1
#                 if c < next_time:
#                     new_workers.append((new_cur, new_targ, c))
#                 else:
#                     new_workers.append((new_cur, None, 0))
#             if [w for w in new_workers if w[1] is not None]:
#                 next_visited = visited.union([n[1].n for n in new_workers if n[1] is not None])
#                 traverse_tree(new_workers, next_time, next_visited, next_flow)
#             else:
#                 add_leaf = True
#         if (add_leaf):
#             create_leaf()
#     else:
#         traverse_tree(moving, next_time, visited, next_flow)

# traverse_tree([(None, start, 0), (None, start, 0)], 27, set(), 0)

sums = set()
max_time = 26
init_rem = set(all_targets)
print(min_cost)

def permute(l_last: Node, lt: int, r_last: Node, rt: int, remaining: set[Node], total: int, depth: int):
    if rt <= min_cost + 1 and lt <= min_cost + 1:
        sums.add(total)
        return
    l_o = [n for n in remaining if l_last.costs[n.n] + 1 < lt]
    r_o = [n for n in remaining if r_last.costs[n.n] + 1 < rt]
    if not l_o and not r_o:
        sums.add(total)
        return
    l_o = l_o or [None]
    r_o = r_o or [None]
    for l in l_o:
        for r in r_o:
            if l is r:
                continue
            l_new, r_new = l_last, r_last
            lt_new, rt_new = lt, rt
            new_total = total
            is_leaf = True
            if l and l_last.costs[l.n] < (lt - 1):
                l_new = l
                lt_new = lt - 1 - l_last.costs[l.n]
                new_total += l.get_potential(lt_new)
                is_leaf = False
            if r and r_last.costs[r.n] < (rt - 1):
                r_new = r
                rt_new = rt - 1 - r_last.costs[r.n]
                new_total += r.get_potential(rt_new)
                is_leaf = False
            if not is_leaf:
                permute(r_new, rt_new, l_new, lt_new, remaining.difference([l_new, r_new]), new_total, depth + 1)
            else:
                sums.add(new_total)

permute(start, max_time, start, max_time, init_rem, 0, 0)

2


In [36]:
max(sums)

2520