In [145]:
import heapq
from collections import defaultdict
from dataclasses import dataclass, field
import itertools
import re
from typing import List, Dict, Iterable, Set, Tuple

@dataclass
class Valve:
  name: str
  rate: int
  tunnels: List[str]

@dataclass
class ValveGraph:
  start: str
  valves: Dict[str, Valve]
  distances: Dict[Tuple[str, str], int] = field(init=False, repr=False)

  def __post_init__(self):
    distances = [[0]*len(self.valves) for v in self.valves]
    n = len(self.valves)
    keys = list(self.valves.keys())
    for i in range(n):
      for j in range(i, n):
        d = float('inf')
        if i == j:
          d = 0
        elif keys[i] in self.valves[keys[j]].tunnels:
          d = 1
        distances[i][j] = d
        distances[j][i] = d
    # Floyd-Warshall:
    for k in range(n):
      for i in range(n):
        for j in range(n):
          if distances[i][k] + distances[k][j] < distances[i][j]:
            distances[i][j] = distances[i][k] + distances[k][j]
    self.distances = {}
    for i, row in enumerate(distances):
      for j, dist in enumerate(row):
        self.distances[(keys[i], keys[j])] = dist
    # Optimize by deleting zero-rate nodes with 2 neighbors
    while True:
      found = False
      keys = list(self.valves.keys() - {self.start})
      for v in keys:
        neighbors = self.valves[v].tunnels
        if self.valves[v].rate == 0 and len(neighbors) == 2:
          found = True
          # Incr the distance b/w connected valves by 1.
          a, b = neighbors
          dist_ab = self.distances[(a, b)]
          dist_av = self.distances[(a, v)]
          dist_vb = self.distances[(v, b)]
          print(f'deleting valve {v}; connecting {a} to {b} at d=min({dist_ab}, {dist_av} + {dist_vb})')
          self.distances[(a, b)] = min(dist_ab, dist_av + dist_vb)
          self.distances[(b, a)] = min(dist_ab, dist_av + dist_vb)
          del self.valves[v]
          # Clean out deleted tunnels.
          self.valves[a].tunnels.remove(v)
          self.valves[a].tunnels.append(b)
          self.valves[b].tunnels.remove(v)
          self.valves[b].tunnels.append(a)
      if not found:
        break
      
@dataclass(eq=True, frozen=True)
class ValveState:
  graph: ValveGraph = field(compare=False, repr=False)
  pos: str = field(compare=True)
  time: int = field(default=0, compare=True)
  released: int = field(default=0, compare=True)
  open: Set[str] = field(default_factory=frozenset, compare=True)
  
  def __repr__(self) -> str:
    return f'ValveState({self.pos}, t={self.time}, released={self.released}, open={self.open})'

  def __lt__(self, other) -> int:
    return other.released < self.released

  def neighbors(self, time_limit=30) -> Iterable['ValveState']:
    release_rate = sum(self.graph.valves[v].rate for v in self.open)
    # Open valve at current position (if it makes sense).
    if self.pos not in self.open and self.graph.valves[self.pos].rate > 0:
      yield ValveState(self.graph, self.pos, self.time+1, self.released+release_rate, self.open | {self.pos})
    # Move to neighboring tunnel.
    for v in self.graph.valves[self.pos].tunnels:
      delta_t = self.graph.distances[(v, self.pos)]
      if self.time + delta_t > time_limit:
        print(f'shortening duration from {delta_t} to {time_limit-self.time}; {self.released}+{release_rate}*{time_limit-delta_t}')
        delta_t = time_limit - self.time
      yield ValveState(self.graph, v, self.time+delta_t, self.released+delta_t*release_rate, self.open)
    # # Lastly, noop.
    # yield ValveState(self.graph, self.pos, self.time+1, self.released+release_rate, self.open)
      
  def heuristic_fn(self, time_limit=30) -> int:
    time_remaining = time_limit - self.time
    estimate = 0
    for v in self.graph.valves:
      if v in self.open:
        # If it's open, estimating is easy:
        estimate += (time_remaining) * self.graph.valves[v].rate
      else:
        # If it's not open, see if we can get there and open it in the remaining time:
        dist = self.graph.distances[(self.pos, v)]
        if dist + 1 <= time_remaining:
          estimate += (time_remaining - dist - 1) * self.graph.valves[v].rate
    return estimate

  @staticmethod
  def starting_from(graph: ValveGraph):
    return ValveState(graph, pos=graph.start)

In [124]:
puzzle = """Valve AA has flow rate=0; tunnels lead to valves DD, II, BB
Valve BB has flow rate=13; tunnels lead to valves CC, AA
Valve CC has flow rate=2; tunnels lead to valves DD, BB
Valve DD has flow rate=20; tunnels lead to valves CC, AA, EE
Valve EE has flow rate=3; tunnels lead to valves FF, DD
Valve FF has flow rate=0; tunnels lead to valves EE, GG
Valve GG has flow rate=0; tunnels lead to valves FF, HH
Valve HH has flow rate=22; tunnel leads to valve GG
Valve II has flow rate=0; tunnels lead to valves AA, JJ
Valve JJ has flow rate=21; tunnel leads to valve II"""

