In [1]:
# allows editing aoc_utils "live" without restarting kernel
# see https://ipython.org/ipython-doc/stable/config/extensions/autoreload.html
# and https://stackoverflow.com/a/17551284
%load_ext autoreload
%autoreload 2

# Add the aoc_utils path
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import aoc_utils
get_input = aoc_utils.get_input
mapints = aoc_utils.mapints
nums = mapints
cat = aoc_utils.cat


In [2]:
# Useful imports
import re
from collections import defaultdict, deque
import heapq
import functools
import queue
import itertools
import math
from functools import cache
import time
import operator

In [None]:
data = get_input(1, 2022)

xs = [sum(map(int, grp)) for grp in list(aoc_utils.split_list(data))]
p1 = max(xs)
p2 = sum(list(reversed(sorted(xs)))[:3])
assert p1,p2 == (66616, 199172)

In [None]:
data = get_input(2, 2022)
data

TOWIN = {
  'A': 'Y',
  'B': 'Z',
  'C': 'X',
}
TOLOSE = {
  'A': 'Z',
  'B': 'X',
  'C': 'Y',
}

ELF = 'ABC'
YOU = 'XYZ'

p1 = 0
for line in data:
  lhs,rhs = line.split(' ')
  p1 += 1 + YOU.find(rhs)
  if ELF.index(lhs) == YOU.index(rhs): # draw
    p1 += 3
  else:
    if TOWIN[lhs] == rhs:
      p1 += 6
    else:
      p1 += 0

p2 = 0
for line in data:
  l,r = line.split(" ")
  if r == 'X': #lose
    shape = TOLOSE[l]
    p2 += 0
  elif r == 'Y': #draw
    shape = YOU[ELF.find(l)]
    p2 += 3
  elif r == 'Z': # win
    shape = TOWIN[l]
    p2 += 6
  p2 += 1 + YOU.find(shape)

assert p1,p2 == (13682, 12881)

In [None]:
data = get_input(3, 2022)
data

from string import ascii_letters

