In [2]:
from pathlib import Path
import numpy as np
from dataclasses import dataclass
import string
from typing import List, Self

In [3]:
string.ascii_uppercase

'ABCDEFGHIJKLMNOPQRSTUVWXYZ'

In [4]:
@dataclass
class Node:
    name : str
    value : int
    neighbours : List[str]

In [5]:
allowed = string.ascii_uppercase + string.digits + ' '

def parse_line(line):
    parts = ''.join([c for c in line[1:] if c in allowed]).split()
    return Node(parts[0], int(parts[1]), parts[2:])

def read(prefix='data'):
    lines = Path(f'{prefix}/16.txt').read_text().rstrip().split('\n')
    nodes = [parse_line(line) for line in lines]
    return {n.name : n for n in nodes}

In [6]:
G = read('test')

In [7]:
@dataclass
class State:
    openvalves : set
    location : str
    timeleft : int

    def __hash__(self):
        l = self.location
        t = self.timeleft
        ov = tuple(sorted(self.openvalves))
        return hash((l,t, ov))

In [33]:
def moves(G, sols):
    nw = {}
    for sol in sols:
        score = sols[sol]
        if sol.timeleft > 0:
            for nb in G[sol.location].neighbours:
                state = State(sol.openvalves, nb, sol.timeleft-1)
                nw[state] = max(score, nw.get(state,0))
            if (sol.location not in sol.openvalves) and (G[sol.location].value > 0):
                ov = sol.openvalves | {sol.location}
                state = State(ov, sol.location, sol.timeleft - 1)
                new_score = score + (sol.timeleft - 1)*G[sol.location].value
                nw[state] = max(new_score, nw.get(state,0))
    return nw

In [34]:
def get_sols(G, src, totaltime):
    src = State(set(), src, totaltime)
    nw = {src : 0}
    sols = {src : 0}
    old_len = 0
    for i in range(totaltime):
    # while len(sols) > old_len:
        nw = moves(G,sols)
        old_len = len(sols)
        for state in nw:
            sols[state] = max(nw[state], sols.get(state,0))
        print(len(sols))
    return sols

In [35]:
sols = get_sols(read('test'), 'AA', 30)
max(sols.items(), key= lambda kv: kv[1])

4
10
22
42
73
120
186
273
382
513
667
847
1053
1281
1528
1792
2072
2368
2678
2996
3316
3636
3956
4276
4596
4916
5236
5556
5876
6196
6196


(State(openvalves={'HH', 'JJ', 'DD', 'CC', 'EE', 'BB'}, location='CC', timeleft=6),
 1651)

In [36]:
sols = get_sols(read(), 'AA', 30)
max(sols.items(), key= lambda kv: kv[1])

6
12
33
69
125
230
404
697
1151
1850
2896
4439
6698
9840
14205
20327
28649
39772
54571
73818
98543
130385
170997
222118
286121
365687
463203
581662
725243
898184
898184


(State(openvalves={'AI', 'CU', 'YE', 'QK', 'KS', 'CJ', 'KB'}, location='YE', timeleft=6),
 1724)