In [1]:
# allows editing aoc_utils "live" without restarting kernel
# see https://ipython.org/ipython-doc/stable/config/extensions/autoreload.html
# and https://stackoverflow.com/a/17551284
%load_ext autoreload
%autoreload 2

# Add the aoc_utils path
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import aoc_utils
get_input = aoc_utils.get_input
print = aoc_utils.debug_print

timer = aoc_utils.start_timer()

In [2]:
# Useful imports
import re
from collections import defaultdict, deque
import heapq
import functools
import queue
import itertools
import math
import random
from collections import Counter
import statistics
import parse
import operator
from functools import reduce

In [3]:
# aliases from utils
getnums = aoc_utils.getnums

In [4]:
# from norvig's pytudes
cat = ''.join

In [5]:
def Day1(data=get_input(1, 2021)):
  nums = list(map(int, data))
  def p1():
    return len([i for i in range(1, len(nums)) if nums[i] > nums[i-1]])
  def p2():
    sums = [sum(nums[i:i+3]) for i in range(len(nums) - 2)]
    return len([i for i in range(1, len(sums)) if sums[i] > sums[i-1]])
    
  return p1(),p2()

assert Day1() == (1548, 1589)

In [6]:
def Day2(data=get_input(2,2021)):
  def p1():
    h,d=0,0
    for line in data:
      n = getnums(line)[0]
      if 'forward' in line:
        h += n
      elif 'down' in line:
        d += n
      else:
        d -= n
    return d * h
  def p2():
    h,d,aim=0,0,0
    for line in data:
      n = getnums(line)[0]
      if 'forward' in line:
        h += n
        d += aim * n
      elif 'down' in line:
        aim += n
      else:
        aim -= n
    return d * h
  return p1(),p2()

assert Day2() == (1694130, 1698850445),Day2()

In [7]:
def Day3(data = get_input(3,2021)):
  b2d = aoc_utils.bin_to_decimal

  def p1():
    cnts = [Counter() for _ in range(12)]
    for line in data:
      for idx,c in enumerate(line):
        cnts[idx].update(c)
    a = b2d(''.join([cnt.most_common()[0][0] for cnt in cnts]))
    b = b2d(''.join([cnt.most_common()[-1][0] for cnt in cnts]))
    return a*b

  def p2():
    def winnow(arr, most_com=True):
      for idx in range(len(arr[0])):
        cnt = Counter()
        for l in arr:
          cnt.update(l[idx])
        v = None
        if most_com:
          v = '0' if cnt['0'] > cnt['1'] else '1'
        else:
          v = '1' if cnt['1'] < cnt['0'] else '0'
        arr = [l for l in arr if l[idx] == v]
        if len(arr) == 1:
          return arr[0]

    O = winnow(data, most_com=True)
    C  = winnow(data, most_com=False)
    return b2d(O)*b2d(C)
  return p1(),p2()

assert Day3() == (1540244, 4203981)

In [8]:
def Day4(data=get_input(4,2021)):
  nums = getnums(data[0])
  boards = list(aoc_utils.split_list(data[2:]))
  grids = []
  for board in boards:
    grid = {}
    for y,line in enumerate(board):
      for x,num in enumerate(getnums(line)):
        grid[(x,y)] = num
    grids.append(grid)

  def boardsum(g, seen):
    unmarked = []
    for x in range(5):
      for y in range(5):
        n = g[(x,y)]
        if n not in seen:
          unmarked.append(n)
    return sum(unmarked)

  def winner(grid, seen):
    for x in range(5):
      if all(grid[(x,y)] in seen for y in range(5)):
        return True
    for y in range(5):
      if all(grid[(x,y)] in seen for x in range(5)):
        return True
    return False

  def p1():
    seen = set()
    for n in nums:
      seen.add(n)
      for g in grids:
        if winner(g, seen):
          return boardsum(g,seen)*n
  def p2():
    seen = set()
    winners = []
    remaining = grids
    for n in nums:
      seen.add(n)
      for g in remaining:
        if winner(g, seen):
          winners.append(g)
        if len(winners) == len(grids):
          return boardsum(g,seen)*n
      remaining = [_g for _g in remaining if _g not in winners]
  return p1(),p2()
assert Day4() == (71708, 34726)

In [9]:
def Day5(data=get_input(5,2021)):
  lines = list(map(getnums, data))
  def points(x1,y1,x2,y2):
    if y1 == y2:
      x1,x2 = sorted([x1,x2])
      return [(x,y1) for x in range(x1,x2+1)]
    elif x1 == x2:
      y1,y2 = sorted([y1,y2])
      return [(x1,y) for y in range(y1,y2+1)]
    else:
      m,b = statistics.linear_regression([x1,x2],[y1,y2])
      assert m == 1 or m == -1
      x1,x2 = sorted([x1,x2])
      return [(x,int(m*x+b)) for x in range(x1,x2+1)]

  assert points(0,0,2,0) == [(0,0), (1,0), (2,0)]
  assert points(0,0,0,2) == [(0,0), (0,1), (0,2)]
  assert points(2,0,0,0) == [(0,0), (1,0), (2,0)]
  assert points(0,2,0,0) == [(0,0), (0,1), (0,2)]
  assert points(1,1,1,3) == [(1,1),(1,2),(1,3)]
  assert sorted(points(9,7,7,7)) == [(7,7),(8,7),(9,7)]
  assert sorted(points(7,7,9,7)) == [(7,7),(8,7),(9,7)]
  assert sorted(points(0,0,3,3)) == [(0,0),(1,1),(2,2),(3,3)]

  p1_cnt = defaultdict(int)
  p2_cnt = defaultdict(int)
  for line in lines:
    x1,y1,x2,y2 = line
    for p in points(x1,y1,x2,y2):
      p2_cnt[p] += 1
      if x1 == x2 or y1 == y2:
        p1_cnt[p] += 1
  p1 = sum([1 for v in p1_cnt.values() if v > 1])
  p2 = sum([1 for v in p2_cnt.values() if v > 1])

  return p1,p2

assert Day5() == (5608, 20299)


In [10]:
def Day6(data = get_input(6, 2021)):
  p1 = 0
  p2 = 0
  data = getnums(data[0])
  P1_DAYS = 80
  P2_DAYS = 256
  cnt = Counter(data)
  for d in range(P2_DAYS):
    cyc = d%7
    tmp = cnt[7]
    cnt[7] = cnt[8]
    cnt[8] = cnt[cyc]
    cnt[cyc] += tmp
    if d + 1 == P1_DAYS:
      p1 = cnt.total()
  p2 = cnt.total()
  return p1,p2

assert Day6() == (360610, 1631629590423) 

In [11]:
def Day7(data = aoc_utils.get_input(7, 2021)):
  p1 = p2 = 0
  data = getnums(data[0])
  median = statistics.median(data)
  for d in data:
    p1 += abs(d - median)

  @functools.lru_cache()
  def get_cost(v,target):
    n = abs(v-target)
    return n * (n+1) / 2

  p2 = min([sum([get_cost(c,target) for c in data]) for target in range(min(data), max(data))])

  return p1,p2

assert Day7() == (347509, 98257206)