p1 = 0
for line in data:
  l,r = line[:len(line)//2],line[len(line)//2:]
  xs = set(l) & set(r)
  assert len(xs) == 1
  p1 += 1 + ascii_letters.index(list(xs)[0])
p1

p2 = 0
for grp in aoc_utils.chunker(data, 3):
  xs = set(grp[0]) & set(grp[1]) & set(grp[2])
  assert len(xs) == 1
  p2 += 1 + ascii_letters.index(list(xs)[0])
assert (p1, p2) ==  (8233, 2821)

In [None]:
data = get_input(4,2022)
data

p1 = p2 = 0
for line in data:
  a,b = [list(map(int, x.split('-'))) for x in line.split(',')]
  p1 += any(x[0] >= y[0] and x[1] <= y[1] for x,y in [(a,b),(b,a)])
  ra,rb = range(a[0],a[1]+1),range(b[0],b[1]+1)
  p2 += any(x[0] in r and x[1] in r for x,r in [(a,rb),(b,ra)])

assert p1,p2 == (515,883)

In [None]:
data = get_input(5,2022)

towerdata,moves = aoc_utils.split_list(data)

def mk_towers(towerdata):
  toweridxs = mapints(towerdata[-1])
  towers = defaultdict(list)
  for line in towerdata[:-1]: # last line is the indices
    for idx in toweridxs:
      pos = (idx-1)*4 + 1
      if pos < len(line) and line[pos] != ' ':
        towers[idx].append(line[pos])
  return {idx: list(reversed(tower)) for idx,tower in towers.items()}

towers1 = mk_towers(towerdata)
towers2 = mk_towers(towerdata)

for line in moves:
  cnt,start,end = mapints(line)

  assert len(towers1[start]) >= cnt
  assert len(towers2[start]) >= cnt

  # p1
  for _ in range(cnt): towers1[end].append(towers1[start].pop())

  # p2
  towers2[end] += towers2[start][-cnt:]
  towers2[start] = towers2[start][:-cnt]

p1 = cat([towers1[idx][-1] for idx in sorted(towers1.keys())])
p2 = cat([towers2[idx][-1] for idx in sorted(towers2.keys())])

assert (p1,p2) == ('LJSVLTWQM', 'BRQWDBBJM')

In [None]:
data = get_input(6, 2022)[0]

p1 = None
p2 = None
for idx,ch in enumerate(data):
  if idx > 3 and p1 is None:
    if len(set(data[idx-4:idx])) == 4:
      p1 = idx
  if idx > 13 and p2 is None:
    if len(set(data[idx-14:idx])) == 14:
      p2 = idx

assert (p1,p2) == (1343, 2193)

In [None]:
from pathlib import PurePath
data = get_input(7, 2022)

dirs = defaultdict(set)
cwd = None
mode = None
for line in data:
  if line.startswith('$ cd'):
    mode = None
    target = line.split(' ')[-1]
    if target == '/':
      cwd = PurePath('/')
    elif target == '..':
      cwd = cwd.parent
    else:
      cwd = cwd / target
  elif line.startswith('$ ls'):
    mode = 'ls'
  else:
    assert mode == 'ls'
    if line.startswith('dir'): pass
    else:
      size = mapints(line)[0]
      file = line.split(' ')[-1]
      dirs[cwd].add( (file, size) )


@cache
def dirsize(target):
  return (
    sum(f[1] for f in dirs[target]) +
    sum(dirsize(child) for child in dirs if child.parent == target and child != target)
  )

all_sizes = {dir:dirsize(dir) for dir in dirs}
p1 = sum(s for s in all_sizes.values() if s <= 100_000)

TOTAL = 70000000
NEEDED = 30000000
USED = TOTAL - dirsize(PurePath('/'))

p2 = min(s for s in all_sizes.values() if s >= NEEDED-USED)

assert p1,p2 == (1783610, 4370655)

In [None]:
timer = aoc_utils.start_timer()

In [None]:
data = get_input(8, 2022)
testdata = """30373
25512
65332
33549
35390""".split('\n')

colmax = len(data[0])
rowmax = len(data)
visible = set()

def get_row(idx):
  return list(map(int, data[idx]))

def get_col(idx):
  return list(map(int, [data[rowidx][idx] for rowidx in range(rowmax)]))

def filter_visible(arr):
  maxes = list(itertools.accumulate(arr, max))
  return [cur>prev for (cur,prev) in zip( maxes, [-1] + maxes[:-1])]

for ridx in range(rowmax):
  row = get_row(ridx)
  for cidx,is_vis in enumerate(filter_visible(row)):
    if is_vis: visible.add((cidx,ridx))
  for revcidx,is_vis in enumerate(filter_visible(row[::-1])):
    if is_vis: visible.add((colmax-1-revcidx,ridx))

for cidx in range(colmax):
  col = get_col(cidx)
  for ridx,is_vis in enumerate(filter_visible(col)):
    if is_vis: visible.add((cidx,ridx))
  for revridx,is_vis in enumerate(filter_visible(col[::-1])):
    if is_vis: visible.add((cidx, rowmax-1-revridx))

p1 = len(visible)

def view_dist(arr, val):
  idx = aoc_utils.findindex(arr, lambda v: v >= val)
  return len(arr) if idx is None else (idx + 1)

p2 = 0
for (cidx,ridx) in itertools.product(range(colmax),range(rowmax)):
  row,col = get_row(ridx),get_col(cidx)
  l,r,u,d = row[:cidx][::-1],row[cidx+1:],col[:ridx][::-1],col[ridx+1:]
  curval = int(data[ridx][cidx])
  scenic = functools.reduce(operator.mul, [view_dist(arr,curval) for arr in [u,r,l,d]])
  p2 = max(p2, scenic)

assert (p1,p2) == (1854, 527340)


In [None]:
data = get_input(9,2022)
testdata = """R 4
U 4
L 3
D 1
R 4
D 1
L 5
R 2""".split('\n')

CARDINAL_DELTAS = aoc_utils.CARDINAL_DELTAS
manhattan_distance = aoc_utils.manhattan_distance

def mv_tail(h,t):
  tx,ty = t
  hx,hy = h
  md = manhattan_distance(t,h)
  if md in [0,1]: return t
  if md == 2 and tx != hx and ty != hy: return t

  dx = 1 if tx < hx else -1
  dy = 1 if ty < hy else -1

  if tx == hx:
    return (tx, ty+dy)
  if ty == hy:
    return (tx+dx,ty)
  return (tx+dx,ty+dy)

def mv_head(h,dir):
  x,y = h
  dx,dy = CARDINAL_DELTAS[dir]
  return (x+dx,y+dy)

def mv_rope(rp,dir):
  return list(itertools.accumulate(rp[1:], func=mv_tail, initial=mv_head(rp[0], dir)))

ROPE = [(0,0)] * 10
p1seen = set([(0,0)])
p2seen = set([(0,0)])
for line in data:
  cnt = mapints(line)[0]
  dir = line[0]
  for _ in range(cnt):
    ROPE = mv_rope(ROPE, dir)
    p1seen.add(ROPE[1])
    p2seen.add(ROPE[-1])
p1 = len(p1seen)
p2 = len(p2seen)

print(p1,p2)
assert (p1,p2) == (6081, 2487)


In [None]:
data = get_input(10,2022)

def run(data):
  X = 1
  yield X
  for line in data:
    yield X
    if 'addx' in line:
      yield X
      X += int(line.split(' ')[1])

xs = list(run(data))
p1 = sum(idx*v for idx,v in enumerate(xs) if idx % 40 == 20)

G = {}
rows = 6
rowlen = 40
for y in range(rows):
  for x in range(rowlen):
    cycle_idx = 1 + x + rowlen*y
    pos = (x,y)
    spritemid = xs[cycle_idx]
    pixel = '⬛' if abs(x-spritemid) > 1 else '⬜'
    G[pos] = pixel


print(f"p1 {p1}")
assert p1 == 15880

print("p2:")
for y in range(rows):
  for x in range(rowlen):
    print(G[(x,y)], end='')
  print('')

In [None]:
timer = aoc_utils.start_timer()

In [None]:
data = get_input(11,2022)

testdata = """Monkey 0:
  Starting items: 79, 98
  Operation: new = old * 19
  Test: divisible by 23
    If true: throw to monkey 2
    If false: throw to monkey 3

Monkey 1:
  Starting items: 54, 65, 75, 74
  Operation: new = old + 6
  Test: divisible by 19
    If true: throw to monkey 2
    If false: throw to monkey 0

Monkey 2:
  Starting items: 79, 60, 97
  Operation: new = old * old
  Test: divisible by 13
    If true: throw to monkey 1
    If false: throw to monkey 3

Monkey 3:
  Starting items: 74
  Operation: new = old + 3
  Test: divisible by 17
    If true: throw to monkey 0
    If false: throw to monkey 1""".split('\n')

OPS = { '+': operator.add, '*': operator.mul }
def parse(data):
  monkeys = {}
  for idx,grp in enumerate(aoc_utils.split_list(data)):
    inspect = 0
    items = nums(grp[1])
    op,v = grp[2].split('Operation: new = old ')[1].split(' ')
    test,true,false = nums(' '.join(grp[3:]))
    monkeys[idx] = [inspect,op,v,test,true,false,items]
  return monkeys


monkeys = parse(data)
PROD = math.prod(map(lambda m: m[3], monkeys.values()))

def rnd(monkeys,is_p2=False):
  for idx in range(len(monkeys.keys())):
    (inspect,op,v,test,true,false,items) = monkeys[idx]
    monkeys[idx][0] += len(items) # update inspect
    monkeys[idx][-1] = [] # clear items
    for lvl in items:
      lvl = OPS[op](lvl, int(v) if v.isdigit() else lvl)
      if not is_p2:
        lvl //= 3
      lvl = lvl % PROD
      target = true if lvl % test == 0 else false
      monkeys[target][-1].append(lvl)
  return monkeys

monkeys = parse(data)
for _ in range(20):
  monkeys = rnd(monkeys)
a,b = list(sorted(map(lambda m:m[0], monkeys.values())))[-2:]
p1 = a * b

monkeys = parse(data)
for _ in range(10_000):
  monkeys = rnd(monkeys,is_p2=True)
a,b = list(sorted(map(lambda m:m[0], monkeys.values())))[-2:]
p2 = a * b

# p1 in 2123s
print("p1",p1)
assert p1 == 62491

# p2 in 2778s
print("p2",p2)
assert p2 == 17408399184

In [None]:
data = get_input(12,2022)

S = None
E = None
GRID = {}
from string import ascii_letters as alpha
for col,line in enumerate(data):
  for row,ch in enumerate(line):
    pos = (row,col)
    if ch == 'S':
      S = pos
      ch ='a'
    if ch == 'E':
      E = pos
      ch = 'z'
    GRID[pos] = alpha.index(ch)

def solve(start, goals, GRID, is_p2=False):
  from heapq import heappop,heappush
  search = []
  moves = 0
  seen = set()
  heappush(search, (moves, start))
  while len(search):
    (moves, pos) = heappop(search)
    if pos in goals:
      return moves
    for n in aoc_utils.neighbors4(pos):
      valid = n in GRID and n not in seen
      if is_p2: # going backwards, so reverse elevation condition
        valid = valid and GRID[pos] <= (GRID[n]+1)
      else:
        valid = valid and GRID[n] <= (GRID[pos]+1)
      if valid:
        seen.add(n)
        heappush(search, (moves+1,n))


starts = [pos for (pos,v) in GRID.items() if v == alpha.index('a')]

p1 = solve(S, [E], GRID)
# solve backwards from E to any one of the starts
p2 = solve(E, starts, GRID, is_p2=True)

print(f"p1 {p1}") # 624s
print(f"p2 {p2}") # 859s

assert p1 == 420
assert p2 == 414

In [None]:
timer = aoc_utils.start_timer()

In [None]:
data = get_input(13,2022)
testdata = """[1,1,3,1,1]
[1,1,5,1,1]

[[1],[2,3,4]]
[[1],4]

[9]
[[8,7,6]]

[[4,4],4,4]
[[4,4],4,4,4]

[7,7,7,7]
[7,7,7]

[]
[3]

[[[]]]
[[]]

[1,[2,[3,[4,[5,6,7]]]],8,9]
[1,[2,[3,[4,[5,6,0]]]],8,9]""".split('\n')

from aoc_utils import split_list
from itertools import chain, cycle
from functools import cmp_to_key

def isint(x): return isinstance(x,int)

LT = -1 # In-order
EQ = 0  # Equal (-> In-order)
GT = 1  # Out of order

def cmp(l,r):
  def recurse(l,r):
    if isint(l) and isint(r):
      return LT if l < r else EQ if l == r else GT
    if isint(l): return recurse([l],r)
    if isint(r): return recurse(l,[r])
    for v in map(recurse, l, r):
      if v == EQ: continue
      return v
    return recurse(len(l),len(r))
  
  v = recurse(l,r)
  return LT if v in [LT,EQ] else GT # coerce EQ to LT

assert cmp([1,1,3,1,1],[1,1,5,1,1]) == LT
assert cmp([[1],[2,3,4]],[[1],4]) == LT
assert cmp([9], [[8,7,6]]) == GT
assert cmp([[4,4],4,4] , [[4,4],4,4,4]) == LT
assert cmp([7,7,7,7] , [7,7,7]) == GT
assert cmp([],[3]) == LT
assert cmp([1,2,3],[1,2,3]) == LT # EQ -> LT
assert cmp([[[]]] , [[]]) == GT
assert cmp([1,[2,[3,[4,[5,6,7]]]],8,9] , [1,[2,[3,[4,[5,6,0]]]],8,9]) == GT
assert cmp([1,[2,[3,[4,[5,6,7]]]],8,9] , [1,[2,[3,[4,[5,6,6]]]],8,9]) == GT
assert cmp([1,[2,[3,[4,[5,6,7]]]],8,9] , [1,[2,[3,[4,[5,6,8]]]],8,9]) == LT
assert cmp([1,[2,[3,[4,[5,6,7]]]],8,9] , [1,[2,[3,[4,[5,6,7]]]],8,8]) == GT
assert cmp([1,2,3],[2]) == LT
assert cmp([1],[2,3,4]) == LT
assert cmp([],[2,3,4]) == LT

p1 = sum(idx for idx,pair in enumerate(split_list(data), 1) if cmp(*map(eval,pair)) == LT)

DIVIDERS = [ [[2]], [[6]] ]
packets = DIVIDERS + list(map(eval, (line for line in data if line != '')))
packets = sorted(packets, key=cmp_to_key(cmp))
p2 = math.prod(idx for idx,pkt in enumerate(packets, 1) if pkt in DIVIDERS)

print(f"p1 {p1}")
print(f"p2 {p2}")

assert p1 == 5330
assert p2 == 27648

In [None]:
data = get_input(14,2022)
testdata = """498,4 -> 498,6 -> 496,6
503,4 -> 502,4 -> 502,9 -> 494,9""".split("\n")

def parse(data, is_p2=False):
  G = set()

  def draw_line(prev, cur):
    (x1,y1), (x2,y2) = prev, cur
    x1,x2 = sorted([x1,x2])
    y1,y2 = sorted([y1,y2])
    for pos in itertools.product(range(x1,x2+1),range(y1,y2+1)):
      G.add(pos)

  for line in data:
    points = list(aoc_utils.chunker(nums(line),2))
    for (prev,cur) in zip(points,points[1:]):
      draw_line(prev,cur)

  maxy = max(pos[1] for pos in G)
  if is_p2:
    maxy += 2
    # add partial floor -- wide enough to catch all sand
    for x in range(-5 * maxy, 5 * maxy):
      G.add((x,maxy))

  return G, maxy

sand_deltas = [ [0,1], [-1,1], [1,1] ] # S, SW, SE
SOURCE = (500,0)

def fallsand(G,maxy,pos=SOURCE):
  if pos in G: return False # not free
  if not pos[1] < maxy: return False # not in bounds

  x,y = pos
  possible = [(x+dx,y+dy) for (dx,dy) in sand_deltas]
  possible = [p for p in possible if not p in G]
  if not possible:
    G.add(pos)
    return True
  else:
    return fallsand(G,maxy,possible[0])


def solve(is_p2=False):
  G,maxy = parse(data,is_p2)
  cnt = 0
  while True:
    if not fallsand(G,maxy):
      return cnt
    cnt += 1

p1 = solve()
print(f"p1 {p1}")
assert p1 == 763

p2 = solve(is_p2=True)
print(f"p2 {p2}")
assert p2 == 23921



In [None]:
data = get_input(15,2022)
testdata = """Sensor at x=2, y=18: closest beacon is at x=-2, y=15
Sensor at x=9, y=16: closest beacon is at x=10, y=16
Sensor at x=13, y=2: closest beacon is at x=15, y=3
Sensor at x=12, y=14: closest beacon is at x=10, y=16
Sensor at x=10, y=20: closest beacon is at x=10, y=16
Sensor at x=14, y=17: closest beacon is at x=10, y=16
Sensor at x=8, y=7: closest beacon is at x=2, y=10
Sensor at x=2, y=0: closest beacon is at x=2, y=10
Sensor at x=0, y=11: closest beacon is at x=2, y=10
Sensor at x=20, y=14: closest beacon is at x=25, y=17
Sensor at x=17, y=20: closest beacon is at x=21, y=22
Sensor at x=16, y=7: closest beacon is at x=15, y=3
Sensor at x=14, y=3: closest beacon is at x=15, y=3
Sensor at x=20, y=1: closest beacon is at x=15, y=3""".split('\n')

from aoc_utils import manhattan_distance as md

def parse(data):
  S = set()
  beacons = set()
  for line in data:
    sx,sy,bx,by = nums(line)
    s = (sx,sy)
    b = (bx,by)
    beacons.add(b)
    d = md(s,b)
    S.add((s,d))
  # map of y -> count of beacons w/ that y
  beaconYs = defaultdict(int)
  for b in beacons:
    beaconYs[b[1]] += 1
  return S,beaconYs

def solve(S,B,targety):
  sensors = []
  for s,d in S:
    x,y = s
    if abs(targety - y) <= d:
      sensors.append((s,d))
  cover = set()
  for s,d in sensors:
    x,y = s
    dy = abs(targety - y)
    rng = range(x - d + dy, x + d - dy + 1)
    cover.add(rng)
  minx = min(r[0] for r in cover)
  maxx = max(r[-1] for r in cover)
  cnt = 0
  x = minx
  free = []
  seen = []
  while x < maxx:
    for r in cover:
      if x in r:
        nextx = r[-1] + 1
        cnt = cnt + nextx - x
        x = nextx
        break
    else:
      free.append(x)
      x += 1
  # subtract any beacons in targety
  return cnt-B[targety],free

S,B = parse(testdata)
solve(S, B, 10)

TARGETY = 2000000
S,B = parse(data)
p1,_ = solve(S,B,TARGETY)
print("p1",p1)
assert p1 == 5832528

MAXY = 4000000
for y in range(0,MAXY):
  if y > 0 and y % (MAXY // 10) == 0: print(f"p2 {round(100*y//MAXY)}%") # progress
  cnt,free = solve(S,B,y)
  if free:
    assert len(free) == 1
    x = free[0]
    p2 = x * MAXY + y
    break
print("p2",p2)
assert p2 == 13360899249595

In [None]:
try:
  import networkx
except:
  import sys
  !{sys.executable} -m pip install networkx
try:
  import matplotlib
except:
  import sys
  !{sys.executable} -m pip install matplotlib
import networkx as nx

data = get_input(16,2022)
testdata = """Valve AA has flow rate=0; tunnels lead to valves DD, II, BB
Valve BB has flow rate=13; tunnels lead to valves CC, AA
Valve CC has flow rate=2; tunnels lead to valves DD, BB
Valve DD has flow rate=20; tunnels lead to valves CC, AA, EE
Valve EE has flow rate=3; tunnels lead to valves FF, DD
Valve FF has flow rate=0; tunnels lead to valves EE, GG
Valve GG has flow rate=0; tunnels lead to valves FF, HH
Valve HH has flow rate=22; tunnel leads to valve GG
Valve II has flow rate=0; tunnels lead to valves AA, JJ
Valve JJ has flow rate=21; tunnel leads to valve II""".split("\n")

from copy import copy
RATES = {}
MAP = defaultdict(set)

for line in data:
  r = nums(line)[0]
  src = line.split('Valve ')[1].split(' ')[0]
  outs = line.split('to valve')[1]
  if outs.startswith('s'): outs = outs[1:]
  outs = outs.strip()
  outs = outs.split(', ')
  RATES[src] = r
  MAP[src] |= set(outs)

RATES,MAP

def parse(data):
  g = nx.DiGraph()
  rates = {}
  valves = set()
  for line in data:
    r = nums(line)[0]
    src = line.split('Valve ')[1].split(' ')[0]
    valves.add(src)
    rates[src] = r
    outs = line.split('to valve')[1]
    if outs.startswith('s'): outs = outs[1:]
    outs = outs.strip()
    outs = outs.split(', ')
    # print('add node',src,r)
    g.add_node(src, rate=r)
    for out in outs:
      # print('add edge from ',src,out)
      g.add_edge(src, out, len=1)
  dists = {}
  for v1 in valves:
    for v2 in valves:
      if v1 != v2:
        dists[(v1,v2)] = nx.shortest_path_length(g,v1,v2)
  flowvalves = {v:r for v,r in rates.items() if r != 0}
  return g,dists,flowvalves

valve = 'AA'
mins = 1
rate = 0
released = 0
open = frozenset()
state = (valve,mins,open,rate,released)
maxrate = defaultdict(int)
MAX_MINS = 30
maxseen = -1

from functools import cache

# @cache
# def run(state):
#   global maxseen
#   valve,mins,open,rate,released = state
#   if mins > MAX_MINS: assert False
#   if mins == MAX_MINS:
#     if released > maxseen:
#       print(f"new max {released}")
#       maxseen = released
#     return released + rate
#   # if maxrate[mins] > rate:
#   #   return -1
#   # elif rate > maxrate[mins]:
#   #   maxrate[mins] = rate

#   next_released = released + rate
#   next_mins = mins + 1
#   next_states = []
#   if valve not in open and valve != 'AA':
#     next_open = open | set([valve])
#     next_rate = rate + RATES[valve]
#     next_states.append((valve,next_mins,next_open,next_rate,next_released))
#   for next_valve in MAP[valve]:
#     next_states.append(
#       (next_valve,next_mins,open,rate,next_released)
#     )
#   return max(run(s) for s in next_states)


# def make_mermaid_chart(RATES,MAP):
#   print("graph LR")
#   for src,outs in MAP.items():
#     for out in outs:
#       print(f"{src}[{src} {RATES[src]}]-- 1 -->{out}[{out} {RATES[out]}]")
#     if RATES[src] == 0:
#       print(f"style {src} fill:green")
#   print(f"style AA fill:red")
    
  

# make_mermaid_chart(RATES,MAP)
# run(state)

maxrel = -1

RUN2_SMALLSTATES = defaultdict(int)
pruned = 0

@cache
def run2(G,state=None):
  global maxrel
  if state is None:
    mins = 1
    valve = 'AA'
    open = frozenset([valve])
    rate = 0
    released = 0
    state = (valve,mins,open,rate,released)
  valve,mins,open,rate,released = state
  if mins > MAX_MINS:
    pruned += 1
    return -1
  smallstate = (open,mins)
  if rate < RUN2_SMALLSTATES[smallstate]:
    return -1
  RUN2_SMALLSTATES[smallstate] = rate
  if mins == MAX_MINS:
    total = released + rate
    if total > maxrel:
      maxrel = total
      # print(f"best {total}")
    return total

  next_states = []
  if valve not in open:
    next_mins = mins + 1
    next_open = open | set([valve])
    next_released = released + rate
    next_rate = rate + G.nodes[valve]['rate']
    next_states.append(
      (valve,next_mins,next_open,next_rate,next_released)
    )
  for next_valve in list(G.successors(valve)):
    assert next_valve != valve
    l = G.get_edge_data(valve,next_valve)['len']
    assert l > 0
    next_released = released + rate*l
    next_mins = mins + l
    next_states.append(
      (next_valve,next_mins,open,rate,next_released)
    )
  return max(run2(G,s) for s in next_states)

from copy import copy, deepcopy

def fully_reduce_graph(g):
  while True:
    l1 = len(g)
    g = reduce_graph(g)
    if len(g) == l1: return g

def reduce_graph(g):
  G = deepcopy(g)
  # for each node of r=0 w/ in_degree 2:
  # . remove the node
  # . add an edge of len=2 between both of its in_pointing nodes
  for n,d in g.in_degree():
    if n == 'AA': continue
    if G.nodes[n]['rate'] == 0 and d == 2:
      # print(n,d,G.nodes[n]['rate'],G.out_degree(n))
      preds = list(G.predecessors(n))
      assert len(preds) == 2
      # print(f"removing {n} and joining {preds[0]} to {preds[1]}")
      l = sum(G.get_edge_data(p,n)['len'] for p in preds)
      G.remove_node(n)
      G.add_edge(preds[0],preds[1],len=l)
      G.add_edge(preds[1],preds[0],len=l)
      assert preds[0] != preds[1]
      # for p in G.predecessors(n):
      #   l = G.get_edge_data(p,n)['len']
      #   print("pred",p,G.get_edge_data(p,n))
  # list(G.in_edges('AA'))
  return G

# MAX_MINS2 = 26
# maxrel = -1
# @cache
# def run3(G,state=None):
#   global maxrel
#   if state is None:
#     p1 = 'AA'
#     p2 = 'AA'
#     p1m = 1
#     p2m = 1
#     open = frozenset( [('AA',1)] )
#     seen = frozenset(['AA'])
#     state = (p1,p1m,p2,p2m,open,seen)
#   p1,p1m,p2,p2m,open,seen = state
#   if p1m > MAX_MINS2 or p2m > MAX_MINS2:
#     return -1
#   if p1m == MAX_MINS2 and p2m == MAX_MINS2:
#     total = sum(G.nodes[v]['rate'] * (MAX_MINS2-m) for v,m in open)
#     if total > maxrel:
#       maxrel = total
#       print(f"best {total}")
#     return total

#   next_states = []
#   if p1m < MAX_MINS2:
#     if p1 not in seen:
#       nextp1m = p1m+1
#       next_open = open | set([(p1,p1m)])
#       next_seen = seen | set([p1])
#       next_states.append(
#         (p1,nextp1m,p2,p2m,next_open,next_seen)
#       )
#     for nextp1 in G.successors(p1):
#       l = G.get_edge_data(p1,nextp1)['len']
#       assert l > 0
#       nextp1m = min(MAX_MINS2, p1m + l)
#       next_states.append(
#         (nextp1,nextp1m,p2,p2m,open,seen)
#       )
#   if p2m < MAX_MINS2:
#     if p2 not in seen:
#       nextp2m = p2m+1
#       next_open = open | set([(p2,p2m)])
#       next_seen = seen | set([p2])
#       next_states.append(
#         (p1,p1m,p2,nextp2m,next_open,next_seen)
#       )
#     for nextp2 in G.successors(p2):
#       l = G.get_edge_data(p2,nextp2)['len']
#       assert l > 0
#       nextp2m = min(MAX_MINS2, p2m + l)
#       next_states.append(
#         (p1,p1m,nextp2,nextp2m,open,seen)
#       )
#   return max(run3(G,s) for s in next_states)

G,DISTS,FLOWVALVES = parse(data)
# G = fully_reduce_graph(G)

In [None]:
DISTS
FLOWVALVES

from functools import cache

@cache
def search(rem_mins,start='AA',rem_valves=frozenset(FLOWVALVES.keys()), is_p2=False):
  next_states = []
  for v in rem_valves:
    d = DISTS[(start,v)]
    if d < rem_mins:
      next_rem_mins = rem_mins - d - 1
      released = FLOWVALVES[v] * next_rem_mins
      next_rem_valves = rem_valves - {v}
      next_states.append( (released, v, next_rem_mins, next_rem_valves) )
  return max(
    [
      released + search(next_rem_mins,v,next_rem_valves,is_p2=is_p2)
      for (released,v,next_rem_mins,next_rem_valves) in next_states
    ] + [ search(26, rem_valves=rem_valves) if is_p2 else 0 ] # adding the [0] ensures `max` doesn't error
  )

FLOWVALVES,frozenset(FLOWVALVES)
search(30), search(26, is_p2=True)

In [None]:
from functools import cache
from itertools import cycle, islice
data = get_input(17,2022)
# data = [">>><<><>><<<>><>>><<<>>><<<><<<>><>><<>>"]

R = '>'
L = '<'
def parse(data):
  return ['L' if ch == '<' else 'R' for ch in data[0]]

SHAPES = [
  ((0,0),(1,0),(2,0),(3,0)), # horiz bar
  ((1,0),(0,-1),(1,-1),(2,-1),(1,-2)), # cross
  ((2,0),(2,-1),(0,-2),(1,-2),(2,-2)), # L
  ((0,0),(0,-1),(0,-2),(0,-3)), # vert bar
  ((0,0),(1,0),(0,-1),(1,-1)) # square
]

@cache
def get_extents(shape):
  miny = { x: min(_y for (_x,_y) in shape if _x == x) for x,y in shape }
  maxy = { x: max(_y for (_x,_y) in shape if _x == x) for x,y in shape }
  minx = { y: min(_x for (_x,_y) in shape if _y == y) for x,y in shape }
  maxx = { y: max(_x for (_x,_y) in shape if _y == y) for x,y in shape }
  rights = [p for p in shape if p[0] == maxx[p[1]]]
  lefts = [p for p in shape if p[0] == minx[p[1]]]
  tops = [p for p in shape if p[1] == maxy[p[0]]]
  bottoms = [p for p in shape if p[1] == miny[p[0]]]
  return {
    'R': rights,
    'L': lefts,
    'U': tops,
    'D': bottoms
  }

assert get_extents(SHAPES[0])['L'] == [(0,0)]
assert get_extents(SHAPES[0])['R'] == [(3,0)]
assert get_extents(SHAPES[0])['U'] == list(SHAPES[0])
assert get_extents(SHAPES[0])['D'] == list(SHAPES[0])

assert get_extents(SHAPES[1])['L'] == [(1,0),(0,-1),(1,-2)]
assert get_extents(SHAPES[1])['R'] == [(1,0),(2,-1),(1,-2)]
assert sorted(get_extents(SHAPES[1])['U']) == sorted([(0,-1),(1,0),(2,-1)])
assert sorted(get_extents(SHAPES[1])['D']) == sorted([(0,-1),(1,-2),(2,-1)])

assert get_extents(SHAPES[2])['L'] == [(2,0),(2,-1),(0,-2)]
assert get_extents(SHAPES[2])['R'] == [(2,0),(2,-1),(2,-2)]
assert sorted(get_extents(SHAPES[2])['U']) == sorted([(2,0),(0,-2),(1,-2)])
assert sorted(get_extents(SHAPES[2])['D']) == sorted([(0,-2),(1,-2),(2,-2)])

assert get_extents(SHAPES[3])['L'] == list(SHAPES[3])
assert get_extents(SHAPES[3])['R'] == list(SHAPES[3])
assert sorted(get_extents(SHAPES[3])['U']) == [(0,0)]
assert sorted(get_extents(SHAPES[3])['D']) == [(0,-3)]

assert sorted(get_extents(SHAPES[4])['L']) == sorted([(0,0),(0,-1)])
assert sorted(get_extents(SHAPES[4])['R']) == sorted([(1,0),(1,-1)])
assert sorted(get_extents(SHAPES[4])['U']) == sorted([(0,0),(1,0)])
assert sorted(get_extents(SHAPES[4])['D']) == sorted([(0,-1),(1,-1)])


DIR_DELTAS = {
  'L': (-1,0),
  'R': (1,0),
  'D': (0,-1)
}

WIDTH = 7

def valid(p,G):
  x,y = p
  if x < 0 or x >= WIDTH:
    return False
  if y < 0:
    return False
  return not p in G

def addv2(a,b):
  assert len(a) == 2
  assert len(b) == 2
  return (a[0]+b[0],a[1]+b[1])

def translate(shape,origin):
  return (addv2(p,origin) for p in shape)

def can_mv(origin,shape,dir,G):
  extents = get_extents(shape)
  # shape = translate(shape,origin)
  delta = DIR_DELTAS[dir]
  sides = list(addv2(origin,side) for side in extents[dir])
  return all(valid(addv2(p,delta), G) for p in sides)

@cache
def height(shape):
  return max(abs(y) for x,y in get_extents(shape)['D']) + 1

assert height(SHAPES[0]) == 1
assert height(SHAPES[1]) == 3
assert height(SHAPES[2]) == 3
assert height(SHAPES[3]) == 4
assert height(SHAPES[4]) == 2

shapes = cycle(SHAPES)
dirs = parse(data)

def get_maxy(G):
  return max(y for (x,y) in G)

def make_cache_key(G,dir_idx,shape_idx):
  maxys = defaultdict(int)
  for ref_x in range(7):
    maxys[ref_x] = max(y for (x,y) in G if x == ref_x)
  minmaxy = min(maxys.values())
  maxys = {
    x:maxys[x]-minmaxy for x in range(7)
  }
  return (dir_idx,shape_idx,frozenset(maxys.items()))

def run(data):
  G = {}
  cache = {}
  for x in range(WIDTH):
    G[(x,0)] = True
  dirs = parse(data)
  dir_idx = 0
  shape_idx = 0
  for step in range(int(1e12)):
    if step % 500 == 0: print("step",step)
    shape = SHAPES[shape_idx]
    shape_idx = (shape_idx + 1) % len(SHAPES)
    origin = (2,get_maxy(G)+3+height(shape)) 
    while True:
      dir = dirs[dir_idx]
      dir_idx = (dir_idx + 1) % len(dirs)

      key = make_cache_key(G,dir_idx,shape_idx)
      if key in cache:
        # print(key,step,get_maxy(G),cache[key])
        prev_step,prev_maxy = cache[key]
        pos_cycle_size = step - prev_step
        rem_steps = 1e12 - step
        curmaxy = get_maxy(G)
        height_inc = curmaxy - prev_maxy
        d,rem = divmod(rem_steps,pos_cycle_size)
        print(key,step,pos_cycle_size,height_inc,"REM",rem)
        if rem == 0:
          print("p2", curmaxy + height_inc * d)
          return
      else:
        cache[key] = (step,get_maxy(G))

      if can_mv(origin,shape,dir,G):
        origin = addv2(origin, DIR_DELTAS[dir])
      if can_mv(origin,shape,'D',G):
        origin = addv2(origin, DIR_DELTAS['D'])
      else:
        for p in translate(shape, origin):
          assert p not in G
          G[p] = True
        break
  print(get_maxy(G)) # 11322 too high

run(data)

In [47]:
data = get_input(18,2022)
data
testdata = """2,2,2
1,2,2
3,2,2
2,1,2
2,3,2
2,2,1
2,2,3
2,2,4
2,2,6
1,2,5
3,2,5
2,1,5
2,3,5""".split('\n')
testdata2 = ["1,1,1","2,1,1"]

testdata3 = """1,1,2
1,1,3
1,2,1
1,2,2
1,2,3
1,3,1
1,3,2
1,3,3
2,1,1
2,1,2
2,1,3
2,2,1
2,2,3
2,3,1
2,3,2
2,3,3
3,1,1
3,1,2
3,1,3
3,2,1
3,2,2
3,2,3
3,3,1
3,3,2
3,3,3""".split("\n")

def parse(data):
  return {tuple(nums(line)) for line in data}

def faces(cube):
  return [(*cube,dir) for dir in DIRS_3D]

def reflect_face(face):
  # the connected face that points at this one, on the adj cube
  try:
    *cube,dir = face
  except:
    pass
  adj_cube = translate_cube(cube,dir)
  return (*adj_cube,opp_dir(dir))

def opp_dir(dir):
  return ("+" if dir[0] == "-" else "-") + dir[1]

def translate_cube(cube,dir):
  x,y,z = cube
  dx,dy,dz = DIRS_3D[dir]
  return (x+dx,y+dy,z+dz)

def translate_face(face,t_dir):
  *cube,dir = face
  adj_cube = translate_cube(cube,t_dir)
  return (*adj_cube,dir)

def cube_neighbors6(cube):
  return [translate_cube(cube,dir) for dir in DIRS_3D]

def ortho_dirs(dir):
  if 'x' in dir:
    return ['-y','+y','-z','+z']
  elif 'y' in dir:
    return ['-x','+x','-z','+z']
  else:
    assert 'z' in dir
    return ['-x','+x','-y','+y']

# returns:
# . the 4 on the cube
# . the 4 that are the mirror of those ^ 4 (eg on the imaginary cube 1 unit in the dir)
# . the 4 on the same plane of the dir (eg if +x, then 4 other +x faces, each +/- 1 unit in y,z)
# . the reflections of all those 12 (eg if dir is +x for a face, the reflection is the face 1 x unit away but facing -x)
def adjacent_faces(face):
  *cube,dir = face
  adj = []
  same_cube_faces = [(*cube,d) for d in ortho_dirs(dir)]
  adj_cube = translate_cube(cube,dir)
  adj_cube_faces = [(*adj_cube,d) for d in ortho_dirs(dir)]
  planar_faces = [translate_face(face,p_dir) for p_dir in ortho_dirs(dir)]
  all_adj_faces = same_cube_faces + adj_cube_faces + planar_faces 
  all_faces = []
  for f in all_adj_faces:
    all_faces.append(f)
    all_faces.append(reflect_face(f))
  return all_faces

def cube_manhattan_dist(cubea,cubeb):
  ax,ay,az = cubea
  bx,by,bz = cubeb
  return abs(ax-bx) + abs(ay-by) + abs(az-bz)

# the 2 faces that connect these two cubes, if they are connected
def adjoining_faces(cubea,cubeb):
  d = cube_manhattan_dist(cubea,cubeb)
  if d != 1: return []
  ax,ay,az = cubea
  bx,by,bz = cubeb
  if ax < bx:
    return [(*cubea,'+x'),(*cubeb,'-x')]
  elif bx < ax:
    return [(*cubeb,'+x'),(*cubea,'-x')]
  elif ay < by:
    return [(*cubea,'+y'),(*cubeb,'-y')]
  elif by < ay:
    return [(*cubeb,'+y'),(*cubea,'-y')]
  elif az < bz:
    return [(*cubea,'+z'),(*cubeb,'-z')]
  elif bz < az:
    return [(*cubeb,'+z'),(*cubea,'-z')]
  else:
    assert False

def cube_dir(src,target):
  assert cube_manhattan_dist(src,target) == 1
  ax,ay,az = src
  bx,by,bz = target
  if ax < bx: return '+x'
  if ax > bx: return '-x'
  if ay < by: return '+y'
  if ay > by: return '-y'
  if az < bz: return '+z'
  if az > bz: return '-z'
  assert False

def in_bounds(cube,mins,maxs):
  return all(mins[idx] <= cube[idx] <= maxs[idx] for idx in [0,1,2])

DIRS_3D = {
  '+x': (1,0,0),
  '-x': (-1,0,0),
  '+y': (0,1,0),
  '-y': (0,-1,0),
  '+z': (0,0,1),
  '-z': (0,0,-1),
}

def categorize(face,all_faces,cubes):
  x,y,z,dir = face
  src = (x,y,z)
  dx,dy,dz = DIRS_3D[dir]
  adj = (x+dx,y+dy,z+dz)
  if adj in cubes:
    return 'connected'
  # anything in dir?
  def indir(dir,src,target):
    op = operator.gt if '+' in dir else operator.lt
    axis_idx = 0 if 'x' in dir else 1 if 'y' in dir else 2
    return op(src[axis_idx], target[axis_idx])
  if not any(c for c in cubes if src != c and indir(dir,src,c)):
    return 'exterior'
  return 'unk'

def solve2(data, is_p2=False):
  cubes = parse(data)
  all_faces = {}
  for c in cubes:
    for f in faces(c):
      assert f not in all_faces
      all_faces[f] = 'unk'

  for f,type in all_faces.items():
    all_faces[f] = categorize(f,all_faces,cubes)

  if not is_p2:
    return len([f for f in all_faces if all_faces[f] != 'connected'])
  else:
    mins = [min(c[idx] for c in cubes) - 1 for idx in [0,1,2]]
    maxes = [max(c[idx] for c in cubes) + 1 for idx in [0,1,2]]
    start = tuple(mins)
    assert start not in cubes
    assert in_bounds(start, mins, maxes)
    stack = [start]
    filled = set([start])
    ext_faces = set()
    while stack:
      cube = stack.pop()
      filled.add(cube)
      for neighbor in cube_neighbors6(cube):
        if not in_bounds(neighbor, mins, maxes): continue
        if neighbor in filled: continue
        if neighbor in cubes:
          dir = cube_dir(neighbor,cube)
          face = (*neighbor,dir)
          ext_faces.add(face)
          assert face in all_faces
          assert all_faces[face] in ["unk", "exterior"]
          all_faces[face] = 'exterior'
        else:
          stack.append(neighbor)
    print(len(filled),len(cubes),mins,maxes)
    return len(ext_faces) # len([f for f in all_faces if all_faces[f] == 'exterior'])


    ##########################
    # first try w/ adjacent faces didn't really work
    # seen = set()
    # stack = [f for f in all_faces if all_faces[f] == 'exterior']
    # while stack:
    #   face = stack.pop()
    #   if face in seen: continue
    #   seen.add(face)
    #   for adj_face in adjacent_faces(face):
    #     if adj_face not in all_faces: continue
    #     if all_faces[adj_face] != 'connected':
    #       all_faces[adj_face] = 'exterior'
    #       stack.append(adj_face)
    #     else:
    #       assert all_faces[adj_face] in ['connected','exterior']
    # return len([f for f in all_faces if all_faces[f] == 'exterior'])
    ##########################

p1 = solve2(data,is_p2=False)
p2 = solve2(data,is_p2=True) # 4284 too high, 2528 too low
print(f"p1 {p1}")
print(f"p2 {p2}")

9541 2881 [-1, -1, -1] [22, 22, 22]
p1 4308
p2 2540


In [111]:
from collections import Counter
from heapq import heappop,heappush
import re
data = get_input(19,2022)
testdata = """Blueprint 1: Each ore robot costs 4 ore.  Each clay robot costs 2 ore.  Each obsidian robot costs 3 ore and 14 clay.  Each geode robot costs 2 ore and 7 obsidian.
Blueprint 2: Each ore robot costs 2 ore.  Each clay robot costs 3 ore.  Each obsidian robot costs 3 ore and 8 clay.  Each geode robot costs 3 ore and 12 obsidian.""".split("\n")

def parse(data):
  bps = []
  for line in data:
    id = nums(line)[0]
    robots = {}
    for robot_info in reversed(re.findall(r"Each \w+ robot costs.*?\.", line)):
      rtype = re.findall(r"Each (\w+) robot costs",robot_info)[0]
      costs = Counter({rtype:int(c) for (c,rtype) in re.findall(r"(\d+) (\w+)", robot_info)}) 
      robots[rtype] = costs
    bps.append((id, robots))
  return bps

def available_robots(bp,stuff):
  id = bp[0]
  robots = bp[1]
  avail_robots = []
  for rtype,costs in robots:
    costs = Counter({t:c for t,c in costs})
    if all(stuff[t] >= costs[t] for t in costs):
      avail_robots.append((rtype,costs))
  return avail_robots

MAX_T = 24

def get_max_robots(bp):
  id,bp_robots = bp
  max_robots = {
    rtype: max(rcosts[rtype] or 0 for rcosts in bp_robots.values()) for rtype in bp_robots.keys()
  }
  for rtype in max_robots:
    if max_robots[rtype] == 0:
      max_robots[rtype] = 1e10
  return max_robots

def search(bp,state=None):
  bests = list(0 for _ in range(MAX_T))
  seen = {}
  id,bp_robots = bp
  prune_count = 0
  max_robots = get_max_robots(bp)
  print(f"max_robots {max_robots}")

  if state is None:
    t = 0
    robots = Counter({"ore":1})
    stuff = Counter()
    state = (t, robots, stuff)

  def inner_search(state):
    t,robots,stuff = state
    key = cache_key(state)
    if key in seen and t > seen[key]:
      return -1
    seen[key] = t

    if t == MAX_T:
      return stuff['geode']
    # elif stuff['geode'] > bests[t]:
    #   bests[t] = stuff['geode']
    # elif stuff['geode'] < bests[t]:
    #   return -1

    next_states = []
    next_states.append((t+1, robots, stuff + robots))
    for rtype in ['geode','obsidian','clay','ore']:
      rcosts = bp_robots[rtype]
      if stuff[rtype] <= max_robots[rtype] and afford(stuff,rcosts):
        next_stuff = stuff + robots - rcosts
        next_robots = robots + Counter({rtype:1})
        next_state = (t+1,next_robots,next_stuff)
        next_states.append(next_state)
    return max(inner_search(s) for s in next_states)

  res = inner_search(state)
  print(f"pruned {prune_count}, seen size: {len(seen)}")
  return res

def afford(stuff, costs):
  return all([stuff[k] - costs[k] >= 0 for k in costs.keys()])

def cache_key(state):
  t,robots,stuff = state
  return (frozenset(robots.items()),frozenset(stuff.items()))

def get_score(state):
  *_,stuff = state
  return (stuff['geode'],)

def get_priority(state,max_robots):
  t,robots,stuff = state
  return (t, stuff['geode']-max_robots['geode'], stuff['obsidian']-max_robots['obsidian'], stuff['clay']-max_robots['clay'], stuff['ore']-max_robots['ore'])

def get_priority3(state):
  t,robots,stuff = state
  return (t, 0-robots['geode'], 0-robots['obsidian'], 0-robots['clay'], 0-robots['ore'])

def prune_queue(queue):
  def prune_key(state):
    robots,stuff = state
    k = []
    for rtype in ['geode','obsidian','clay','ore']:
      k.append(robots[rtype])
      k.append(stuff[rtype])
    return k
  return sorted(queue, key=prune_key, reverse=True)[:2_000]

def search4(bp,max_t=24):
  id,bp_robots = bp
  state = (Counter({'ore': 1}), Counter())
  queue = [state]
  bests = defaultdict(int)

  for t in range(max_t+1):
    next_queue = []
    for state in queue:
      robots,stuff = state
      bests[t] = max(bests[t],stuff['geode'])
      for rtype,costs in bp_robots.items():
        if afford(stuff,costs):
          next_queue.append( (robots + Counter({rtype:1}), stuff+robots-costs) )
      next_queue.append( (robots,stuff+robots) )
    queue = prune_queue(next_queue)
    # print(f"t {t}, next_queue {len(next_queue)} queue {len(queue)}")
  return bests[max_t]

def search2(bp):
  id,bp_robots = bp
  bests = [(0,)] * (MAX_T + 1)
  queue = []

  max_robots = get_max_robots(bp)
  print(f"max_robots {max_robots}")

  t = 0
  robots = Counter({'ore': 1})
  stuff = Counter()
  state = (t, robots, stuff)
  heappush(queue, (get_priority(state,max_robots), state))
  seen = set()

  while queue:
    priority,state = heappop(queue)
    t,robots,stuff = state
    k = cache_key(state)
    if k in seen: continue
    seen.add(k)
    if t > MAX_T: continue
    score = get_score(state)
    if score < bests[t]:
      continue
    elif score > bests[t]:
      bests[t] = score
    
    next_stuff = stuff + robots
    next_state = (t+1,robots,next_stuff)
    heappush(queue, (get_priority(next_state,max_robots), next_state))
    for rtype,rcosts in bp_robots.items():
      if robots[rtype] <= max_robots[rtype] and afford(stuff,rcosts):
        next_state = (t+1,robots + Counter({rtype:1}), next_stuff - rcosts)
        heappush(queue, (get_priority(state,max_robots), next_state))
  return bests

def search3(costs_map,state,target_stuff):
  queue = []
  heappush(queue, (get_priority3(state), state))

  while queue:
    _,state = heappop(queue)

    t,robots,stuff = state
    if t > MAX_T: continue
    if stuff >= target_stuff:
      return state
    needed = target_stuff - stuff

    next_states = []
    next_stuff = stuff + robots
    for rtype,costs in costs_map.items():
      if needed[rtype] and afford(stuff,costs):
        next_state = (t+1, robots + Counter({rtype:1}), next_stuff - costs)
        next_states.append(next_state)
    next_states.append( (t+1, robots, next_stuff) )
    for s in next_states:
      heappush(queue, (get_priority3(s), s))

costs_map = {
  'ore': Counter({'ore': 4}),
  'clay': Counter({'ore': 2}),
  'obsidian': Counter({'ore': 3, 'clay': 14}),
  'geode': Counter({'ore': 2, 'obsidian': 7}),
}

target_stuffs = [
  Counter({'geode': 8}),
  Counter({'geode': 1}),
  Counter({'obsidian': 7, 'ore': 2}),
  Counter({'obsidian': 1}),
  Counter({'clay': 14, 'ore': 3}),
  Counter({'clay': 1}),
]

# state = (0, Counter({'ore': 1}), Counter())
# for target_stuff in reversed(target_stuffs):
#   next_state = search3(costs_map,state,target_stuff)
#   if next_state:
#     print(next_state)
#     state = next_state



# bps = parse(testdata)
# search4(bps[1])
# # for bp in bps[10:]: print(f"{bp[0]}: {search(bp)}")
# all_results = {}
# for bp in parse(data):
#   timer = aoc_utils.start_timer()
#   # print(data[bp[0] - 1])
#   res = search4(bp)
#   print(f"search4: {res} in {round(timer())}")
#   # timer = aoc_utils.start_timer()
#   # print(f"search2: {search2(bp)} in {round(timer())}")
#   all_results[bp[0]] = res

all_results_32 = {}
for bp in parse(data)[:3]:
  timer = aoc_utils.start_timer()
  # print(data[bp[0] - 1])
  res = search4(bp,max_t=32)
  print(f"search4: {res} in {round(timer())}")
  # timer = aoc_utils.start_timer()
  # print(f"search2: {search2(bp)} in {round(timer())}")
  all_results_32[bp[0]] = res

# BP ID 22: search:6 vs search2:7 !!
# 1120 too low
# 1154 too low
# 1176 too low

search4: 25 in 2
search4: 19 in 3
search4: 31 in 3


In [112]:
math.prod(all_results_32.values())

14725