In [1]:
import math
import itertools

In [2]:
def parse(input_file):
    valves = []
    with open(input_file) as f:
        for line in f:
            words = line.rstrip().split()
            rate = int(words[4].rstrip(';').split('=')[-1])
            neighbours = [l.rstrip(',') for l in words[9:]]
            valves.append([words[1], rate, neighbours])
    return valves


class Valve:
    def __init__(self, name: str, rate: int, neighbours: list[str]) -> None:
        self.name = name
        self.rate = rate
        self.neighbour_names = neighbours
        self.open = False


class PressureNetwork:
    def __init__(self, valve_list, start="AA"):
        self.valves = dict()
        for name, rate, neigh in valve_list:
            self.valves[name] = Valve(name, rate, neigh)
        for valve in self.valves.values():
            valve.neighbours = [self.valves[l] for l in valve.neighbour_names]
        self.start_name = start
        self.start_valve = self.valves[self.start_name]
        self.positive_valves = {k: v for k, v in self.valves.items() if v.rate > 0}
        self.compute_distances()
        self.rate_dict = {k: v.rate for k, v in self.positive_valves.items()}
    
    
    def compute_distances(self):
        self.dist = {v: {u: math.inf for u in self.valves} for v in self.valves}
        for v in self.valves:
            self.dist[v][v] = 0
            for u in self.valves[v].neighbour_names:
                self.dist[v][u] = 1
        
        for w in self.valves:
            for v in self.valves:
                for u in self.valves:
                    self.dist[v][u] = min(self.dist[v][u], self.dist[v][w] + self.dist[w][u])
        
        
        self.dist_rv = {v: self.dist[self.start_name][v] for v in self.positive_valves}
        self.dist_small = {v: {u: self.dist[v][u] for u in self.positive_valves} for v in self.positive_valves}
    
    def score(self, current, open_valves, k):
        results = dict()
        closed = (v for v in self.positive_valves if v not in open_valves)
        for v in closed:
            distance = self.dist[current][v]
            rate = self.rate_dict[v]
            score = (30 - k - distance - 1) * rate
            results[v] = score
        return results
    
    
    def score2(self, current, open_valves, k):
        results = dict()
        score1 = self.score(current, open_valves, k)
        closed = tuple(v for v in self.positive_valves if v not in open_valves)
        if len(closed) == 1:
            return {(k, ""): (v, 0) for k, v in score1.items()}
        for v1, v2 in itertools.permutations(closed, 2):
            distance = self.dist[current][v1] + self.dist[v1][v2]
            rate = self.rate_dict[v2]
            score = (score1[v1], (30 - k - distance - 2) * rate)
            results[(v1, v2)] = score
        return results
        
    
    def greedy(self):
        # state is open set + current + time
        pressure = 0
        current = "AA"
        open_valves = ()
        n_valves = len(self.positive_valves)
        k = 0
        while k < 30 and len(open_valves) < n_valves:
            moves_scores = self.score(current, open_valves, k)
            maxv, newp = max(moves_scores.items(), key=lambda x: x[1])
            pressure += newp
            k += (self.dist[current][maxv] + 1)
            current = maxv
            open_valves += (maxv,)
        return pressure, open_valves


    def greedy2(self):
        pressure = 0
        current = "AA"
        open_valves = ()
        n_valves = len(self.positive_valves)
        k = 0
        while k < 30 and len(open_valves) < n_valves:
            moves_scores = self.score2(current, open_valves, k)
            (v1, v2), (p1, p2) = max(moves_scores.items(), key=lambda x: sum(x[1]))
            pressure += p1
            k += (self.dist[current][v1] + 1)
            current = v1
            open_valves += (v1,)
        return pressure, open_valves
    
    def score_n(self, n, current, open_valves, t_left):
        closed = tuple(v for v in self.positive_valves if v not in open_valves)
        scores = {(current,): {"t": t_left, "score": 0}} # scores[path]["t"] scores[path]["score"]
        for k in range(min(n, len(closed))):
            for path in itertools.permutations(closed, k+1):
                # print(path, scores)
                new_path = (current, ) + path
                if new_path not in scores:
                    old_path = new_path[:-1]
                    v = old_path[-1]
                    w = new_path[-1]
                    distance = self.dist[v][w]
                    try:
                        new_t = (scores[old_path]["t"] - distance - 1)
                    except:
                        continue
                    if new_t >= 0:
                        rate = self.rate_dict[w]
                        new_score = scores[old_path]["score"] + new_t * rate
                        scores[new_path] = {"t": new_t, "score": new_score}
        return scores
                    
    def greedy_n(self, n, tmax=30):
        pressure = 0
        current = "AA"
        open_valves = ()
        n_valves = len(self.positive_valves)
        t_left = tmax
        while t_left > 0 and len(open_valves) < n_valves:
            # print(current, open_valves, t_left)
            scores = self.score_n(n, current, open_valves, t_left)
            v, _ = max(scores.items(), key=lambda x: x[1]["score"])
            if len(v) < 2:
                break
            v = v[1]
            t_left -= (self.dist[current][v] + 1)
            pressure += (t_left) * self.rate_dict[v]
            current = v
            open_valves += (v,)
        return pressure, open_valves
    
    
    def greedy_pair(self, n, tmax=26):
        pressure = 0
        current_me = "AA"
        current_el = "AA"
        open_valves = ()
        n_valves = len(self.positive_valves)
        t_me = tmax
        t_el = tmax
        while t_me > 0 and t_el > 0 and len(open_valves) < n_valves:
            scores_me = self.score_n(n, current_me, open_valves, t_me)
            scores_el = self.score_n(n, current_el, open_valves, t_el)
            # naive - I take the best one
            v, _ = max(scores_me.items(), key=lambda x: x[1]["score"])
            if len(v) < 2:
                break
            v = v[1]
            # and remove it from the elephant's
            scores_el = {path: score for path, score in scores_el.items() if len(path) > 1 and path[1] != v}
            t_me -= (self.dist[current_me][v] + 1)
            pressure += t_me * self.rate_dict[v]
            current_me = v
            try:
                w, _ = max(scores_el.items(), key=lambda x: x[1]["score"])
                w = w[1]
                t_el -= (self.dist[current_el][w] + 1)
                pressure += t_el * self.rate_dict[w]
                current_el = w
                open_valves += (v, w)
            except ValueError:
                # elephant has nowhere to go
                open_valves += (v, )
        return pressure, open_valves

    @property
    def closed_valves(self):
        return {k: v for k, v in self.positive_valves.items() if not(v.open)}
    
    
    def best_paths(self, n=1, tmax=30):
        q = [(("AA",), tmax, 0)] # (path, time_left, score)
        best = 0
        while q:
            path, time_left, score = q.pop(0)
            current = path[-1]
            scores = self.score(current, path, tmax-time_left)
            scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
            for v, p in scores[:n]:
                dt = self.dist[current][v] + 1
                if dt < time_left:
                    q.append((path + (v,), time_left - dt, score + p))
                    best = max(best, score + p)
        return best
    
    
    def score_next_pair(self, current, times, open_valves):
        closed = (v for v in self.positive_valves if v not in open_valves)
        moves = []
        for v in closed:
            for k, (c, t) in enumerate(zip(current, times)):
                rate = self.rate_dict[v]
                new_t = t - self.dist[c][v] - 1
                score = new_t * rate
                moves.append((k, v, new_t, score))
        return moves
    
    
    def best_multiple(self, n_top=1, tmax=26, n_agents=2):
        current = ("AA", ) * n_agents
        times = (tmax, ) * n_agents
        q = [(current, times, (), 0)] # current, times, open_valves, score
        best = 0
        while q:
            current, times, open_valves, score = q.pop(0)
            scores = self.score_next_pair(current, times, open_valves)
            scores = sorted(scores, key=lambda x: x[-1], reverse=True)
            for agent, v, new_t, new_score in scores[:n_top]:
                if new_t > 0:
                    best = max(best, score + new_score)
                    cl = list(current)
                    tl = list(times)
                    cl[agent] = v
                    tl[agent] = new_t
                    q.append((tuple(cl), tuple(tl), open_valves + (v,), score + new_score))
        return best

