-
Notifications
You must be signed in to change notification settings - Fork 2
/
__init__.py
182 lines (143 loc) · 5.37 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
from __future__ import annotations
from pprint import pprint
from typing import *
from aocpy import BaseChallenge
from dataclasses import dataclass
import re
from math import inf
import itertools
import copy
@dataclass
class Node:
name: str
value: int
edges: List[Node]
def __init__(
self, name: str, value: int = 0, edges: Optional[List[Tuple[Node, int]]] = None
):
self.name = name
self.value = value
self.edges = [] if edges is None else edges
def shortest_path(nodes: Dict[str, Node], begin: str, end: str) -> List[str]:
priorities: Dict[str, Tuple[Union[int, float], Optional[str]]] = {
node: (inf, None) for node in nodes
}
visited: Dict[str, None] = {begin: None}
cursor = begin
while True:
if cursor == end:
break
n = nodes[cursor]
visited[cursor] = None
length_to_current = priorities[cursor][0]
if length_to_current == inf:
length_to_current = 0
# every neighbour that's not been visited already
for neighbour in n.edges:
if neighbour.name in visited:
continue
if priorities[neighbour.name][0] > length_to_current + 1:
priorities[neighbour.name] = (length_to_current + 1, cursor)
# work out next item
min_priority = (inf, None)
for node_name in priorities:
if node_name not in visited and priorities[node_name][0] < min_priority[0]:
min_priority = (priorities[node_name][0], node_name)
cursor = min_priority[1]
route: List[str] = []
while priorities[cursor][1] is not None:
route.insert(0, cursor)
cursor = priorities[cursor][1]
return route[:-1]
parse_re = re.compile(
r"Valve ([A-Z]+) has flow rate=(\d+); tunnels? leads? to valves? ((?:[A-Z]+,? ?)+)"
)
def parse(
instr: str,
) -> Tuple[Dict[str, Node], List[str], Dict[Tuple[str, str], List[str]]]:
sp = [parse_re.match(line).groups() for line in instr.strip().splitlines()]
nodes: Dict[str, Node] = {}
unjammed_nodes: List[str] = []
for (valve_name, flow_rate_str, _) in sp:
flow_rate = int(flow_rate_str)
nodes[valve_name] = Node(valve_name, flow_rate)
if flow_rate != 0:
unjammed_nodes.append(valve_name)
for (valve_name, _, further_nodes_str) in sp:
n = nodes[valve_name]
for connected_node_name in further_nodes_str.split(", "):
n.edges.append(nodes[connected_node_name])
# work out a matrix of the shortest paths between two nodes
shortest_paths: Dict[Tuple[str, str], List[str]] = {}
for start_node in nodes:
if nodes[start_node].value == 0 and start_node != "AA":
continue
for end_node in nodes:
if end_node == start_node or nodes[end_node].value == 0:
continue
path = shortest_path(nodes, start_node, end_node)
pl = len(path) + 1
shortest_paths[(start_node, end_node)] = pl
shortest_paths[(end_node, start_node)] = pl
return nodes, unjammed_nodes, shortest_paths
def permutations(
current_node: str,
nodes_remaining: List[str],
shortest_paths: Dict[Tuple[str, str], List[str]],
path: List[str],
cost_remaining: int,
) -> Generator[List[str]]:
for next_node in nodes_remaining:
cost = shortest_paths[(current_node, next_node)]
if cost < cost_remaining:
nr = copy.copy(nodes_remaining)
nr.remove(next_node)
yield from permutations(
next_node, nr, shortest_paths, path + [next_node], cost_remaining - cost
)
yield path
def calc_vented(
nodes: Dict[str, Node],
shortest_paths: Dict[Tuple[str, str], List[str]],
visit_order: List[str],
time_remaining: int,
) -> int:
current = "AA"
pressure = 0
for node_name in visit_order:
path_length = shortest_paths[(current, node_name)]
time_remaining -= path_length + 1
pressure += nodes[node_name].value * time_remaining
current = node_name
return pressure
class Challenge(BaseChallenge):
@staticmethod
def one(instr: str) -> int:
nodes, unjammed_nodes, shortest_paths = parse(instr)
max_pressure = 0
for visit_order in permutations("AA", unjammed_nodes, shortest_paths, [], 30):
pressure = calc_vented(nodes, shortest_paths, visit_order, 30)
if pressure > max_pressure:
max_pressure = pressure
return max_pressure
@staticmethod
def two(instr: str) -> int:
nodes, unjammed_nodes, shortest_paths = parse(instr)
pressures = [
(calc_vented(nodes, shortest_paths, visit_order, 26), visit_order)
for visit_order in permutations(
"AA", unjammed_nodes, shortest_paths, [], 26
)
]
max_pressure = 0
for i, (pressure_a, order_a) in enumerate(
sorted(pressures, reverse=True, key=lambda x: x[0])
):
if pressure_a * 2 < max_pressure:
break
for (pressure_b, order_b) in pressures[i + 1 :]:
if len(set(order_a).intersection(order_b)) == 0:
pressure = pressure_a + pressure_b
if pressure > max_pressure:
max_pressure = pressure
return max_pressure