In [None]:
from pathlib import Path
import numpy as np
from dataclasses import dataclass
import string
from typing import List

In [None]:
@dataclass
class Node:
    name : str
    value : int
    neighbours : List[str]

In [None]:
allowed = string.ascii_uppercase + string.digits + ' '

def parse_line(line):
    parts = ''.join([c for c in line[1:] if c in allowed]).split()
    return Node(parts[0], int(parts[1]), parts[2:])

def read(prefix='data'):
    lines = Path(f'{prefix}/16.txt').read_text().rstrip().split('\n')
    nodes = [parse_line(line) for line in lines]
    return {n.name : n for n in nodes}

In [None]:
G = read('test')

In [None]:
@dataclass
class State:
    openvalves : set
    location : str
    timeleft : int

    def __hash__(self):
        l = self.location
        t = self.timeleft
        ov = tuple(sorted(self.openvalves))
        return hash((l,t, ov))

In [None]:
def moves(G, sols):
    nw = {}
    for sol in sols:
        score = sols[sol]
        if sol.timeleft > 0:
            for nb in G[sol.location].neighbours:
                state = State(sol.openvalves, nb, sol.timeleft-1)
                nw[state] = max(score, nw.get(state,0))
            if (sol.location not in sol.openvalves) and (G[sol.location].value > 0):
                ov = sol.openvalves | {sol.location}
                state = State(ov, sol.location, sol.timeleft - 1)
                new_score = score + (sol.timeleft - 1)*G[sol.location].value
                nw[state] = max(new_score, nw.get(state,0))
    return nw

In [None]:
def get_sols(G, src, totaltime):
    sols = {State(set(), src, totaltime) : 0}
    for i in range(totaltime):
        sols = moves(G,sols)
    return sols

In [None]:
sols = get_sols(read('test'), 'AA', 30)
max(sols.items(), key= lambda kv: kv[1])

(State(openvalves={'DD', 'BB', 'JJ', 'CC', 'EE', 'HH'}, location='GG', timeleft=0),
 1651)

In [None]:
sols = get_sols(read(), 'AA', 30)
max(sols.items(), key= lambda kv: kv[1])

(State(openvalves={'CJ', 'KS', 'KB', 'YE', 'CU', 'AI', 'QK'}, location='YE', timeleft=0),
 1724)

In [49]:
@dataclass
class State2:
    openvalves : set
    location1 : str
    location2 : str
    timeleft : int

    def __hash__(self):
        l1 = self.location1
        l2 = self.location2
        t = self.timeleft
        ov = tuple(sorted(self.openvalves))
        return hash((l1, t, l2, ov))

In [118]:

def moves2(G, sols):
    nw = {}

    can_turnon = lambda loc, openvalves: (loc not in openvalves) and (G[loc].value > 0)
    for sol in sols:
        score = sols[sol]
        loc1 = sol.location1
        loc2 = sol.location2
        for nb1 in G[loc1].neighbours + [loc1]:
            for nb2 in G[loc2].neighbours + [loc2]:
                new_score = score
                ov = sol.openvalves
                if (nb1 == loc1) and can_turnon(loc1, ov):
                    new_score = new_score + (sol.timeleft - 1)*G[loc1].value
                    ov = ov | {loc1}
                if (nb2 == loc2) and can_turnon(loc2, ov):
                    new_score = new_score + (sol.timeleft - 1)*G[loc2].value   
                    ov = ov | {loc2}
                state = State2(ov, min(nb1, nb2), max(nb1, nb2), sol.timeleft-1)
                nw[state] = max(new_score, nw.get(state,0))

    return nw

In [122]:
def get_sols2(G, src, totaltime, top=None):
    sols = {State2(set(), src, src, totaltime) : 0}
    for i in range(totaltime):
        print(f'{i}/{totaltime}, {len(sols)}')
        sols = moves2(G,sols)
        if top is not None:
            sols = dict(sorted(sols.items(), key=lambda kv : -kv[1])[:top])
    return sols

In [123]:
sols = get_sols2(read('test'), 'AA', 26, top=1000)
max(sols.items(), key= lambda kv: kv[1])


0/26, 1
1/26, 10
2/26, 43
3/26, 142
4/26, 355
5/26, 691
6/26, 1000
7/26, 1000
8/26, 1000
9/26, 1000
10/26, 1000
11/26, 1000
12/26, 1000
13/26, 1000
14/26, 1000
15/26, 1000
16/26, 1000
17/26, 1000
18/26, 1000
19/26, 1000
20/26, 1000
21/26, 1000
22/26, 1000
23/26, 1000
24/26, 1000
25/26, 1000


(State2(openvalves={'DD', 'BB', 'JJ', 'CC', 'EE', 'HH'}, location1='DD', location2='FF', timeleft=0),
 1707)

In [125]:
sols = get_sols2(read(), 'AA', 26, top=1000)
max(sols.items(), key= lambda kv: kv[1])


0/26, 1
1/26, 21
2/26, 66
3/26, 375
4/26, 1000
5/26, 1000
6/26, 1000
7/26, 1000
8/26, 1000
9/26, 1000
10/26, 1000


KeyboardInterrupt: 