In [1]:
from pathlib import Path
import numpy as np
from dataclasses import dataclass
import string
from typing import List

In [2]:
@dataclass
class Node:
    name : str
    value : int
    neighbours : List[str]

In [3]:
allowed = string.ascii_uppercase + string.digits + ' '

def parse_line(line):
    parts = ''.join([c for c in line[1:] if c in allowed]).split()
    return Node(parts[0], int(parts[1]), parts[2:])

def read(prefix='data'):
    lines = Path(f'{prefix}/16.txt').read_text().rstrip().split('\n')
    nodes = [parse_line(line) for line in lines]
    return {n.name : n for n in nodes}

In [9]:
@dataclass
class State:
    openvalves : set
    location : str
    timeleft : int

    def __hash__(self):
        l = self.location
        t = self.timeleft
        ov = tuple(sorted(self.openvalves))
        return hash((l,t, ov))

In [10]:
def moves(G, sols):
    nw = {}
    for sol in sols:
        score = sols[sol]
        if sol.timeleft > 0:
            for nb in G[sol.location].neighbours:
                state = State(sol.openvalves, nb, sol.timeleft-1)
                nw[state] = max(score, nw.get(state,0))
            if (sol.location not in sol.openvalves) and (G[sol.location].value > 0):
                ov = sol.openvalves | {sol.location}
                state = State(ov, sol.location, sol.timeleft - 1)
                new_score = score + (sol.timeleft - 1)*G[sol.location].value
                nw[state] = max(new_score, nw.get(state,0))
    return nw

In [11]:
def get_sols(G, src, totaltime):
    sols = {State(set(), src, totaltime) : 0}
    for i in range(totaltime):
        sols = moves(G,sols)
    return sols

In [12]:
sols = get_sols(read('test'), 'AA', 30)
max(sols.items(), key= lambda kv: kv[1])

(State(openvalves={'EE', 'BB', 'HH', 'CC', 'DD', 'JJ'}, location='GG', timeleft=0),
 1651)

In [13]:
sols = get_sols(read(), 'AA', 30)
max(sols.items(), key= lambda kv: kv[1])

(State(openvalves={'KB', 'AI', 'KS', 'CU', 'YE', 'QK', 'CJ'}, location='YE', timeleft=0),
 1724)

In [14]:
@dataclass
class State2:
    openvalves : set
    location1 : str
    location2 : str
    timeleft : int

    def __hash__(self):
        l1 = self.location1
        l2 = self.location2
        t = self.timeleft
        ov = tuple(sorted(self.openvalves))
        return hash((l1, t, l2, ov))

In [15]:

def moves2(G, sols):
    nw = {}

    can_turnon = lambda loc, openvalves: (loc not in openvalves) and (G[loc].value > 0)
    for sol in sols:
        score = sols[sol]
        loc1 = sol.location1
        loc2 = sol.location2
        for nb1 in G[loc1].neighbours + [loc1]:
            for nb2 in G[loc2].neighbours + [loc2]:
                new_score = score
                ov = sol.openvalves
                if (nb1 == loc1) and can_turnon(loc1, ov):
                    new_score = new_score + (sol.timeleft - 1)*G[loc1].value
                    ov = ov | {loc1}
                if (nb2 == loc2) and can_turnon(loc2, ov):
                    new_score = new_score + (sol.timeleft - 1)*G[loc2].value   
                    ov = ov | {loc2}
                state = State2(ov, min(nb1, nb2), max(nb1, nb2), sol.timeleft-1)
                nw[state] = max(new_score, nw.get(state,0))

    return nw

In [48]:
def max_remaining_score(G, sol):
    closedvalves = set(G.keys()) - set(sol.openvalves)
    vals = sorted([G[v].value for v in closedvalves], reverse=True)
    score = sum(t*v for (v,t) in zip(vals, range(sol.timeleft-1,0,-1)))
    return max(score,0)

In [49]:
def prune(G, sols, top=None):
    mx_score = max(sols.values())
    pruned = {}
    for (sol, score) in sols.items():
        if score + max_remaining_score(G, sol) >= mx_score:
            pruned[sol] = score
    if top is not None:
        pruned = dict(sorted(pruned.items(), key = lambda kv : -kv[1])[:top])
    return pruned

In [50]:
def get_sols2(G, src, totaltime, top=None, verbose=False):
    sols = {State2(set(), src, src, totaltime) : 0}
    for i in range(totaltime):
        if verbose:
            print(f'{i}/{totaltime}, {len(sols)}')
        sols = moves2(G,sols)
        sols = prune(G,sols, top)

    return sols

In [51]:
sols = get_sols2(read('test'), 'AA', 26, top=1000)
max(sols.values())


1707

In [53]:
sols = get_sols2(read(), 'AA', 26, top=1000, verbose=True)
max(sols.values())


0/26, 1
1/26, 21
2/26, 66
3/26, 375
4/26, 1430
5/26, 3481
6/26, 8513
7/26, 20157
8/26, 44539
9/26, 89904
10/26, 100000
11/26, 100000