line_re = re.compile("Valve (.+) has flow rate=(\d+); tunnels? leads? to valves? (.*)")
def parse_valve(line):
  valve, rate, tunnels = line_re.match(line).groups()
  return Valve(valve, int(rate), [v.strip() for v in tunnels.split(',')])

def parse_valve_graph(lines):
  valves = {}
  for l in lines.splitlines():
    valve = parse_valve(l)
    valves[valve.name] = valve
  return ValveGraph('AA', valves)

In [126]:
test_graph = parse_valve_graph(puzzle)
test_graph

deleting valve II; connecting AA to JJ at d=min(2, 1 + 1)
deleting valve GG; connecting FF to HH at d=min(2, 1 + 1)
deleting valve FF; connecting EE to HH at d=min(3, 1 + 2)


ValveGraph(start='AA', valves={'AA': Valve(name='AA', rate=0, tunnels=['DD', 'BB', 'JJ']), 'BB': Valve(name='BB', rate=13, tunnels=['CC', 'AA']), 'CC': Valve(name='CC', rate=2, tunnels=['DD', 'BB']), 'DD': Valve(name='DD', rate=20, tunnels=['CC', 'AA', 'EE']), 'EE': Valve(name='EE', rate=3, tunnels=['DD', 'HH']), 'HH': Valve(name='HH', rate=22, tunnels=['EE']), 'JJ': Valve(name='JJ', rate=21, tunnels=['AA'])})

In [155]:
def heuristic_fn(state: ValveState, time_limit=30) -> int:
  time_remaining = time_limit - state.time
  estimate = 0
  for v in state.graph.valves:
    if v in state.open:
      # If it's open, estimating is easy:
      estimate += (time_remaining) * state.graph.valves[v].rate
    else:
      # If it's not open, see if we can get there and open it in the remaining time:
      dist = state.graph.distances[(state.pos, v)]
      if dist + 1 <= time_remaining:
        estimate += (time_remaining - dist - 1) * state.graph.valves[v].rate
  return estimate

def a_star_search(valve_graph, time_limit=30, state_cls=ValveState):
  # Initial state.
  frontier = []
  visited = set()
  heapq.heappush(frontier, (0, state_cls.starting_from(valve_graph)))
  i = 0

  while frontier:
    curr_cost, curr_state = heapq.heappop(frontier)
    if curr_state in visited: continue
    i += 1
    visited.add(curr_state)
    # print(f'at {curr_state} with est cost {curr_cost}')

    if curr_state.time >= time_limit:
      break

    for next_state in curr_state.neighbors(time_limit):
      if next_state in visited or next_state.time > time_limit:
        continue
      estimate = next_state.heuristic_fn(time_limit)
      next_cost = -(next_state.released + estimate)
      # print(f'\tconsidering {next_state} with est cost +{estimate}={next_cost}')
      heapq.heappush(frontier, (next_cost, next_state))
  
  print(curr_state)
  print(f'searched {i} states to arrive at released={curr_state.released}')

In [128]:
a_star_search(test_graph)

ValveState(CC, t=30, released=1651, open=frozenset({'BB', 'EE', 'HH', 'DD', 'CC', 'JJ'}))
searched 796 states to arrive at released=1651


In [102]:
with open('day16.txt') as puzzle:
  valve_graph = parse_valve_graph(puzzle.read())
  a_star_search(valve_graph)

deleting valve ED; connecting PS to AW at d=min(2, 1 + 1)
deleting valve BL; connecting GJ to XG at d=min(2, 1 + 1)
deleting valve RQ; connecting HH to GF at d=min(2, 1 + 1)
deleting valve FT; connecting IN to YH at d=min(2, 1 + 1)
deleting valve PN; connecting MF to QR at d=min(2, 1 + 1)
deleting valve KR; connecting AA to PB at d=min(2, 1 + 1)
deleting valve VE; connecting PH to AW at d=min(2, 1 + 1)
deleting valve PS; connecting AY to AW at d=min(3, 1 + 2)
deleting valve YY; connecting PH to GJ at d=min(2, 1 + 1)
deleting valve RI; connecting PB to AY at d=min(2, 1 + 1)
deleting valve SI; connecting AA to HX at d=min(2, 1 + 1)
deleting valve MF; connecting BE to QR at d=min(3, 1 + 2)
deleting valve MK; connecting HX to DV at d=min(2, 1 + 1)
deleting valve KJ; connecting RM to FY at d=min(2, 1 + 1)
deleting valve GC; connecting BI to GJ at d=min(2, 1 + 1)
deleting valve XG; connecting AA to GJ at d=min(3, 1 + 2)
deleting valve ZG; connecting AA to PH at d=min(2, 1 + 1)
deleting valve