In [12]:
def Day8(data=get_input(8,2021)):
  p1 = p2 = 0
  data = get_input(8, 2021)
  for line in data:
    digits = line.split('| ')[1].split(' ')
    for d in digits:
      if len(d) in(1,4,7,3,2):
        p1 += 1
    pass

  DIGNUMS = [
    [0,1,2,3,4,6], #0
    [1,2], #1
    [0,1,3,4,5], # 2
    [0,1,2,3,5], # 3
    [1,2,5,6], # 4
    [0,2,3,5,6], # 5
    [0,2,3,4,5,6], # 6
    [0,1,2], #7
    [0,1,2,3,4,5,6], # 8
    [0,1,2,3,5,6] # 9
  ]

  def prune(_map):
    found_keys = [k for k,v in _map.items() if len(v) == 1]
    found_vals = [list(v)[0] for k,v in _map.items() if k in found_keys]
    for k in _map.keys():
      if k not in found_keys:
        _map[k].difference_update(found_vals)
    return _map

  def solve_num(s, poss):
    ons = [list(v)[0] for c,v in poss.items() if c in s]
    ons = list(sorted(ons))
    assert ons in DIGNUMS, (ons,s)
    return DIGNUMS.index(ons)

  for line in data:
    signals, digits = line.split(' | ')
    signals = signals.split(' ')
    digits = digits.split(' ')
    poss = defaultdict(lambda: set([0,1,2,3,4,5,6]))
    for s in signals:
      if len(s) == 2:
        _poss = set(DIGNUMS[1]) # 1l
        for c in s: poss[c] = poss[c] & _poss
      elif len(s) == 3:
        _poss = set(DIGNUMS[7]) # 7
        for c in s: poss[c] = poss[c] & _poss
      elif len(s) == 4:
        _poss = set(DIGNUMS[4]) # 4
        for c in s: poss[c] = poss[c] & _poss
      elif len(s) == 7:
        _poss = set(DIGNUMS[8]) # 8
        for c in s: poss[c] = poss[c] & _poss
      elif len(s) == 5:
        _poss = set(DIGNUMS[2]) #2
        _poss = _poss | set(DIGNUMS[3]) #3
        _poss = _poss | set(DIGNUMS[5]) # 5
        for c in s: poss[c] = poss[c] & _poss
      elif len(s) == 6:
        _poss = set(DIGNUMS[6]) # 6
        _poss = _poss | set(DIGNUMS[0]) # 0
        _poss = _poss | set(DIGNUMS[9]) # 9
        for c in s: poss[c] = poss[c] & _poss
      else:
        assert False, s
    found = []
    right_segs = [c for c,v in poss.items() if len(v) == 2]
    assert len(right_segs) == 2
    a,b = right_segs
    a_cnt = len([s for s in signals if a in s])
    b_cnt = len([s for s in signals if b in s])
    if a_cnt == 8:
      assert b_cnt == 9
      poss[a] = set([1])
      poss[b] = set([2])
    else:
      assert a_cnt == 9
      assert b_cnt == 8
      poss[a] = set([2])
      poss[b] = set([1])
    poss = prune(poss)
    left_segs = [c for c,v in poss.items() if len(v) == 2]
    assert len(left_segs) == 2
    a,b = left_segs
    a_cnt = len([s for s in signals if a in s])
    b_cnt = len([s for s in signals if b in s])
    assert poss[a] == set([5,6])
    assert poss[b] == set([5,6])
    if a_cnt == 6:
      assert b_cnt == 7
      poss[a] = set([6])
      poss[b] = set([5])
    else:
      assert b_cnt == 6
      assert a_cnt == 7
      poss[a] = set([5])
      poss[b] = set([6])
    poss = prune(poss)
    rem_segs = [c for c,v in poss.items() if len(v) > 1]
    assert len(rem_segs) == 2
    a,b = rem_segs
    a_cnt = len([s for s in signals if a in s])
    b_cnt = len([s for s in signals if b in s])
    if a_cnt == 7:
      assert b_cnt == 4
      poss[a] = set([3])
      poss[b] = set([4])
    else:
      assert a_cnt == 4
      assert b_cnt == 7
      poss[a] = set([4])
      poss[b] = set([3])
    assert all([len(v) == 1 for v in poss.values()])
    p2 += int(''.join(map(str, [solve_num(s,poss) for s in digits])))


  return p1,p2

assert Day8() == (387, 986034)

In [15]:
def Day9(data=get_input(9,2021)):
  p1 = p2 = 0
  grid = defaultdict(lambda: 9)
  XMAX = len(data[0])
  YMAX = len(data)
  for y,line in enumerate(data):
    for x,c in enumerate(line):
      grid[(x,y)] = int(c)

  lows = []
  for y in range(YMAX):
    for x in range(XMAX):
      v = grid[(x,y)]
      low = all([grid[n] > v for n in aoc_utils.neighbors((x,y))])
      if low:
        lows.append((x,y))
        p1 += v+1

  basins = []
  for low in lows:
    x,y = low
    v = grid[(x,y)]
    seen = set()
    basin = set()
    neighbors = aoc_utils.neighbors(low)
    while len(neighbors):
      n = neighbors.pop()
      if n not in seen:
        seen.add(n)
        if grid[n] < 9:
          basin.add(n)
          for _n in aoc_utils.neighbors(n):
            neighbors.append(_n)
    basins.append(basin)
  bsizes = list(sorted(map(len, basins)))

  p2 = reduce(operator.mul, bsizes[-3:],1)
  return p1,p2

assert Day9() == (486, 1059300)

In [16]:
def Day10(data=get_input(10,2021)):
  opens = "([{<"
  closes = ")]}>"
  scores1 = dict(zip(closes, [3,57,1197,25137]))
  scores2 = dict(zip(closes, [1,2,3,4]))
  p1 = p2 = 0

  def get_score_end(s):
    score = 0
    for c in s:
      score = 5 * score + scores2[c]
    return score

  assert get_score_end("}}]])})]") == 288957
  assert get_score_end(")}>]})") == 5566

  def get_score1(line):
    open = deque()
    for c in line:
      if c in opens:
        open.append(c)
      if c in closes:
        last = open.pop()
        if last is None:
          return scores1[c]
        exp = closes[opens.index(last)]
        if c != exp:
          return scores1[c]
    return 0

  p1 = sum(get_score1(line) for line in data)
  
  def get_end(line):
    open = deque()
    for c in line:
      if c in opens:
        open.append(c)
      else:
        last = open.pop()
        exp = closes[opens.index(last)]
        assert c == exp
    return cat(closes[opens.index(c)] for c in reversed(open))

  assert get_end("[({(<(())[]>[[{[]{<()<>>") == "}}]])})]"


  is_not_corrupt = lambda l: get_score1(l) == 0
  p2 = statistics.median(get_score_end(get_end(line)) for line in filter(is_not_corrupt,data))

  return p1,p2

assert Day10() == (370407, 3249889609)

In [18]:
def Day11(data=get_input(11,2021)):
  ex = "11111\n19991\n19191\n19991\n11111"
  ex = """
  5483143223
  2745854711
  5264556173
  6141336146
  6357385478
  4167524645
  2176841721
  6882881134
  4846848554
  5283751526"""
  ex = [l.strip() for l in ex.strip().split('\n')]
  # data = ex
  p1 = p2 = 0
  grid = defaultdict(lambda: -float('inf'))
  STEPS = 100
  def dbggrid(grid):
    lines = []
    xmax = max(p[0] for p in grid)
    ymax = max(p[1] for p in grid)
    for y in range(ymax+1):
      line = ""
      for x in range(xmax+1):
        line += str(grid[(x,y)])
      lines.append(line)
    return '\n'.join(lines)
  for y,line in enumerate(data):
    for x,v in enumerate(line):
      v = int(v)
      grid[(x,y)] = v
  grid2 = grid.copy()
  def advance(grid):
    seen = set()
    for point in grid:
      grid[point] += 1
    while any(v > 9 for v in grid.values()):
      ps = [p for p in grid if grid[p] > 9 and not p in seen]
      if len(ps) == 0:
        break
      for p in ps:
        for n in aoc_utils.neighbors(p, only_cardinal=False):
          grid[n] += 1
      for p in ps:
        seen.add(p)
    for p in seen:
      grid[p] = 0
    return grid,len(seen)
  for s in range(STEPS):
    grid,flashes = advance(grid)
    p1 += flashes

  grid = grid2
  for i in range(100000):
    grid,_ = advance(grid)
    # "-inf" is one of the values
    if len(set(grid.values())) == 2:
      p2 = i + 1
      break

  return p1,p2
assert Day11() == (1675, 515)

