-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
93 lines (79 loc) · 3.02 KB
/
main.py
File metadata and controls
93 lines (79 loc) · 3.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import math
from utils import *
from functools import cache
from functools import cmp_to_key
input = [i.strip() for i in open("input.txt","r").readlines()]
flow = {}
connect = defaultdict(lambda : [])
all_valves = 0
dist = {}
start = 999
def floyd_warshall(graph, bidirectional=False):
# Given a graph dict of format {vertex: [edges]}
# returns the shortest path between every pair of nodes in the graph.
dist = defaultdict(lambda : math.inf)
for node, edges in graph.items():
for dest in edges:
dist[(node, dest)] = 1 # use weight if weighted
if bidirectional:
dist[(dest, node)] = 1
# for node in graph.keys(): # use if self connections important.
# dist[(node, node)] = 0
for k in graph.keys():
for i in graph.keys():
for j in graph.keys():
if dist[(i,j)] > dist[(i,k)] + dist[(k,j)]:
dist[(i,j)] = dist[(i,k)] + dist[(k,j)]
return dist
B = defaultdict(int)
@cache
def dfs(here, mins_remaining, bitmask, accumulated):
B[bitmask] = max(B[bitmask], accumulated)
if mins_remaining <= 1: # it takes a min to open a valve
return
acc = 0
if flow[here] > 0:
# Open the valve
mins_remaining -= 1 # It takes a min to open the valve
acc = mins_remaining * flow[here]
bitmask = bitmask | (1<<here)
if bitmask == all_valves: # If all the valves are open, stop searching.
B[bitmask] = max(B[bitmask], accumulated + acc)
return
for c in connect[here]:
if not (1<<c) & bitmask:
dfs(c, mins_remaining - dist[(here,c)], bitmask, accumulated + acc)
def solve():
global all_valves, start, dist, connect, B
all_conn = {}
for row in input:
# Each row like: "Valve AA has flow rate=0; tunnels lead to valves DD, II, BB"
valve = row.split(" ")[1]
rate = int(row.split("rate=")[1].split(";")[0])
if "lead to valves " in row:
tunnels = row.split("lead to valves ")[1].split(", ")
else:
tunnels = [row.split("leads to valve ")[1]]
flow[valve] = rate
all_conn[valve] = tunnels
remap = {p[0]:i for i,p in enumerate(sorted([(k,v) for k,v in flow.items()],
key=cmp_to_key(lambda x,y: y[1]-x[1])))}
start = remap["AA"]
dist_pairs = floyd_warshall(all_conn)
for (n,o), d in dist_pairs.items():
if d != math.inf and flow[o] > 0 and (flow[n] > 0 or n == "AA"):
connect[remap[n]].append(remap[o])
dist[(remap[n],remap[o])] = d
all_valves = (1 << (len([(k,v) for k,v in flow.items() if k == "AA" or v > 0]) - 2)) - 1
for k,v in remap.items():
flow[v] = flow[k]
del(flow[k])
dfs(start, 30, 0, 0)
part1 = max(B.values())
print(part1)
B = defaultdict(int)
dfs(start, 26, 0, 0)
part2 = max(v1 + v2 for bitmask1, v1 in B.items() for bitmask2, v2 in B.items() if not bitmask1 & bitmask2)
print(part2)
return (part1, part2)
assert(solve() == (2320, 2967))