In [156]:
from typing import Optional

OptionalValveState = Optional[ValveState]

  
@dataclass(eq=True, frozen=True)
class ValveStateWithElephant(ValveState):
  pos: Tuple[str, str] = field(compare=True)
  next_states: Tuple[OptionalValveState, OptionalValveState] = field(compare=True, default=(None, None))

  def a_neighbors(self, time_limit=26) -> Iterable[ValveState]:
    a_next = self.next_states[0]
    if a_next:
      yield a_next
    else:
      yield from ValveState(self.graph, self.pos[0], self.time, self.released, self.open).neighbors()

  def b_neighbors(self, time_limit=26) -> Iterable[ValveState]:
    b_next = self.next_states[1]
    if b_next:
      yield b_next
    else:
      yield from ValveState(self.graph, self.pos[1], self.time, self.released, self.open).neighbors()

  def merge_states(self, a_state: ValveState, b_state: ValveState, next_states: Tuple[OptionalValveState, OptionalValveState]) -> 'ValveStateWithElephant':
    assert a_state.time == b_state.time
    pos = (a_state.pos, b_state.pos)
    return ValveStateWithElephant(a_state.graph, pos=pos, time=a_state.time, 
                                  released=max(a_state.released, b_state.released), 
                                  open=a_state.open | b_state.open, next_states=next_states)

  def split_state(self, state: ValveState, idx: int, time: int) -> ValveState:
    r_delta = state.released - self.released
    t_delta = state.time - self.time
    r_rate = r_delta // t_delta
    return ValveState(state.graph, self.pos[idx], time, self.released + r_rate*(time - self.time), self.open)
  
  def split_states(self, a_state: ValveState, b_state: ValveState):
    if a_state.time == b_state.time:
      # No splitting necessary.
      # print('\tno split')
      return (a_state, b_state), (None, None)
    # Splitting needed at t = min(a_time, b_time)
    t_delta = min(a_state.time, b_state.time) - self.time
    if a_state.time > b_state.time:
      # print('\tsplit a:', a_state)
      a_part = self.split_state(a_state, 0, b_state.time)
      return (a_part, b_state), (a_state, None)
    else:
      # print('\tsplit b:', b_state)
      b_part = self.split_state(b_state, 1, a_state.time)
      return (a_state, b_part), (None, b_state)
  
  def neighbors(self, time_limit=26) -> Iterable['ValveStateWithElephant']:
    release_rate = sum(self.graph.valves[v].rate for v in self.open)
    
    for (a_state, b_state) in itertools.product(self.a_neighbors(), self.b_neighbors()):
      (a_part, b_part), next_states = self.split_states(a_state, b_state)
      merged = self.merge_states(a_part, b_part, next_states)
      # print(f'\tmerging {a_part.pos}:{a_part.open} and {b_part.pos}:{b_part.open} into {merged}')
      yield merged

  def heuristic_fn(self, time_limit=26) -> int:
    time_remaining = time_limit - self.time
    estimate = 0
    for v in self.graph.valves:
      # estimate += (time_remaining) * self.graph.valves[v].rate
      if v in self.open:
        # If it's open, estimating is easy:
        estimate += (time_remaining) * self.graph.valves[v].rate
      else:
        # If it's not open, see if we can get there and open it in the remaining time:
        dist = min(self.graph.distances[(p, v)] for p in self.pos)
        if self.next_states[0]:
          dist = min(dist, self.graph.distances[(self.next_states[0].pos, v)])
        if self.next_states[1]:
          dist = min(dist, self.graph.distances[(self.next_states[1].pos, v)])
        if dist + 1 <= time_remaining:
          estimate += (time_remaining - dist - 1) * self.graph.valves[v].rate
    return estimate

  @staticmethod
  def starting_from(graph: ValveGraph):
    return ValveStateWithElephant(graph, pos=(graph.start, graph.start))

In [157]:
a_star_search(test_graph, time_limit=26, state_cls=ValveStateWithElephant)

ValveStateWithElephant(pos=('DD', 'DD'), time=26, released=1707, open=frozenset({'JJ', 'EE', 'HH', 'DD', 'CC', 'BB'}), next_states=(None, None))
searched 103 states to arrive at released=1707


In [158]:
a_star_search(valve_graph, time_limit=26, state_cls=ValveStateWithElephant)

ValveStateWithElephant(pos=('AA', 'RM'), time=26, released=2999, open=frozenset({'HH', 'RM', 'BE', 'FY', 'PB', 'IN', 'QR', 'OW', 'LX', 'SV', 'AW', 'PH', 'HX'}), next_states=(None, ValveState(FY, t=27, released=3208, open=frozenset({'HH', 'RM', 'BE', 'FY', 'PB', 'IN', 'QR', 'OW', 'LX', 'SV', 'AW', 'PH', 'HX'}))))
searched 68796 states to arrive at released=2999


In [None]:
# Record
# 2835 -> Too Low; bug was in compare=False for next_states