In [3]:
sample_network = PressureNetwork(parse("16_sample.txt"))

print("part 1 - look k steps ahead")
for k in range(1, 10):
    print(f"{k} -", sample_network.greedy_n(k))
    
print("part 1 - try k best next")
for k in range(1, 10):
    print(f"{k} -", sample_network.best_paths(k))
    
print("part 2")
for k in range(1, 10):
    print(f"{k} -", sample_network.best_multiple(k))

part 1 - look k steps ahead
1 - (1595, ('JJ', 'DD', 'HH', 'BB', 'EE', 'CC'))
2 - (1649, ('DD', 'JJ', 'BB', 'HH', 'EE', 'CC'))
3 - (1645, ('JJ', 'BB', 'DD', 'HH', 'EE', 'CC'))
4 - (1651, ('DD', 'BB', 'JJ', 'HH', 'EE', 'CC'))
5 - (1651, ('DD', 'BB', 'JJ', 'HH', 'EE', 'CC'))
6 - (1651, ('DD', 'BB', 'JJ', 'HH', 'EE', 'CC'))
7 - (1651, ('DD', 'BB', 'JJ', 'HH', 'EE', 'CC'))
8 - (1651, ('DD', 'BB', 'JJ', 'HH', 'EE', 'CC'))
9 - (1651, ('DD', 'BB', 'JJ', 'HH', 'EE', 'CC'))
part 1 - try k best next
1 - 1595
2 - 1649
3 - 1651
4 - 1651
5 - 1651
6 - 1651
7 - 1651
8 - 1651
9 - 1651
part 2
1 - 1699
2 - 1707
3 - 1707
4 - 1707
5 - 1707
6 - 1707
7 - 1707
8 - 1707
9 - 1707


In [4]:
%%time

volcano = PressureNetwork(parse("16_input.txt"))

print("part 1 - look k steps ahead")
for k in range(1, 6):
    print(f"{k} -", volcano.greedy_n(k))

print("part 1 - try k best next")
for k in range(1, 10):
    print(f"{k} -", volcano.best_paths(k))

print("part 2")
for k in range(1, 5):
    print(f"{k} -", volcano.best_multiple(k))

part 1 - look k steps ahead
1 - (1884, ('MC', 'ED', 'XZ', 'XL', 'FX'))
2 - (1947, ('MC', 'ED', 'XZ', 'JY', 'XL', 'EM'))
3 - (1991, ('MC', 'XL', 'ED', 'XZ', 'JY'))
4 - (1991, ('MC', 'XL', 'ED', 'XZ', 'JY'))
5 - (1991, ('MC', 'XL', 'ED', 'XZ', 'JY'))
part 1 - try k best next
1 - 1884
2 - 1947
3 - 1991
4 - 1991
5 - 1991
6 - 1991
7 - 1991
8 - 1991
9 - 1991
part 2
1 - 2274
2 - 2517
3 - 2705
4 - 2705
CPU times: user 3.81 s, sys: 38.6 ms, total: 3.85 s
Wall time: 3.85 s