In [19]:
def Day12(data=get_input(12,2021)):
  EDGES = defaultdict(set)

  for line in data:
    a,b = line.split('-')
    EDGES[a].add(b)
    EDGES[b].add(a)

  def explore(head, seen=set(), seen_twice=None):
    if head == 'end':
      return 1
    if head == 'start' and len(seen) > 0:
      return 0
    if head.islower() and head in seen:
      if not seen_twice:
        seen_twice = head
      else:
        return 0
    seen = seen | {head}
    return sum(explore(next_head, seen, seen_twice) for next_head in EDGES[head])

  p1 = explore('start', set(), True)
  p2 = explore('start')

  return p1,p2
assert Day12() == (3679, 107395)

In [20]:
def dbggrid(grid,bounds=None):
    lines = []
    if not bounds:
      xmax = max(p[0] for p in grid)
      ymax = max(p[1] for p in grid)
      bounds = (xmax,ymax)
    xmax,ymax = bounds
    for y in range(ymax+1):
      line = ""
      for x in range(xmax+1):
        line += str(grid[(x,y)])
      lines.append(line)
    return '\n'.join(lines)

def Day13(data=get_input(13,2021)):
  ex = map(str.strip, """
  6,10
  0,14
  9,10
  0,3
  10,4
  4,11
  6,0
  6,12
  4,1
  0,13
  10,12
  3,4
  3,0
  8,4
  1,10
  2,14
  8,10
  9,0

  fold along y=7
  fold along x=5
  """.split("\n"))
  #data = ex
  p1 = p2 = 0
  G = Counter()
  coords,_dirs = aoc_utils.split_list(data)
  for line in coords:
    x,y = getnums(line)
    G[(x,y)] = 1
  # print(dbggrid(G))
  dirs = []
  for line in _dirs:
    val = getnums(line)[0]
    if "x" in line:
      dirs.append(("x",val))
    else:
      dirs.append(("y",val))
  bounds = (max(x for x,y in G), max(y for x,y in G))
  for idx,(dir,val) in enumerate(dirs):
    if idx == 1:
      #print(idx)
      #print(dbggrid(G))
      p1 = len(list(p for p in G if G[p] > 0 and p[0] <= bounds[0] and p[1] <= bounds[1]))
    oldxmax, oldymax = bounds
    if dir == "y":
      bounds = (bounds[0], bounds[1] // 2 - 1)
      for (x,y) in itertools.product(range(bounds[0]+1), range(bounds[1]+1)):
        # 0 - 14
        # 1 - 13
        # 2 - 12
        # 3 - 11
        # 4 - 10
        # 5 - 9
        # 6 - 8
        # 7 - 7
        ref_y = oldymax - y
        G[(x,y)] += G[(x, ref_y)]
    else:
      bounds = (bounds[0] // 2 - 1, bounds[1])
      for (x,y) in itertools.product(range(bounds[0]+1), range(bounds[1]+1)):
        ref_x = oldxmax - x
        G[(x,y)] += G[(ref_x, y)]
  G = dict(G)
  for p in G:
    G[p] = " " if G[p] == 0 else "█"
  print(dbggrid(G, bounds))

  p2 = "PZFJHRFZ"

  return p1,p2

assert Day13() == (610, "PZFJHRFZ")

███  ████ ████   ██ █  █ ███  ████ ████ 
█  █    █ █       █ █  █ █  █ █       █ 
█  █   █  ███     █ ████ █  █ ███    █  
███   █   █       █ █  █ ███  █     █   
█    █    █    █  █ █  █ █ █  █    █    
█    ████ █     ██  █  █ █  █ █    ████ 


In [21]:
def Day14(data=get_input(14,2021)):
  STEPS_P1 = 10
  STEPS_P2 = 40

  pairwise = itertools.pairwise
  p1 = p2 = 0

  RULES = {}
  SOURCE = ""
  for line in data:
    if '->' in line:
      a,b = line.split(' -> ')
      RULES[tuple(a)] = b
    elif line:
      SOURCE = line

  @functools.lru_cache(maxsize=None)
  def count(char, pair, times):
    """
    Count how many times `char` is generated by `times`
    iterations of expanding `pair`
    """
    if times == 0:
      return 0
    new = RULES[pair]
    a,b = pair
    lhs,rhs = (a,new),(new,b)
    return (1 if new == char else 0) + \
      count(char,lhs,times-1) + \
      count(char,rhs,times-1)


  c1 = Counter()
  c2 = Counter()

  for char in SOURCE:
    c1[char] = SOURCE.count(char) + \
      sum([count(char, pair, STEPS_P1) for pair in pairwise(SOURCE)])
    c2[char] = SOURCE.count(char) + \
      sum([count(char, pair, STEPS_P2) for pair in pairwise(SOURCE)])

  p1 = max(c1.values()) - min(c1.values())
  p2 = max(c2.values()) - min(c2.values())

  return p1,p2

assert Day14() == (2602, 2942885922173)

In [23]:
ex = list(map(str.strip, """
1163751742
1381373672
2136511328
3694931569
7463417111
1319128137
1359912421
3125421639
1293138521
2311944581
""".strip().split('\n')))

def Day15(data=get_input(15,2021)):
  p1 = p2 = 0

  def gridify(data):
    G = {}
    for y,line in enumerate(data):
      for x,c in enumerate(line):
        G[(x,y)] = int(c)
    return G
  G = gridify(data)
  BOUNDS = tuple(map(max, zip(*G.keys())))
  BOUNDS = (BOUNDS[0]+1, BOUNDS[1]+1)

  def in_grid(G, pos, is_p1):
    if is_p1:
      return pos in G
    else:
      x,y = pos
      xmax,ymax = BOUNDS
      return x >= 0 and y >= 0 and x < xmax*5 and y < ymax*5

  def cost_of(G, pos, is_p1):
    if is_p1:
      return G[pos]
    else:
      if pos in G:
        return G[pos]
      else:
        xmax,ymax = BOUNDS
        x,y = pos
        adj_x = x % xmax
        adj_y = y % ymax
        rep_x = x // xmax
        rep_y = y // ymax
        adj_v = G[(adj_x,adj_y)] + rep_x + rep_y
        if adj_v > 9:
          adj_v = (adj_v % 10) + 1
        return adj_v

  def search(G,is_p1=True):
    if is_p1:
      TARGET = (BOUNDS[0]-1,BOUNDS[1]-1)
    else:
      TARGET = (BOUNDS[0]*5-1,BOUNDS[1]*5-1)
    nodes = []
    heapq.heappush(nodes, (0, (0,0), 0))
    seen = set()
    while nodes:
      total_cost, cur_pos, cur_cost = heapq.heappop(nodes)
      if cur_pos in seen:
        continue
      seen.add(cur_pos)
      if cur_pos == TARGET:
        return total_cost + cur_cost
      for n in aoc_utils.neighbors(cur_pos):
        if in_grid(G,n,is_p1) and n not in seen:
          next_cost = cost_of(G,n,is_p1)
          heapq.heappush(nodes, (total_cost + cur_cost, n, next_cost))

  p1 = search(G, is_p1=True)
  p2 = search(G, is_p1=False)

  return p1, p2

assert Day15(ex) == (40,315)
assert Day15() == (714, 2948)

In [24]:
def Day16(data=get_input(16,2021)):
  p1 = p2 = 0
  BITS = []

  def hex2bits(c, numbits=4):
    return list(map(int, bin(int(c,16))[2:].zfill(numbits)))
  assert hex2bits("0") == [0,0,0,0]
  assert hex2bits("1") == [0,0,0,1]
  assert hex2bits("A") == [1,0,1,0]
  assert hex2bits("E") == [1,1,1,0]
  def bin2int(bits):
    out = 0
    for i,v in enumerate(reversed(bits)):
      out += v*(2**i)
    return out
  assert bin2int([0,0,0,1]) == 1
  assert bin2int([0,0,1,1]) == 3
  assert bin2int([0,1,1,1]) == 7
  assert bin2int([1,0,1,0]) == 10

  def data2bits(data):
    bits = []
    for line in data:
      for c in line:
        for b in hex2bits(c):
          bits.append(b)
    return bits

  def readint(bits, pos, size=4):
    val = bin2int(bits[pos:pos+size])
    return pos+size,val
  assert readint([0,0,0,1,1,0,0,0], 4) == (8,8)
  assert readint([0,0,0,1,1,0,0,0], 0) == (4,1)

  def readbits(bits, pos, size):
    val = bits[pos:pos+size]
    return pos+size,val

  def readliteral(bits,pos):
    valbits = []
    while True:
      pos,b5 = readbits(bits,pos,5)
      valbits.extend(b5[1:])
      if b5[0] == 0:
        break
    return pos, bin2int(valbits)

  BITS = data2bits(data)
    
  def decode(bits, pos):
    pos,version = readint(bits, pos, 3)
    pos,id_type = readint(bits, pos, 3)
    val = None
    sub_packets = []
    match id_type:
      case 4: # literal
        pos,val = readliteral(bits,pos)
      case _:
        pos,len_type = readbits(bits, pos, 1)
        match len_type:
          case [0]: # next 15 total len in bits of sub-packets
            pos, sub_bit_len = readint(bits, pos, 15)
            end_pos = pos + sub_bit_len
            sub_packets = []
            while pos < end_pos:
              pos, sub_packet = decode(bits, pos)
              sub_packets.append(sub_packet)
          case [1]: # next 11 are number of sub-packets
            pos, sub_count = readint(bits, pos, 11)
            sub_packets = []
            for _ in range(sub_count):
              pos,sub_packet = decode(bits, pos)
              sub_packets.append(sub_packet)
    return pos,(version,id_type,val,sub_packets)

  def calc(packet):
    version,id_type,val,sub_packets = packet
    match id_type:
      case 0:
        assert not val
        return reduce(operator.add, map(calc, sub_packets), 0)
      case 1:
        assert not val
        return reduce(operator.mul, map(calc, sub_packets), 1)
      case 2:
        assert not val
        return min(map(calc, sub_packets))
      case 3:
        assert not val
        return max(map(calc, sub_packets))
      case 4: # literal
        assert not sub_packets
        return val
      case 5: # greater-than
        assert not val
        assert len(sub_packets) == 2
        a,b = sub_packets
        return 1 if calc(a) > calc(b) else 0
      case 6: # less-than
        assert not val
        assert len(sub_packets) == 2
        a,b = sub_packets
        return 1 if calc(a) < calc(b) else 0
      case 7: # equal
        assert not val
        assert len(sub_packets) == 2
        a,b = sub_packets
        return 1 if calc(a) == calc(b) else 0
      case _:
        assert False, f"Unexpected id type {id_type}"

  _,packet = decode(BITS, 0)
  packets = [packet]
  while packets:
    version,_,_,sub_packets = packets.pop()
    p1 += version
    packets.extend(sub_packets)
  p2 = calc(packet)

  return p1,p2

assert Day16() == (955, 158135423448)

In [25]:
def Day17(data=get_input(17,2021)):
  p1 = p2 = 0
  nums = getnums(data[0])
  xmin,xmax,ymin,ymax = nums
  xr = range(xmin,xmax+1)
  yr = range(ymin,ymax+1)
  x,y = (0,0)
  def stepX(x,vx):
    x+=vx
    vx += -1 if vx > 0 else 1 if vx < 0 else 0
    return x,vx
  def stepY(y,vy):
    y+=vy
    vy -= 1
    return y,vy
  
  maxes = []
  poss_vs = set()
  for ivx in range(1,xmax*2):
    for ivy in range(-300,300):
      x = y = 0
      vx = ivx
      vy = ivy
      curmax = 0
      while x < xmax and y > ymin:
        x,vx = stepX(x,vx)
        y,vy = stepY(y,vy)
        if y > curmax:
          curmax = y
        if x in xr and y in yr:
          maxes.append(curmax)
          poss_vs.add((ivx,ivy))
        if x > xmax or y < ymin:
          break



  p1 = max(maxes)
  p2 = len(poss_vs)
  return p1,p2

assert Day17() == (4851, 1739)


In [26]:
def Day18(data=get_input(18,2021)):
  def to_tree(l,parent=None):
    if type(l) == int:
      return {'v':l,'p':parent}
    else:
      assert type(l) == list
      assert len(l) == 2
      lhs,rhs = l
      node = {'p':parent}
      node['l'] = to_tree(lhs, node)
      node['r'] = to_tree(rhs, node)
      return node

  def from_tree(t):
    if 'v' in t:
      return t['v']
    else:
      return [
        from_tree(t['l']),
        from_tree(t['r']),
      ]

  def dfs(t, pred, depth=0):
    if 'v' in t:
      if pred(t['v'], depth):
        return (t, depth)
    else:
      return dfs(t['l'], pred, depth+1) or dfs(t['r'], pred, depth+1)
  def split_snail(n):
    return (math.floor(n/2),math.ceil(n/2))
  assert split_snail(10) == (5,5)
  assert split_snail(11) == (5,6)

  def explode(t,n):
    lval = n['l']['v']
    rval = n['r']['v']
    n['l']['v'] = -1
    n['r']['v'] = -1
    s = str(from_tree(t)).replace(' ','')
    nums = getnums(s)
    assert nums.count(-1) == 2
    lidx = nums.index(-1)
    ridx = lidx + 1
    assert nums[ridx] == -1
    lnewval = rnewval = None
    lcurval = rcurval = None
    if lidx > 0:
      lcurval = nums[lidx-1]
      lnewval = lcurval+lval
    if ridx < len(nums) - 1:
      rcurval = nums[ridx+1]
      rnewval = rcurval+rval
    if lcurval is not None:
      s = replace_last_before(s, lcurval, lnewval)
    if rcurval is not None:
      s = replace_first_after(s, rcurval, rnewval)
    assert "[-1,-1]" in s
    s = s.replace("[-1,-1]","0")
    if ",," in s:
      s = s.replace(',,',)
    if s.endswith(','):
      s = s[0:-1]
    return to_tree(eval(s))

  def replace_last_before(s, curv, newv):
    idx = s.rfind(str(curv), 0, s.index('-1'))
    return s[0:idx] + str(newv) + s[idx+len(str(curv)):]
  def replace_first_after(s, curv, newv):
    idx = s.find(str(curv), s.rindex('-1')+len('-1'))
    return s[0:idx] + str(newv) + s[idx+len(str(curv)):]

  def add_tree(t1, t2, should_reduce=True):
    new_root = {}
    new_root['l'] = t1
    new_root['r'] = t2
    t1['p'] = new_root
    t2['p'] = new_root
    return red(new_root) if should_reduce else new_root
  t1 = to_tree([[[[4,3],4],4],[7,[[8,4],9]]])
  t2 = to_tree([1,1])
  assert from_tree(add_tree(t1,t2,should_reduce=False)) == [[[[[4,3],4],4],[7,[[8,4],9]]],[1,1]]

  def red(t, times=float('inf')):
    if times == 0:
      return t
    found = dfs(t, lambda v,d:d==5)
    if found:
      node,depth = found
      assert depth == 5
      node = node['p']
      assert 'v' in node['l'] and 'v' in node['r']
      t = explode(t,node)
      return red(t, times-1)
    else:
      found = dfs(t, lambda v,d:v>=10)
      if found:
        node,depth = found
        assert node['v'] >= 10
        lhs,rhs = split_snail(node['v'])
        del node['v']
        node['l'] = {'v':lhs,'p':node}
        node['r'] = {'v':rhs,'p':node}
        return red(t, times-1)
    return t

  assert from_tree(red(to_tree([[[[[9,8],1],2],3],4]),times=1)) == [[[[0,9],2],3],4]
  assert from_tree(red(to_tree([7,[6,[5,[4,[3,2]]]]]),times=1)) == [7,[6,[5,[7,0]]]]
  assert from_tree(red(to_tree([[6,[5,[4,[3,2]]]],1]))) == [[6,[5,[7,0]]],3]
  assert from_tree(red(to_tree([[3,[2,[1,[7,3]]]],[6,[5,[4,[3,2]]]]]), times=1)) == [[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]]
  assert from_tree(red(to_tree([[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]]))) == [[3,[2,[8,0]]],[9,[5,[7,0]]]]
  assert from_tree(red(to_tree([[3,[2,[1,[7,3]]]],[6,[5,[4,[3,2]]]]]))) == [[3,[2,[8,0]]],[9,[5,[7,0]]]]
  assert from_tree(red(to_tree([[[[0,7],4],[[7,8],[0,13]]],[1,1]]),times=1)) == [[[[0,7],4],[[7,8],[0,[6,7]]]],[1,1]]
  assert from_tree(red(to_tree([[[[[4,3],4],4],[7,[[8,4],9]]],[1,1]]))) == [[[[0,7],4],[[7,8],[6,0]]],[8,1]]

  def mag_tree(t):
    if 'v' in t: return t['v']
    return mag_tree(t['l'])*3 + mag_tree(t['r'])*2

  assert mag_tree(to_tree([[1,2],[[3,4],5]])) == 143
  assert mag_tree(to_tree([[[[0,7],4],[[7,8],[6,0]]],[8,1]])) == 1384

  trees = [to_tree(eval(l)) for l in data]
  total = trees[0]
  for t in trees[1:]:
    total = add_tree(total, t)

  p1 = mag_tree(total)
  p2 = 0
  for d1,d2 in itertools.combinations(data,2):
    t1 = to_tree(eval(d1))
    t2 = to_tree(eval(d2))
    m1 = mag_tree(add_tree(t1,t2))
    p2 = max(p2, m1)
    t1 = to_tree(eval(d1))
    t2 = to_tree(eval(d2))
    m2 = mag_tree(add_tree(t2,t1))
    p2 = max(p2, m2)

  return p1,p2
assert Day18() == (3411, 4680)

In [27]:
ex = """
--- scanner 0 ---
404,-588,-901
528,-643,409
-838,591,734
390,-675,-793
-537,-823,-458
-485,-357,347
-345,-311,381
-661,-816,-575
-876,649,763
-618,-824,-621
553,345,-567
474,580,667
-447,-329,318
-584,868,-557
544,-627,-890
564,392,-477
455,729,728
-892,524,684
-689,845,-530
423,-701,434
7,-33,-71
630,319,-379
443,580,662
-789,900,-551
459,-707,401

--- scanner 1 ---
686,422,578
605,423,415
515,917,-361
-336,658,858
95,138,22
-476,619,847
-340,-569,-846
567,-361,727
-460,603,-452
669,-402,600
729,430,532
-500,-761,534
-322,571,750
-466,-666,-811
-429,-592,574
-355,545,-477
703,-491,-529
-328,-685,520
413,935,-424
-391,539,-444
586,-435,557
-364,-763,-893
807,-499,-711
755,-354,-619
553,889,-390

--- scanner 2 ---
649,640,665
682,-795,504
-784,533,-524
-644,584,-595
-588,-843,648
-30,6,44
-674,560,763
500,723,-460
609,671,-379
-555,-800,653
-675,-892,-343
697,-426,-610
578,704,681
493,664,-388
-671,-858,530
-667,343,800
571,-461,-707
-138,-166,112
-889,563,-600
646,-828,498
640,759,510
-630,509,768
-681,-892,-333
673,-379,-804
-742,-814,-386
577,-820,562

--- scanner 3 ---
-589,542,597
605,-692,669
-500,565,-823
-660,373,557
-458,-679,-417
-488,449,543
-626,468,-788
338,-750,-386
528,-832,-391
562,-778,733
-938,-730,414
543,643,-506
-524,371,-870
407,773,750
-104,29,83
378,-903,-323
-778,-728,485
426,699,580
-438,-605,-362
-469,-447,-387
509,732,623
647,635,-688
-868,-804,481
614,-800,639
595,780,-596

--- scanner 4 ---
727,592,562
-293,-554,779
441,611,-461
-714,465,-776
-743,427,-804
-660,-479,-426
832,-632,460
927,-485,-438
408,393,-506
466,436,-512
110,16,151
-258,-428,682
-393,719,612
-211,-452,876
808,-476,-593
-575,615,604
-485,667,467
-680,325,-822
-627,-443,-432
872,-547,-609
833,512,582
807,604,487
839,-516,451
891,-625,532
-652,-548,-490
30,-46,-14""".split("\n")

In [28]:
data=get_input(19,2021)

def parse(data):
  scns = []
  curscn = []
  for line in data:
    if not line:
      continue
    if 'scanner' in line:
      if curscn:
        scns.append(curscn)
        curscn = []
      continue
    try:
      x,y,z = getnums(line)
    except:
      x,y,z
    curscn.append((x,y,z))
  assert len(curscn) > 0
  scns.append(curscn)
  assert len(scns) == len([l for l in data if 'scanner' in l])
  for s in scns:
    scndists(s)
  return scns

def pToManyDists(p,ps):
  return [aoc_utils.dist3d(p,p_) for p_ in ps]

def scndists(scn):
  return [(p, pToManyDists(p,scn)) for p in scn]

def find_overlap(scnA,scnB):
  dstsA,dstsB = scndists(scnA),scndists(scnB)
  for pA,dsA in dstsA:
    for pB,dsB in dstsB:
      overlaps = set(dsA) & set(dsB)
      if len(overlaps) >= 12:
        pairs = []
        # first overlapping pair is pA,pB
        # for each dist in overlap:
        #  find l,r index of that dist
        # . get point in l,r at each index
        for dist in overlaps:
          assert dist in dsA and dist in dsB
          a_idx = dsA.index(dist)
          b_idx = dsB.index(dist)
          pairs.append([scnA[a_idx], scnB[b_idx]])
        return pairs

def matrix_mul(a,b):
  return [sum(itertools.starmap(operator.mul, zip(a, col))) for col in zip(*b)]

_cos  = lambda turns: [1,0,-1,0][turns]
_sin  = lambda turns: [0,1,0,-1][turns]
ROT_X = lambda turns: [ [1,0,0], [0, _cos(turns), -1*_sin(turns)], [0, _sin(turns), _cos(turns)] ]
ROT_Y = lambda turns: [ [_cos(turns), 0, _sin(turns)], [0,1,0], [-1*_sin(turns), 0, _cos(turns)] ]
ROT_Z = lambda turns: [ [_cos(turns), -1*_sin(turns), 0], [_sin(turns), _cos(turns), 0], [0,0,1] ]

def orientations(p):
  px = p
  # pointed at x, 4 rotations around x axis
  for turn in range(4):
    yield tuple(matrix_mul(px, ROT_X(turn)))
  # reverse x
  prevx = matrix_mul(p, ROT_Y(2))
  for turn in range(4):
    yield tuple(matrix_mul(prevx, ROT_X(turn)))
  # point at y
  py = matrix_mul(p, ROT_Z(1))
  for turn in range(4):
    yield tuple(matrix_mul(py, ROT_Y(turn)))
  # reverse y
  pyrev = matrix_mul(py, ROT_Z(2))
  for turn in range(4):
    yield tuple(matrix_mul(pyrev, ROT_Y(turn)))
  # point at z
  pz = matrix_mul(p, ROT_Y(1))
  for turn in range(4):
    yield tuple(matrix_mul(pz, ROT_Z(turn)))
  # reverse z
  pzrev = matrix_mul(pz, ROT_Y(2))
  for turn in range(4):
    yield tuple(matrix_mul(pzrev, ROT_Z(turn)))

def orient(scnA,scnB,overlap):
  add3d = aoc_utils.add3d
  sub3d = aoc_utils.sub3d
  """
  Returns the position of scnB in scnA's coords, and the orientation index of scnB relative to scn A
  They all equal the same thing, which is scnB's pos relative to scnA
  """
  scanAPts = [scnAPt for scnAPt,_ in overlap]
  orientedScnBPts = [list(orientations(scnBPt)) for _,scnBPt in overlap]
  idx = -1
  scanPos = None
  while True:
    idx += 1
    # the array of the idx-th orientation of each of the scn-B points in the overlap
    orientedScnBPtsSlice = [ptsList[idx] for ptsList in orientedScnBPts]
    possibleOrientedOverlap = list(zip(scanAPts, orientedScnBPtsSlice))
    allSums = [sub3d(aPt,bPt) for aPt,bPt in possibleOrientedOverlap]
    if len(set(allSums)) == 1:
      scanPos = allSums[0]
      break
  return (scanPos, idx)


In [29]:
def go():
  add3d = aoc_utils.add3d
  data = get_input(19,2021)
  scns = parse(ex)
  scnA = scnB = overlap = None
  identifiedScanIdxs = [0]
  foundBeacons = set(scns[0])
  for scnIdx,scnA in enumerate(scns):
    for cmpScnIdx,scnB in enumerate(scns):
      if scnIdx == cmpScnIdx:
        continue
      if cmpScnIdx in identifiedScanIdxs:
        continue
      overlap = find_overlap(scnA,scnB)
      if overlap:
        print(f"overlap between idx {scnIdx} and {cmpScnIdx}, len {len(overlap)}")
        identifiedScanIdxs.append(cmpScnIdx)
        cmpScanPos,cmpScanOrientationIdx = orient(scnA,scnB,overlap)
        print(f"overlap between idx {scnIdx} and {cmpScnIdx}, len {len(overlap)}, pos: {cmpScanPos}")
        orientedPts = [list(orientations(p))[cmpScanOrientationIdx] for p in scnB]
        translatedOrientedPts = [add3d(p,cmpScanPos) for p in orientedPts]
        foundBeacons.update(translatedOrientedPts)
        scns[cmpScnIdx] = translatedOrientedPts
        break
  return foundBeacons
bcs = go()
len(bcs)

overlap between idx 0 and 1, len 12
overlap between idx 0 and 1, len 12, pos: (68, -1246, -43)
overlap between idx 1 and 3, len 12
overlap between idx 1 and 3, len 12, pos: (-92, -2380, -20)
overlap between idx 2 and 4, len 12
overlap between idx 2 and 4, len 12, pos: (1125, -168, 72)
overlap between idx 4 and 2, len 12
overlap between idx 4 and 2, len 12, pos: (0, 0, 0)


91

In [30]:
# S0 0,0
# S1 5,2
# P1/S1 -1,-1
# P1/S0 4,1
# scns = parse(ex)
# overlap = find_overlap(scns[0],scns[1])
# pos,idx = orient(scns[0],scns[1],overlap)
# scnBOrientedPts = [list(orientations(p))[idx] for p in scns[1]]
# transBOPts = [aoc_utils.add3d(p,pos) for p in scnBOrientedPts]
# scns[1] = transBOPts
# pos,idx = orient(scns[1], scns[4], find_overlap(scns[1],scns[4]))
# scnBOrientedPts = [list(orientations(p))[idx] for p in scns[4]]
# transBOPts = [aoc_utils.add3d(p,pos) for p in scnBOrientedPts]
# scns[4] = transBOPts
# pos,idx = orient(scns[4], scns[2], find_overlap(scns[4],scns[2]))
# print(pos)
# scnBOrientedPts = [list(orientations(p))[idx] for p in scns[2]]
# transBOPts = [aoc_utils.add3d(p,pos) for p in scnBOrientedPts]
# scns[2] = transBOPts
# pos,idx = orient(scns[1],scns[3],find_overlap(scns[1],scns[3]))
# print(pos)
# scnBOrientedPts = [list(orientations(p))[idx] for p in scns[3]]
# transBOPts = [aoc_utils.add3d(p,pos) for p in scnBOrientedPts]
# scns[3] = transBOPts
# foundBeacons = set(scns[0])
# foundBeacons.update(scns[1])
# foundBeacons.update(scns[2])
# foundBeacons.update(scns[3])
# foundBeacons.update(scns[4])
# len(foundBeacons)

def solve(data):
  positions = []
  scns = parse(data)
  refScn = scns[0]
  foundBeacons = set(refScn)
  unknowns = scns[1:]
  while len(unknowns) > 0:
    for scnB in unknowns:
      overlap = find_overlap(refScn,scnB)
      if overlap:
        pos,idx = orient(refScn,scnB,overlap)
        positions.append(pos)
        scnBOrientedPts = [list(orientations(p))[idx] for p in scnB]
        transBOPts = [aoc_utils.add3d(p,pos) for p in scnBOrientedPts]
        foundBeacons.update(transBOPts)
        refScn.extend(transBOPts)
        del unknowns[unknowns.index(scnB)]
  return len(foundBeacons),positions

solve(get_input(19,2021))

(335,
 [(105, -1130, 56),
  (23, 100, 1151),
  (142, -2383, 131),
  (1361, -1191, 141),
  (117, -1237, 1217),
  (1228, -1252, 1272),
  (-21, -3538, -37),
  (139, 1203, 1204),
  (2434, -1105, 113),
  (-11, 2312, 1188),
  (2390, 37, 114),
  (-1211, -89, 1243),
  (1307, -1212, -1125),
  (3678, -1268, 16),
  (2508, -2394, -26),
  (3757, -1186, 1340),
  (1279, -2373, -1059),
  (3710, -2406, 66),
  (-1082, 1243, 1312),
  (4854, -1225, 104),
  (1316, 44, -1063),
  (65, 1126, 32),
  (2530, 41, -1080),
  (4812, -2428, 1313),
  (2386, -3640, 135),
  (2482, -3626, -1056),
  (4902, -2359, 34)])

In [31]:
positions = [(105, -1130, 56),
  (23, 100, 1151),
  (142, -2383, 131),
  (1361, -1191, 141),
  (117, -1237, 1217),
  (1228, -1252, 1272),
  (-21, -3538, -37),
  (139, 1203, 1204),
  (2434, -1105, 113),
  (-11, 2312, 1188),
  (2390, 37, 114),
  (-1211, -89, 1243),
  (1307, -1212, -1125),
  (3678, -1268, 16),
  (2508, -2394, -26),
  (3757, -1186, 1340),
  (1279, -2373, -1059),
  (3710, -2406, 66),
  (-1082, 1243, 1312),
  (4854, -1225, 104),
  (1316, 44, -1063),
  (65, 1126, 32),
  (2530, 41, -1080),
  (4812, -2428, 1313),
  (2386, -3640, 135),
  (2482, -3626, -1056),
  (4902, -2359, 34)]

maxd = 0
for a,b in itertools.combinations(positions, 2):
  d = aoc_utils.manhattan_distance3d(a,b)
  maxd = max(d,maxd)
maxd


10864

In [34]:
def Day20(data=get_input(20,2021)):
  p1 = p2 = 0
  LIGHT_PXL = '#'
  DARK_PXL = '.'
  LIGHT = 1
  DARK = 0

  def parse(data):
    algo = [LIGHT if c == LIGHT_PXL else DARK for c in data[0]]
    G = defaultdict(lambda: DARK)
    for y,line in enumerate(data[2:]):
      for x,c in enumerate(line):
        G[(x,y)] = LIGHT if c == LIGHT_PXL else DARK
    next_defaults = {}
    next_defaults[LIGHT] = algo[-1]
    next_defaults[DARK] = algo[0]
    return algo,G,next_defaults

  # inclusive
  def gridbounds(G):
    xmin = min([x for x,y in G.keys()])
    ymin = min([y for x,y in G.keys()])
    xmax = max([x for x,y in G.keys()])
    ymax = max([y for x,y in G.keys()])
    return ( (xmin,ymin), (xmax,ymax) )

  def extbounds(bounds, xsize, ysize):
    (xmin,ymin),(xmax,ymax) = bounds
    return (
      (xmin-xsize,ymin-ysize),
      (xmax+xsize,ymax+ysize)
    )

  def step(G, algo, next_defaults):
    cur_dflt = G.default_factory()
    next_dflt = next_defaults[cur_dflt]
    G2 = defaultdict(lambda: next_dflt)
    bounds = extbounds(gridbounds(G), 1, 1)
    (xmin,ymin),(xmax,ymax) = bounds
    for x in range(xmin,xmax+1):
      for y in range(ymin,ymax+1):
        p = (x,y)
        n8 = aoc_utils.neighbors8(p, include_point=True)
        v8 = [G[_p] for _p in n8]
        v = aoc_utils.bits_to_int(v8)
        new_v = algo[v]
        G2[p] = new_v
    return G2


  algo,G,next_defaults = parse(data)

  for idx in range(50):
    if idx == 2:
      p1 = sum(G.values())
    G = step(G, algo, next_defaults)

  p2 = sum(G.values())

  return p1,p2

assert Day20() == (5479, 19012)


In [35]:
def Day21(data=get_input(21,2021)):
  p1start = getnums(data[0])[-1]
  p2start = getnums(data[1])[-1]

  def play1(p1,p2):
    die100 = itertools.cycle(range(1,101))
    p1score = p2score = 0
    rolls = 0

    while True:
      rolls += 3
      p1roll = next(die100) + next(die100) + next(die100)
      p1 = (p1 + p1roll) % 10 or 10
      p1score += p1
      if p1score >= 1000:
        return rolls*p2score
      rolls += 3
      p2roll = next(die100) + next(die100) + next(die100)
      p2 = (p2 + p2roll) % 10 or 10
      p2score += p2
      if p2score >= 1000:
        return rolls*p1score

  ROLL_FREQS = Counter(map(sum, itertools.product([1,2,3], repeat=3)))
  @functools.lru_cache(maxsize=None)
  def play2(p1, p2, score1=0, score2=0, turn=0):
    if score1 >= 21:
      return [1,0]
    if score2 >= 21:
      return [0,1]

    wins = [0,0]

    for roll,freq in ROLL_FREQS.items():
      p1_next     = p1
      score1_next = score1
      p2_next     = p2
      score2_next = score2
      turn_next   = 1 - turn

      if turn == 0:
        p1_next     = (p1 + roll) % 10 or 10
        score1_next = score1 + p1_next

      elif turn == 1:
        p2_next     = (p2 + roll) % 10 or 10
        score2_next = score2 + p2_next

      w1,w2 = play2(p1_next, p2_next, score1_next, score2_next, turn_next)
      wins[0] += w1 * freq
      wins[1] += w2 * freq
    return wins

  p1 = play1(p1start, p2start)
  p2 = max(play2(p1start,p2start))
  return p1,p2

assert Day21() == (513936, 105619718613031)

In [36]:
def Day22(data=get_input(22,2021)):
  G = defaultdict(bool)
  for line in data:
    if any([abs(n) > 50 for n in getnums(line)]): continue
    on = line.startswith('on')
    x1,x2,y1,y2,z1,z2 = getnums(line)
    for x in range(x1,x2+1):
      for y in range(y1,y2+1):
        for z in range(z1,z2+1):
          G[(x,y,z)] = on
  p1 = len([v for v in G.values() if v])

  def intersect(box1,box2):
    x1,x2,y1,y2,z1,z2 = box1
    a1,a2,b1,b2,c1,c2 = box2
    return a1 <= x2 and x1 <= a2 and y1 <= b2 and b1 <= y2 and z1 <= c2 and c1 <= z2

  def find_intersect(box1,box2):
    x1,x2,y1,y2,z1,z2 = box1
    a1,a2,b1,b2,c1,c2 = box2
    return (
      max(x1,a1),min(x2,a2),
      max(y1,b1),min(y2,b2),
      max(z1,c1),min(z2,c2)
    )

  boxes = Counter()
  for line in data:
    on = line.startswith('on')
    curbox = tuple(getnums(line))
    next_boxes = Counter()
    for box,hits in boxes.items():
      if intersect(curbox,box):
        intersection = find_intersect(curbox,box)
        next_boxes[intersection] -= hits
    if on:
      next_boxes[curbox] += 1
    boxes.update(next_boxes)
  total = 0
  for cube,hits in boxes.items():
    x1,x2,y1,y2,z1,z2 = cube
    cube_area = (x2+1-x1)*(y2+1-y1)*(z2+1-z1)
    total += cube_area * hits
  p2 = total

  return p1,p2

assert Day22() == (537042, 1304385553084863)


In [12]:
def Day23WIP():
  return "https://aoc-2021-day-23.netlify.app/"
  from collections import namedtuple
  D = 'D'
  C = 'C'
  B = 'B'
  A = 'A'
  ex_data = [B,C,B,D, D,C,B,A, D,B,A,C, A,D,C,A]
  data = [D,B,C,A, D,C,B,A, D,B,A,C, C,A,D,B]
  nodes = [({'idx':idx,'allowed':[A,B,C,D],'neighbors':[],'occupant':None}) for idx in range(27)]
  nodes[0]['neighbors'] = [nodes[1]]
  nodes[1]['neighbors'] = [nodes[0], nodes[2]]
  nodes[2]['neighbors'] = [nodes[1], nodes[3], nodes[11]]
  nodes[3]['neighbors'] = [nodes[2], nodes[4]]
  nodes[4]['neighbors'] = [nodes[3], nodes[5], nodes[12]]
  nodes[5]['neighbors'] = [nodes[4], nodes[6]]
  nodes[6]['neighbors'] = [nodes[5], nodes[7], nodes[13]]
  nodes[7]['neighbors'] = [nodes[6], nodes[8]]
  nodes[8]['neighbors'] = [nodes[7], nodes[9], nodes[14]]
  nodes[9]['neighbors'] = [nodes[8], nodes[10]]
  nodes[10]['neighbors'] = [nodes[9]]
  nodes[11]['neighbors'] = [nodes[2], nodes[15]]
  nodes[12]['neighbors'] = [nodes[4], nodes[16]]
  nodes[13]['neighbors'] = [nodes[6], nodes[17]]
  nodes[14]['neighbors'] = [nodes[8], nodes[18]]
  nodes[15]['neighbors'] = [nodes[11], nodes[19]]
  nodes[16]['neighbors'] = [nodes[12], nodes[20]]
  nodes[17]['neighbors'] = [nodes[13], nodes[21]]
  nodes[18]['neighbors'] = [nodes[14], nodes[22]]

  nodes[19]['neighbors'] = [nodes[15], nodes[23]]
  nodes[20]['neighbors'] = [nodes[16], nodes[24]]
  nodes[21]['neighbors'] = [nodes[17], nodes[25]]
  nodes[22]['neighbors'] = [nodes[18], nodes[26]]

  nodes[23]['neighbors'] = [nodes[19]]
  nodes[24]['neighbors'] = [nodes[20]]
  nodes[25]['neighbors'] = [nodes[21]]
  nodes[26]['neighbors'] = [nodes[22]]

  nodes[11]['allowed'] = [A]
  nodes[15]['allowed'] = [A]
  nodes[19]['allowed'] = [A]
  nodes[23]['allowed'] = [A]
                           
  nodes[12]['allowed'] = [B]
  nodes[16]['allowed'] = [B]
  nodes[20]['allowed'] = [B]
  nodes[24]['allowed'] = [B]
                           
  nodes[13]['allowed'] = [C]
  nodes[17]['allowed'] = [C]
  nodes[21]['allowed'] = [C]
  nodes[25]['allowed'] = [C]
                           
  nodes[14]['allowed'] = [D]
  nodes[18]['allowed'] = [D]
  nodes[22]['allowed'] = [D]
  nodes[26]['allowed'] = [D]

  # cannot stop at top of a hallway
  nodes[2]['allowed'] = []
  nodes[4]['allowed'] = []
  nodes[6]['allowed'] = []
  nodes[8]['allowed'] = []

  ex_data = [B,C,B,D, D,C,B,A, D,B,A,C, A,D,C,A]
  data = [D,B,C,A, D,C,B,A, D,B,A,C, C,A,D,B]
  nodes[11]['occupant'] = data[0]
  nodes[12]['occupant'] = data[1]
  nodes[13]['occupant'] = data[2]
  nodes[14]['occupant'] = data[3]

  nodes[15]['occupant'] = data[4]
  nodes[16]['occupant'] = data[5]
  nodes[17]['occupant'] = data[6]
  nodes[18]['occupant'] = data[7]

  nodes[19]['occupant'] = data[8]
  nodes[20]['occupant'] = data[9]
  nodes[21]['occupant'] = data[10]
  nodes[22]['occupant'] = data[11]

  nodes[23]['occupant'] = data[12]
  nodes[24]['occupant'] = data[13]
  nodes[25]['occupant'] = data[14]
  nodes[26]['occupant'] = data[15]

  def valid_moves(node, home):
    if node.occupant is None:
      return []
    if node in home:
      return []
    if len(node['allowed']) == 4: # in hallway
      # if node is in hallway: only the bottom-most home node, if path is free
      pass
    # if node has not moved from start:
    #   every spot in hallway if a) path is free and b) spot's "allowed" includes node's occupant
    pass

  # returns the path if there is one
  def path_to_node(n1,n2,cur_path=[]):
    for n in n1['neighbors']:
      if n in cur_path:
        continue
      if n['occupant'] is not None:
        continue
      if n == n2:
        return cur_path[:] + [n]
      path = path_to_node(n,n2,cur_path[:]+[n])
      if path:
        return path
    return None

Day23()

AttributeError: can't set attribute

In [115]:
def Day24(data=get_input(24,2021)):
  def runProgOrig(input):
    prog = data
    regs = {'w':0,'x':0,'y':0,'z':0}
    inputidx = 0
    def lookup(v, regs):
      if v in regs: return regs[v]
      return int(v)

    for line in prog:
      inst,*rest = line.split(' ')
      match inst:
        case 'inp':
          assert len(input) > 0
          assert len(rest) == 1
          reg = rest[0]
          assert reg == 'w'
          v = input[inputidx]
          inputidx += 1
          regs[reg] = v
        case 'add':
          assert len(rest) == 2
          reg = rest[0]
          a = regs[reg]
          b = lookup(rest[1], regs)
          regs[reg] = a + b
        case 'mod':
          assert len(rest) == 2
          reg = rest[0]
          a = regs[reg]
          b = lookup(rest[1], regs)
          assert a >= 0
          assert b > 0
          regs[reg] = a%b
        case 'mul':
          assert len(rest) == 2
          reg = rest[0]
          a = regs[reg]
          b = lookup(rest[1], regs)
          regs[reg] = a*b
        case 'div':
          assert len(rest) == 2
          reg = rest[0]
          a = regs[reg]
          b = lookup(rest[1], regs)
          assert b != 0
          regs[reg] = a//b
        case 'eql':
          assert len(rest) == 2
          reg = rest[0]
          a = regs[reg]
          b = lookup(rest[1], regs)
          v = 1 if a == b else 0
          regs[reg] = v
        case _:
          assert False
    return regs

  def runProg(n):
    w = x = y = z = 0
    idx = 0
    idx += 1
    idx += 1
    idx += 1 # 2
    idx += 1
    z = (n[0] + 1) * 26 + n[1] + 9
    if n[2] - 2 != n[3]:
      z = z * 26 + n[3] + 6
    idx += 1
    z = 26 * z + n[4] + 6
    idx += 1
    w = n[6]; idx += 1
    if n[5] - 1 != n[6]:
      z = 26 * z + n[6] + 13
    w = n[7]; idx += 1
    w = n[8]; idx += 1
    if n[7] - 3 != n[8]:
      z = 26 * z + n[8] + 7
    w = n[9]; idx += 1
    w = n[10]; idx += 1
    if n[9] - 7 != n[10]:
      z = z * 26 + n[10] + 10
    w = n[11]; idx += 1
    if n[9] - 7 != n[10]:
      x = n[10] - 1
    elif n[7] - 3 != n[8]:
      x = n[8] - 4
    elif n[5] - 1 != n[6]:
      x = n[6] + 2
    else:
      # fallback to just: x = z % 26 - 11 (double-check its 11)
      x = n[4] - 5 # <-- double-check?
    z = z // 26
    if x != n[11]:
      z = z * 26 + n[11] + 14
    w = n[12]; idx += 1
    x = z % 26
    z = z // 26
    x = x - 6
    if x != n[12]:
      z = z * 26 + n[12] + 7
    w = n[13]; idx += 1
    x = z % 26
    z = z // 26
    x = x - 5
    if x == n[13]:
      x = 0
      y = 0
    else:
      x = 1
      y = n[13] + 1
      z = z * 26 + n[13] + 1
    return {'w':w,'x':x,'y':y,'z':z}

  def p1():
    n = [9] * 14
    for n2 in reversed(range(3,10)):
      n[2] = n2
      n[3] = n2 - 2
      for n5 in reversed(range(2, 10)):
        n[5] = n5
        n[6] = n5 - 1
        for n7 in reversed(range(4, 10)):
          n[7] = n7
          n[8] = n7 - 3
          for n9 in [9, 8]:
            n[9] = n9
            n[10] = n9 - 7
            for n4 in reversed(range(1,10)):
              n[4] = n4
              for n0 in reversed(range(1,10)):
                n[0] = n0
                for n1 in reversed(range(1,10)):
                  n[1] = n1
                  for n11 in reversed(range(1,10)):
                    n[11] = n11
                    for n12 in reversed(range(1,10)):
                      n[12] = n12
                      for n13 in reversed(range(1,10)):
                        n[13] = n13
                        reg = runProg(n)
                        if reg['z'] == 0:
                          return int(cat(map(str,n)))

  def p2():
    n = [9] * 14
    for n2 in (range(3,10)):
      n[2] = n2
      n[3] = n2 - 2
      for n5 in (range(2, 10)):
        n[5] = n5
        n[6] = n5 - 1
        for n7 in (range(4, 10)):
          n[7] = n7
          n[8] = n7 - 3
          for n9 in [8,9]:
            n[9] = n9
            n[10] = n9 - 7
            for n4 in (range(1,10)):
              n[4] = n4
              for n0 in (range(1,10)):
                n[0] = n0
                for n1 in (range(1,10)):
                  n[1] = n1
                  for n11 in (range(1,10)):
                    n[11] = n11
                    for n12 in (range(1,10)):
                      n[12] = n12
                      for n13 in (range(1,10)):
                        n[13] = n13
                        reg = runProg(n)
                        if reg['z'] == 0:
                          return int(cat(map(str,n)))
  return p1(),p2()

assert Day24() == (96979989692495, 51316214181141)

In [111]:
def Day25(data=get_input(25,2021)):
  S = 'v'
  E = '>'
  EMPTY = '.'
  def parse(data):
    G = defaultdict(lambda:EMPTY)
    Y = len(data)
    X = len(data[0])
    for y,line in enumerate(data):
      for x,c in enumerate(line):
        if c in [S,E]: G[(x,y)] = c
    return G,X,Y

  def step(G, X, Y):
    moved = False
    G2 = defaultdict(lambda:EMPTY)
    for (x,y),c in G.items():
      if c == E:
        n = ((x+1)%X,y)
        if n not in G:
          moved = True
          G2[n] = c
        else:
          G2[(x,y)] = c
      else:
        G2[(x,y)] = c
    G3 = defaultdict(lambda:EMPTY)
    for (x,y),c in G2.items():
      if c == S:
        n = (x,(y+1)%Y)
        if n not in G2:
          moved = True
          G3[n] = c
        else:
          G3[(x,y)] = c
      else:
        G3[(x,y)] = c
    return G3,moved
  G,X,Y = parse(data)
  for idx in itertools.count(1):
    G,moved = step(G,X,Y)
    if not moved:
      return idx

assert Day25() == 560