# Advent of Code 2021

I really liked [Peter Norvig's approach](https://github.com/norvig/pytudes/blob/main/ipynb/Advent-2020.ipynb) to writing and tracking solutions, so I'm going to use his method this year.

## Day 0: Imports and Utility Functions

Preparations prior to Day 1:

- Some imports.
- A way to read the day's data file and to print/check the output.
- Some utilities that are likely to be useful.


In [335]:
from __future__ import annotations
from collections import Counter, defaultdict, namedtuple
from dataclasses import dataclass
from functools import reduce
from itertools import chain
from math import prod
import operator
from queue import PriorityQueue
from statistics import mean, median
from typing import Callable

In [14]:
def data(day: int, parser=str, sep='\n', filetype="input") -> list:
    "Split the day's input file into sections separated by `sep`, and apply `parser` to each."
    sections = open(f'data/advent2021/{filetype}{day}.txt').read().rstrip().split(sep)
    return [parser(section) for section in sections]
     
def do(day, *answers) -> dict[int, int]:
    "E.g., do(3) returns {1: day3_1(in3), 2: day3_2(in3)}. Verifies `answers` if given."
    g = globals()
    got = []
    for part in (1, 2):
        fname = f'day{day}_{part}'
        if fname in g: 
            got.append(g[fname](g[f'in{day}']))
            if len(answers) >= part: 
                assert got[-1] == answers[part - 1], (
                    f'{fname}(in{day}) got {got[-1]}; expected {answers[part - 1]}')
    return got

def by_line(text: str) -> list[str]:
    "Split the text into a list of lines."
    return text.strip().splitlines()

def first(iterable, default=None) -> object:
    "Return first item in iterable, or default."
    return next(iter(iterable), default)

def rest(sequence) -> object: return sequence[1:]

In [15]:
def quantify(iterable, pred=bool) -> int:
    "Count the number of items in iterable for which pred is true."
    return sum(1 for item in iterable if pred(item))

def sliding_window(iterable, window_size) -> list:
    for i in range(len(iterable) - window_size + 1):
        yield iterable[i:i+window_size]

## Day 1: Sonar Sweep

1. How many measurements are larger than the previous measurement?
2. Consider sums of a three-measurement sliding window. How many sums are larger than the previous sum?

In [16]:
in1 = data(1, parser=int)

In [17]:
def day1_1(depths):
    return quantify(sliding_window(depths, 2), lambda window: window[1] > window[0])

In [18]:
def day1_2(depths):
    sums = list(map(sum, sliding_window(depths, 3)))
    return quantify(sliding_window(sums, 2), lambda window: window[1] > window[0])

In [19]:
do(1, 1121, 1065)

[1121, 1065]

## Day 2: Dive!

1. What do you get if you multiply your final horizontal position by your final depth?

In [20]:
in2 = data(2, lambda line: line.split())

In [21]:
Position = tuple[int, int]

def move(position: Position, delta: Position) -> Position:
    return Position(map(operator.add, position, delta))

def navigate(instructions: list[str], position: Position = (0,0)) -> Position:
    for instruction in instructions:
        match instruction:
            case ["forward", n]:
                position = move(position, (int(n), 0))
            case ["down", n]:
                position = move(position, (0, int(n)))
            case ["up", n]:
                position = move(position, (0, -int(n)))
            case _:
                raise ValueError(f"Unmatched instruction: {instruction}")
    return position

In [22]:
def day2_1(instructions): return prod(navigate(instructions))

In [23]:
AimedPosition = namedtuple("AimedPosition", ["horizontal", "depth", "aim"])

def navigate_by_aim(instructions: list[str], position: AimedPosition = AimedPosition(0,0,0)) -> AimedPosition:
    for instruction in instructions:
        match instruction:
            case ["forward", n]:
                position = position._replace(
                    horizontal=position.horizontal + int(n),
                    depth=position.depth + int(n)*position.aim,
                )
            case ["down", n]:
                position = position._replace(aim=position.aim + int(n))
            case ["up", n]:
                position = position._replace(aim=position.aim - int(n))
            case _:
                raise ValueError(f"Unmatched instruction: {instruction}")
    return position

In [24]:
def day2_2(instructions):
    final_position = navigate_by_aim(instructions)
    return final_position.horizontal * final_position.depth

In [25]:
do(2, 1484118, 1463827010)

[1484118, 1463827010]

## Day 3: Binary Diagnostic

Power consumption is gamma rate times epsilon rate, where the gamma rate is the most common bit setting by position in the input and the epsilon rate is the least common bit setting (the bitwise negation).

Life support rating is oxygen generator rating times CO2 scrubber rating. These ratings are found by progressively filtering the list of values to only those with the most/least common bit in the nth position based on the previous bit filtering.

1. What is the power consumption of the submarine?
2. What is the life support rating of the submarine?

In [26]:
Bitlist = list[int]

def most_common_bits(readings: list[str]) -> Bitlist:
    rate = [0] * len(readings[0])
    for reading in readings:
        for i, bit in enumerate(reading):
            rate[i] += 1 if bit == "1" else -1
    return list(map(lambda bit: int(bit >= 0), rate))

def least_common_bits(readings: list[str]) -> Bitlist:
    rate = [0] * len(readings[0])
    for reading in readings:
        for i, bit in enumerate(reading):
            rate[i] += 1 if bit == "1" else -1
    return list(map(lambda bit: int(bit < 0), rate))

def bitlist_to_int(bitlist: Bitlist) -> int:
    val = 0
    for i, bit in enumerate(reversed(bitlist)):
        val += bit * 2**i
    return val

def toggle(bitlist: Bitlist) -> Bitlist:
    return list(map(lambda x: 0 if x else 1, bitlist))

In [27]:
in3 = data(3)

In [28]:
def day3_1(readings):
    gamma = most_common_bits(readings)
    epsilon = least_common_bits(readings)
    return bitlist_to_int(gamma) * bitlist_to_int(epsilon)

In [29]:
def filter_bits(readings: list[str], filter_fn: Callable[[list[str]], int]) -> int:
    filtered = readings
    for i in range(len(readings[0])):
        filter_bits = filter_fn(filtered)
        filtered = list(filter(lambda reading: int(reading[i]) == filter_bits[i], filtered))
        if len(filtered) == 1:
            break
    assert len(filtered) == 1
    return int(filtered[0], base=2)

def o2_rating(readings: list[str]) -> int: return filter_bits(readings, most_common_bits)
def co2_rating(readings: list[str]) -> int: return filter_bits(readings, least_common_bits)



In [30]:
sample3 = [
    '00100',
    '11110',
    '10110',
    '10111',
    '10101',
    '01111',
    '00111',
    '11100',
    '10000',
    '11001',
    '00010',
    '01010',
]
assert o2_rating(sample3) == 23
assert co2_rating(sample3) == 10

In [31]:
def day3_2(readings): return o2_rating(readings) * co2_rating(readings)

In [32]:
do(3, 3429254, 5410338)

[3429254, 5410338]

## Day 4: Giant Squid

If all numbers in any row or any column of a board are marked, that board wins. (Diagonals don't count.)

The score of the winning board is the sum of all unmarked numbers on that board times the number that was just called when the board won.

1. To guarantee victory against the giant squid, figure out which board will win first. What will your final score be if you choose that board?
2. Figure out which board will win last. Once it wins, what would its final score be?

In [33]:
BingoBoard = list[int]

def board_to_list(board: list[str]) -> BingoBoard:
    """
    Convert 2-D board as list of strings to a 1-D list of ints
    """
    return list(chain(*[list(map(int, row.split())) for row in board]))
in4 = data(4, parser=str.splitlines, sep="\n\n")


In [34]:
BOARD_DIM = 5
def is_winner(board: BingoBoard) -> bool:
    for i in range(BOARD_DIM):
        if all(square is None for square in board[i*BOARD_DIM:(i+1)*BOARD_DIM]) or all(square is None for square in board[i::BOARD_DIM]):
            return True
    return False

In [35]:
test_board = list(range(1, BOARD_DIM**2 + 1))
winning_row = test_board.copy()
winning_row[0:BOARD_DIM] = [None] * BOARD_DIM
winning_col = test_board.copy()
winning_col[0::BOARD_DIM] = [None] * BOARD_DIM

assert not is_winner(test_board)
assert is_winner(winning_row)
assert is_winner(winning_col)

In [36]:
def print_board(board):
    for i in range(BOARD_DIM):
        row = [board[i*BOARD_DIM+j] for j in range(BOARD_DIM)]
        print(row)

def day4_1(lines: list[str]) -> int:
    draws = list(map(int, first(lines)[0].split(",")))
    boards = list(map(board_to_list, rest(lines)))
    for draw in draws:
        for board in boards:
            try:
                board[board.index(draw)] = None
            except ValueError:
                pass
            if is_winner(board):
                return sum([square for square in board if square is not None]) * draw


In [37]:
def day4_2(lines: list[str]) -> int:
    draws = list(map(int, first(lines)[0].split(",")))
    boards = list(map(board_to_list, rest(lines)))
    for draw in draws:
        winners = []
        for board in boards:
            try:
                board[board.index(draw)] = None
            except ValueError:
                pass
            if is_winner(board):
                winners.append(board)
        for winner in winners:
            boards.remove(winner)
        if len(boards) == 0:
            return sum([square for square in winners[0] if square is not None]) * draw



In [38]:
do(4, 27027, 36975)
# 5120 is too low

[27027, 36975]

## Day 5: Hydrothermal Venture

Each line of vents is given as a line segment in the format x1,y1 -> x2,y2 inclusive.

1. Consider only horizontal and vertical lines. At how many points do at least two lines overlap?
2. Consider all lines (horizontal, vertical, and 45° diagonal). At how many points do at least two lines overlap?

In [107]:
@dataclass
class Point:
    x: int
    y: int

    def __hash__(self): return (self.x, self.y).__hash__()

    def __repr__(self): return f"({self.x}, {self.y})"

    def __add__(self, other): return Point(self.x + other.x, self.y + other.y)

@dataclass
class Line:
    start: Point
    end: Point

In [40]:
def str_to_line(input: str) -> Line:
    points = []
    for coords in input.split(" -> "):
        points.append(Point(*list(map(int, coords.split(",")))))
    return Line(*points)

in5 = data(5, parser=str_to_line)


In [41]:
def day5_1(lines: list[Line], do_diagonal=False) -> int:
    grid = defaultdict(int)
    for line in lines:
        if line.start.x == line.end.x:
            for y in range(min(line.start.y, line.end.y), max(line.start.y, line.end.y) + 1):
                grid[Point(line.start.x, y)] += 1
        elif line.start.y == line.end.y:
            for x in range(min(line.start.x, line.end.x), max(line.start.x, line.end.x) + 1):
                grid[Point(x, line.start.y)] += 1
        elif do_diagonal:
            x_dir = 1 if line.start.x < line.end.x else -1
            y_dir = 1 if line.start.y < line.end.y else -1
            current = Point(line.start.x, line.start.y)
            grid[current] += 1
            while current != line.end:
                current = Point(current.x + x_dir, current.y + y_dir)
                grid[current] += 1
    return quantify(grid.values(), lambda x: x >= 2)

def day5_2(lines: list[Line]): return day5_1(lines, do_diagonal=True)

In [42]:
sample5 = data(5, parser=str_to_line, filetype="sample")
assert day5_1(sample5) == 5
assert day5_2(sample5) == 12

In [43]:
do(5, 5632, 22213)

[5632, 22213]

## Day 6: Lanternfish

Each lanternfish produces a new lanternfish every 6 days, except the first time, which takes 2 days longer.

1. How many lanternfish would there be after 80 days?
2. How many lanternfish would there be after 256 days?

In [44]:
in6 = data(6, sep=",", parser=int)

In [45]:
def sim_lanternfish(initial_state: list[int], spawn_rate: int, initial_spawn_rate: int, days: int) -> dict[int, int]:
    max_spawn_rate = max(spawn_rate, initial_spawn_rate)
    state = {k: 0 for k in range(max_spawn_rate+1)}
    for fish in initial_state:
        state[fish] = state.get(fish, 0) + 1
    
    for _ in range(days):
        new_state = {k: 0 for k in range(max_spawn_rate+1)}
        new_state[initial_spawn_rate] = state.get(0, 0)
        for timer in range(1, max_spawn_rate+1):
            new_state[timer-1] = state.get(timer, 0)
        new_state[spawn_rate] += state[0]
        state = new_state

    return state

In [46]:
def day6_1(state): return sum(sim_lanternfish(state, spawn_rate=6, initial_spawn_rate=8, days=80).values())

In [47]:
def day6_2(state): return sum(sim_lanternfish(state, spawn_rate=6, initial_spawn_rate=8, days=256).values())

In [48]:
do(6, 379114, 1702631502303)

[379114, 1702631502303]

## Day 7: The Treachery of Whales

We have a bunch of crabs that can only move left and right on a line.

1. Each step costs one unit of fuel. Determine the horizontal position that the crabs can align to using the least fuel possible. How much fuel must they spend to align to that position?
2. Each step costs one more unit of fuel than the previous step. Determine the horizontal position that the crabs can align to using the least fuel possible. How much fuel must they spend to align to that position?

In [49]:
in7 = data(7, parser=int, sep=",")

In [50]:
def constant_total_fuel(crabs: list[int], target: int) -> int: return sum(map(lambda crab: abs(crab-target), crabs))

def triangular_number(n: int) -> int: return n * (n+1) / 2

def linear_total_fuel(crabs: list[int], target: int) -> int: return int(sum(map(lambda crab: triangular_number(abs(crab-target)), crabs)))

def day7_1(crabs: list[int]) -> int: return constant_total_fuel(crabs, int(median(crabs)))

def day7_2(crabs: list[int]) -> int: return linear_total_fuel(crabs, int(mean(crabs)))

In [51]:
do(7, 335330, 92439766)


[335330, 92439766]

## Day 8: Seven Segment Search

We're trying to interpret seven-segment displays. The input lines are ten unique input digits, a pipe, and four output digits. The segments that correspond to each letter are consistent within a line, but random from one line to the next.

```text
  0:      1:      2:      3:      4:
 aaaa    ....    aaaa    aaaa    ....
b    c  .    c  .    c  .    c  b    c
b    c  .    c  .    c  .    c  b    c
 ....    ....    dddd    dddd    dddd
e    f  .    f  e    .  .    f  .    f
e    f  .    f  e    .  .    f  .    f
 gggg    ....    gggg    gggg    ....

  5:      6:      7:      8:      9:
 aaaa    aaaa    aaaa    aaaa    aaaa
b    .  b    .  .    c  b    c  b    c
b    .  b    .  .    c  b    c  b    c
 dddd    dddd    ....    dddd    dddd
.    f  e    f  .    f  e    f  .    f
.    f  e    f  .    f  e    f  .    f
 gggg    gggg    ....    gggg    gggg
```

1. **For now, focus on the easy digits.** In the output values, how many times do digits `1`, `4`, `7`, or `8` appear?

In [52]:
Digit = set[str]

def digit_to_str(digit: Digit) -> str: return ''.join(sorted(digit))

@dataclass
class Display:
    in_digits: list[Digit]
    out_digits: list[Digit]

def parser8(line: str) -> Display:
    in_str, out_str = line.split('|')
    return Display(in_digits=[set(x) for x in in_str.split()], out_digits=[set(x) for x in out_str.split()])

in8 = data(8, parser=parser8)

In [53]:
def day8_1(displays: list[Display]) -> int:
    return quantify([digit for display in in8 for digit in display.out_digits], lambda digit: len(digit) in [2,3,4,7])

In [54]:
def decode_display(display: Display) -> int:
    knowns = {}
    unknowns = display.in_digits.copy()
    for k,v in {'8': 7, '4': 4, '7': 3, '1': 2}.items():
        knowns[k] = next(digit for digit in unknowns if len(digit) == v)
        unknowns.remove(knowns[k])

    a = knowns['7'] - knowns['1']
    partial = knowns['4'] | a
    knowns['9'] = next(digit for digit in unknowns if partial < digit)
    unknowns.remove(knowns['9'])

    e = knowns['8'] - knowns['9']
    partial = knowns['9'] - knowns['1'] | e
    knowns['6'] = next(digit for digit in unknowns if partial < digit)
    unknowns.remove(knowns['6'])

    knowns['0'] = next(digit for digit in unknowns if len(digit) == 6)
    unknowns.remove(knowns['0'])

    knowns['5'] = knowns['6'] - e
    unknowns.remove(knowns['5'])

    knowns['3'] = next(digit for digit in unknowns if knowns['1'] < digit)
    unknowns.remove(knowns['3'])

    assert len(unknowns) == 1
    knowns['2'] = unknowns[0]

    digits = {digit_to_str(v): k for k, v in knowns.items()}
    return int(''.join([digits[digit_to_str(digit)] for digit in display.out_digits]))

def day8_2(displays: list[Display]) -> int: return sum([decode_display(display) for display in displays])

In [55]:
assert day8_2(data(8, filetype='sample', parser=parser8)) == 61229

In [56]:
do(8, 284, 973499)

[284, 973499]

## Day 9: Smoke Basin

Each cell in the grid is the height of that point (0-9). A low point is a point that's lower than the four points around it (not diagonals). The risk level of a point is its height + 1.

A basin is all the points less than height 9 that are connected to a low point.

1. What is the sum of the risk levels of all low points on your heightmap?
2. What do you get if you multiply together the sizes of the three largest basins?


In [57]:
def str_to_digit_list(line: str) -> list[int]: return list(map(int, line))

in9 = data(9, parser=str_to_digit_list)

In [58]:
Grid = list[list[int]]

def neighbors(grid: Grid, point: Point, nsew=True, diag=False) -> list[Point]:
    ret = []
    x, y = point.x, point.y
    nbrs = [(x-1, y), (x+1, y), (x, y-1), (x, y+1)] if nsew else []
    if diag:
        nbrs += [(x-1, y-1), (x-1, y+1), (x+1, y-1), (x+1, y+1)]
    for x_, y_ in nbrs:
        if 0 <= y_ < len(grid) and 0 <= x_ < len(grid[y_]):
            ret.append(Point(x_, y_))
    return ret

def lowest_points(grid: Grid) -> list[Point]:
    mins = []
    for y in range(len(grid)):
        for x in range(len(grid[y])):
            if grid[y][x] < min(map(lambda pt: grid[pt.y][pt.x], neighbors(grid, Point(x, y)))):
                mins.append(Point(x, y))
    return mins

def day9_1(grid: Grid) -> int:
    lows = lowest_points(grid)
    return sum([grid[low.y][low.x] for low in lows]) + len(lows)

In [59]:
assert day9_1(data(9, parser=str_to_digit_list, filetype="sample")) == 15

In [60]:
def find_basin(grid: Grid, origin: Point) -> list[Point]:
    basin = {origin}
    processed = set()
    to_process = neighbors(grid, origin)
    while to_process:
        point = to_process.pop()
        processed.add(point)
        if point in basin:
            continue
        if grid[point.y][point.x] < 9:
            basin.add(point)
            for neighbor in filter(lambda nbr: nbr not in processed, neighbors(grid, point)):
                to_process.append(neighbor)
    return list(basin)

def day9_2(grid: Grid) -> int:
    lows = lowest_points(grid)
    basins = [find_basin(grid, low) for low in lows]
    return prod(sorted(map(len, basins), reverse=True)[0:3])

In [61]:
assert day9_2(data(9, parser=str_to_digit_list, filetype="sample")) == 1134

In [62]:
do(9, 633, 1050192)

[633, 1050192]

## Day 10: Syntax Scoring

Time for the annual parser! Start by matching nested grouping characters (`()`, `{}`, `[]`, and `<>`). The first unmatched character in a line gets a score:

- `)`: 3 points.
- `]`: 57 points.
- `}`: 1197 points.
- `>`: 25137 points.

If we autocomplete the missing closing characters, each of those gets a score as well:

- `)`: 1 points.
- `]`: 2 points.
- `}`: 3 points.
- `>`: 4 points.

The total autocomplete score is found by iterating over the autocomplete list, multiplying the running score by 5 and adding the current character's score.

1. Find the first illegal character in each corrupted line of the navigation subsystem. What is the total syntax error score for those errors?
2. Find the completion string for each incomplete line, score the completion strings, and sort the scores. What is the middle score?

In [63]:
in10 = data(10)

In [64]:
OPENERS = {
    ')': '(',
    ']': '[',
    '}': '{',
    '>': '<',
}
CLOSERS = {v:k for k, v in OPENERS.items()}

OPEN_CHARS = OPENERS.values()
CLOSE_CHARS = OPENERS.keys()

UNMATCHED_SCORES = {
    '': 0,
    ')': 3,
    ']': 57,
    '}': 1197,
    '>': 25137,
}

AUTOCOMPLETE_SCORES = {
    '': 0,
    ')': 1,
    ']': 2,
    '}': 3,
    '>': 4,
}

def unmatched(chunk: str, autocomplete=False) -> str:
    stack = []
    for char in chunk:
        if char in OPEN_CHARS:
            stack.append(char)
        elif char in CLOSE_CHARS:
            top = stack.pop()
            if top != OPENERS[char]:
                return '' if autocomplete else char
        else:
            raise ValueError(f"Unknown character '{char}'")
    if autocomplete and stack:
        stack.reverse()
        return ''.join([CLOSERS[char] for char in stack])
    else:
        return ''

def day10_1(chunks: list[str]) -> int:
    return sum([UNMATCHED_SCORES[unmatched(chunk)] for chunk in chunks])

In [65]:
def autocomplete_score(completions: str) -> int:
    return reduce(lambda total, score: total * 5 + score, [AUTOCOMPLETE_SCORES[char] for char in completions], 0)

def day10_2(chunks: list[str]) -> int:
    scores = [autocomplete_score(unmatched(chunk, autocomplete=True)) for chunk in chunks]
    return int(median([score for score in scores if score > 0]))


In [66]:
do(10, 362271, 1698395182)

[362271, 1698395182]

## Day 11: Dumbo Octopus

Each octopus occupies a space in a grid. On every step:

- The energy level of each octopus increases by 1
- Any octopus with an energy level **greater than 9** flashes, which increases the energy level of all eight of its neighbors by 1.
  - If this increase puts other octopuses above 9, they also flash, _etc._
- Any octopus that flashed has its energy level set back to 0

1. How many total flashes are there after 100 steps?

### Plan

1. Increase energy by 1 across the board
2. Set the cells to be checked to all of them
3. While there are cells to be checked:
   1. Look at all the cells to be checked for values >9. If found:
      1. Add it to the list of flashers
      2. Add all neighbors to the set to be increased
   2. Increment all the neighbors that aren't flashers
   3. The set of cells to check is now the unflashed neighbors
4. Return the total number of flashers

In [67]:
in11 = data(11, parser=str_to_digit_list)

In [68]:
def print_grid(grid: Grid):
    for line in grid:
        print("".join([str(digit) for digit in line]))

In [69]:
def flasher_step(grid: Grid) -> tuple[Grid, int]:
    grid = [[digit + 1 for digit in line] for line in grid]
    to_check = set()
    flashed = set()
    for y in range(len(grid)):
        for x in range(len(grid[y])):
            to_check.add(Point(x, y))
    while to_check:
        updated = set()
        for point in to_check:
            if grid[point.y][point.x] > 9:
                flashed.add(point)
                grid[point.y][point.x] = 0
                for nbr in neighbors(grid, point, nsew=True, diag=True):
                    if nbr in flashed:
                        continue
                    if grid[nbr.y][nbr.x] <= 9:
                        grid[nbr.y][nbr.x] +=1
                        updated.add(nbr)
        to_check = updated
    return grid, len(flashed)

In [70]:
def day11_1(grid: Grid) -> int:
    total_flashed = 0
    for step in range(100):
        grid, flashed = flasher_step(grid)
        total_flashed += flashed
    return total_flashed

In [71]:
def day11_2(grid: Grid) -> int:
    grid_size = len(grid) * len(grid[0])
    step = 0
    while True:
        step += 1
        grid, flashed = flasher_step(grid)
        if flashed == grid_size:
            return step

In [72]:
do(11, 1732, 290)

[1732, 290]

## Day 12: Passage Pathing

Input is a list of cave connections including a "start" and an "end". Big caves are capitalized, small caves are lowercase. When finding paths, big caves can be revisited and small caves cannot.

1. How many paths through this cave system are there that visit small caves at most once?
2. If one small cave can be revisited per path, how many paths through this cave system are there?

### Plan

- Dict of caves: key is cave name, value is set of neighbors
- Need to add neighbors in both directions
- Keep a set of visited small caves and subtract it from neighbors to determine valid next 
- Assert that two big caves are not neighbors so no loops possible

In [73]:
Caves = defaultdict[str, set[str]]
Path = list[str]
Paths = list[Path]

def small(cave: str) -> bool: return cave.lower() == cave

def parser12(lines: str) -> Caves:
    caves = defaultdict(set)
    for line in lines.split("\n"):
        l, r = line.split('-')
        caves[l].add(r)
        caves[r].add(l)
        # Two big caves connected would create an infinite loop
        assert small(l) or small(r)
    return caves

in12 = data(12, parser=parser12, sep="\n\n")[0]

In [74]:
def traverse(caves: Caves, path=Path(), current="start", allow_one_small=False) -> Paths:
    path.append(current)
    if current == "end":
        return [path]
    paths = Paths()
    for cave in caves[current]:
        if cave == "start":
            continue
        if small(cave) and cave in path:
            if allow_one_small:
                paths += traverse(caves, path=path.copy(), current=cave, allow_one_small=False)
        else:
            paths += traverse(caves, path=path.copy(), current=cave, allow_one_small=allow_one_small)
    return paths


In [75]:
sample12 = data(12, parser=parser12, sep="\n\n", filetype="sample")[0]
assert len(traverse(sample12)) == 10
assert len(traverse(sample12, allow_one_small=True)) == 36

In [76]:
def day12_1(caves: Caves) -> int: return len(traverse(caves))
def day12_2(caves: Caves) -> int: return len(traverse(caves, allow_one_small=True))

In [77]:
do(12, 5104, 149220)

[5104, 149220]

## Day 13: Transparent Origami

Input is a list of coordinates of dots on a page, followed by a series of folding instructions. Folds with `y=` fold the bottom part over the top part; folds with `x=` fold the right part over the left part. (`(0, 0)` never moves.)

After all folds you should see eight capital letters.

1. How many dots are visible after completing just the first fold instruction on your transparent paper?
2. What code do you use to activate the infrared thermal imaging camera system?

In [78]:
Dots = set[Point]

@dataclass
class Fold():
    direction: str
    value: int

    def beyond(self, dot: Point) -> bool:
        return dot.x > self.value if self.direction == 'x' else dot.y > self.value

    def over(self, dot: Point) -> Point:
        if self.direction == 'x':
            return Point(self.value - (dot.x - self.value), dot.y)
        else:
            return Point(dot.x, self.value - (dot.y - self.value))

Folds = list[Fold]

@dataclass
class Paper():
    dots: Dots
    folds: Folds

    def __str__(self):
        max_x = max([dot.x for dot in self.dots])
        max_y = max([dot.y for dot in self.dots])
        lines = []
        for y in range(max_y+1):
            line = ""
            for x in range(max_x+1):
                line += "#" if Point(x,y) in self.dots else "."
            lines.append(line)
        return "\n".join(lines)

def parser13dots(lines: list[str]) -> Dots:
    dots = Dots()
    for line in lines.split('\n'):
        if ',' in line:
            x, y = line.split(',')
            dots.add(Point(x=int(x), y=int(y)))
        else:
            raise ValueError(f"Unmatched line: {line}")
    return dots

def parser13folds(lines: list[str]) -> Folds:
    folds = Folds()
    for line in lines.split('\n'):
        if line.startswith('fold along '):
            dir = line[11]
            val = int(line[13:])
            folds.append(Fold(direction=dir, value=val))
        else:
            raise ValueError(f"Unmatched line: {line}")
    return folds

In [79]:
blocks13 = data(13, sep="\n\n")
in13 = Paper(parser13dots(blocks13[0]), parser13folds(blocks13[1]))

In [80]:
def fold(dots: Dots, step: Fold) -> Dots:
    moving = set([dot for dot in dots if step.beyond(dot)])
    moved = set([step.over(dot) for dot in moving])
    dots = dots - moving | moved
    return dots

In [81]:
def day13_1(paper: Paper) -> int:
    return len(fold(paper.dots, paper.folds[0]))

def day13_2(paper: Paper) -> None:
    for instruction in paper.folds:
        paper.dots = fold(paper.dots, instruction)
    print(paper)
    return None

In [82]:
sample_blocks13 = data(13, sep="\n\n", filetype="sample")
sample13 = Paper(parser13dots(sample_blocks13[0]), parser13folds(sample_blocks13[1]))
sample13.dots = fold(sample13.dots, sample13.folds[0])
print(sample13)

#.##..#..#.
#...#......
......#...#
#...#......
.#.#..#.###


In [83]:
do(13, 666)
# 13-2: "CJHAZHKU"

.##....##.#..#..##..####.#..#.#..#.#..#
#..#....#.#..#.#..#....#.#..#.#.#..#..#
#.......#.####.#..#...#..####.##...#..#
#.......#.#..#.####..#...#..#.#.#..#..#
#..#.#..#.#..#.#..#.#....#..#.#.#..#..#
.##...##..#..#.#..#.####.#..#.#..#..##.


[666, None]

## Day 14: Extended Polymerization

The first line of the file is the polymer template. The remainder of the file is substitution ruls of the form `AB ->C` meaning replace the instance of `AB` with `ACB`. All substitutions in one step occur simultaneously.

1. Apply 10 steps of pair insertion to the polymer template and find the most and least common elements in the result. What do you get if you take the quantity of the most common element and subtract the quantity of the least common element?
2. Now apply 40 steps and repeat the calculation.


In [84]:
Polymer = Counter[str]
Substitutions = dict[str, str]

def parser14(start, rules: list[str]) -> tuple[str, Substitutions]:
    substitutions = {}
    for rule in rules:
        pair, addition = rule.split(" -> ")
        substitutions[pair] = [pair[0] + addition, addition + pair[1]]
    return (start, substitutions)


In [85]:
in14 = parser14(*data(14, parser=by_line, sep="\n\n"))

In [86]:
def apply_substitutions(polymer: Polymer, substitutions: Substitutions) -> Polymer:
    new_polymer = Polymer()
    for pair in polymer.keys():
        new_polymer.update({new_pair: polymer[pair] for new_pair in substitutions[pair]})
    return new_polymer

def day14_1(inputs, steps=10):
    start, substitutions = inputs
    start = start[0]
    polymer = Polymer(sliding_window(start, 2))
    counts = Counter({start[-1]: 1})
    for _ in range(steps):
        polymer = apply_substitutions(polymer, substitutions)
    for pair,count in polymer.items():
        counts.update({pair[0]: count})
    return (max(counts.values()) - min(counts.values()))

def day14_2(inputs): return day14_1(inputs, steps=40)

In [87]:
do(14, 2360, 2967977072188)

[2360, 2967977072188]

In [88]:
sample14 = parser14(*data(14, parser=by_line, sep="\n\n", filetype="sample"))

In [89]:
day14_1(sample14)

1588

## Day 15: Chiton

The input is an $n \times n$ risk map. We need to find the path from $(0,0)$ to $(n,n)$ with the minimum total risk. No diagonal moves, and the risk of $(0,0)$ is not included in the total.

### Plan

Implement [Dijkstra's algorithm](https://en.wikipedia.org/wiki/Dijkstra%27s_algorithm) using a [priority queue](https://docs.python.org/3/library/queue.html#queue.PriorityQueue)

For part 2, the input is only 1/25th of the full map. Take the input and replicate it across and down five times each, increasing the individual values by `1` with each replication, with `10` rolling over to `0`.

1. What is the lowest total risk of any path from the top left to the bottom right?
2. Using the full map, what is the lowest total risk of any path from the top left to the bottom right?

In [90]:
in15 = data(15, parser=str_to_digit_list)

In [91]:
@dataclass
class QueuedPoint():
    priority: int
    point: Point

    def __lt__(self, other): return self.priority < other.priority

def dijkstra(nodes: Grid, source: Point=Point(0, 0), target: Point=None) -> int:
    if not target:
        max_y = len(nodes)
        max_x = len(nodes[0])
        target = Point(max_x-1, max_y-1)

    queue = PriorityQueue()
    queue.put(QueuedPoint(0, Point(0,0)))

    risks = {}
    risks[source] = 0

    previous = {}

    while not queue.empty():
        current = queue.get()
        # if current.point == target:
        #     break
        for neighbor in neighbors(nodes, current.point):
            risk = nodes[neighbor.y][neighbor.x] + risks[current.point]
            if neighbor not in risks or risk < risks[neighbor]:
                risks[neighbor] = risk
                previous[neighbor] = current.point
                # if neighbor not in queue:
                queue.put(QueuedPoint(risk, neighbor))

    total = 0
    current = target
    while current != source:
        total += nodes[current.y][current.x]
        current = previous[current]
    return total

In [92]:
def day15_1(nodes: Grid) -> int: return dijkstra(nodes)

In [93]:
def day15_2(nodes: Grid) -> int:
    mega_grid = []
    for y in range(5):
        for row in nodes:
            mega_row = []
            for x in range(5):
                mega_row += [(cell + x + y) for cell in row]
            mega_row = list(map(lambda cell: cell - 9 if cell > 9 else cell, mega_row))
            mega_grid.append(mega_row)
    return dijkstra(mega_grid)

In [94]:
sample15 = data(15, parser=str_to_digit_list, filetype="sample")
assert day15_2(sample15) == 315

In [95]:
do(15, 447, 2825)

[447, 2825]

## Day 16: Packet Decoder

Input is a hex-encoded string of **a single packet** (potentially containing subpackets) with complex rules for interpretation.

All packets start with a 3-bit version and a 3-bit type. If `type == 4`, the packet contains a literal value, packed 5 bits at a time. The first bit of eevey five is the continuation bit, which is `0` if this is the last character of the literal, and `1` if it is not. There may be zero-padding to start the next packet on the byte boundary.

Every packet type other than literal is an operator of some sort and will contain one or more subpackets. The next bit after the type is the length type ID:

* `0` -- the next `15` bits are the **length** of the subpackets in this one
* `1` -- the next `11` bits are the **number** of subpackets in this one

1. Decode the structure of your hexadecimal-encoded BITS transmission; what do you get if you add up the version numbers in all packets?
2. What do you get if you evaluate the expression represented by your hexadecimal-encoded BITS transmission?


In [96]:
Bit = str
Bits = list[Bit]

class Packet():
    LEN_VERSION = 3
    LEN_TYPE = 3
    LEN_LITERAL = 4
    LEN_CHILD_BIT_COUNT = 15
    LEN_CHILD_PACKET_COUNT = 11
    
    TYPE_SUM = 0
    TYPE_PRODUCT = 1
    TYPE_MIN = 2
    TYPE_MAX = 3
    TYPE_LITERAL = 4
    TYPE_GT = 5
    TYPE_LT = 6
    TYPE_EQ = 7


    BITS = {
        '0': '0000', '1': '0001', '2': '0010', '3': '0011',
        '4': '0100', '5': '0101', '6': '0110', '7': '0111',
        '8': '1000', '9': '1001', 'A': '1010', 'B': '1011',
        'C': '1100', 'D': '1101', 'E': '1110', 'F': '1111',
    }

    def hex_to_bits(hex: str) -> Bits: return "".join(map(lambda x: Packet.BITS[x], hex))
    def bits_to_int(bits: Bits) -> int: return int("".join(bits), base=2)
    
    def __init__(self, bits: Bits, start_index=0):
        self.value = None
        self.children = []

        self.bits = bits
        self.index = self.start_index = start_index
        self.version = Packet.bits_to_int(self.take_bits(Packet.LEN_VERSION))
        self.type = Packet.bits_to_int(self.take_bits(Packet.LEN_TYPE))

        if self.type == self.TYPE_LITERAL:
            value = Bits()
            while True:
                continuation = self.take_bits(1)
                value += self.take_bits(Packet.LEN_LITERAL)
                if continuation == "0":
                    break
            self.value = Packet.bits_to_int(value)
            # Only the top-level packet has padding
            # padding = (self.index + 1) % Packet.LEN_LITERAL
            # self.index += padding
        else:  # operator packet
            length_type = self.take_bits(1)
            if length_type == "0":
                num_child_bits = Packet.bits_to_int(self.take_bits(Packet.LEN_CHILD_BIT_COUNT))
                sub_bits = self.take_bits(num_child_bits)
                sub_index = 0
                while sub_index < num_child_bits:
                    new_child = Packet(sub_bits, sub_index)
                    sub_index += new_child.bits_used()
                    self.children.append(new_child)
            elif length_type == "1":
                num_child_packets = Packet.bits_to_int(self.take_bits(Packet.LEN_CHILD_PACKET_COUNT))
                for _ in range(num_child_packets):
                    new_child = Packet(self.bits, self.index)
                    self.index += new_child.bits_used()
                    self.children.append(new_child)

    def execute(self) -> int:
        if self.type == Packet.TYPE_LITERAL:
            return self.value
        elif self.type == Packet.TYPE_SUM:
            return sum([child.execute() for child in self.children])
        elif self.type == Packet.TYPE_PRODUCT:
            return prod([child.execute() for child in self.children])
        elif self.type == Packet.TYPE_MIN:
            return min([child.execute() for child in self.children])
        elif self.type == Packet.TYPE_MAX:
            return max([child.execute() for child in self.children])
        elif self.type == Packet.TYPE_GT:
            return int(self.children[0].execute() > self.children[1].execute())
        elif self.type == Packet.TYPE_LT:
            return int(self.children[0].execute() < self.children[1].execute())
        elif self.type == Packet.TYPE_EQ:
            return int(self.children[0].execute() == self.children[1].execute())
        else:
            raise ValueError(f"Unknown type: {self.type}")
                
    def take_bits(self, n) -> Bits:
        start = self.index
        self.index += n
        return self.bits[start:self.index]

    def total_bits(self): return len(self.bits)

    def bits_used(self): return self.index - self.start_index

    def sum_versions(self): return self.version + sum(child.sum_versions() for child in self.children)
    
    def __str__(self):
        ret = [
            f"Version: {self.version}",
            f"Type: {self.type}",
            f"Bits: used {self.bits_used()} of {self.total_bits()} (index: {self.index})"
        ]
        if self.type == self.TYPE_LITERAL:
            ret.append(f"Value: {self.value}")
        else:
            ret += [str(child) for child in self.children]
        return "\n".join(ret)

In [97]:
print(Packet(Packet.hex_to_bits("D2FE28")))

Version: 6
Type: 4
Bits: used 21 of 24 (index: 21)
Value: 2021


In [98]:
print(Packet(Packet.hex_to_bits("EE00D40C823060")))

Version: 7
Type: 3
Bits: used 51 of 56 (index: 51)
Version: 2
Type: 4
Bits: used 11 of 56 (index: 29)
Value: 1
Version: 4
Type: 4
Bits: used 11 of 56 (index: 40)
Value: 2
Version: 1
Type: 4
Bits: used 11 of 56 (index: 51)
Value: 3


In [99]:
print(Packet(Packet.hex_to_bits("38006F45291200")))

Version: 1
Type: 6
Bits: used 49 of 56 (index: 49)
Version: 6
Type: 4
Bits: used 11 of 27 (index: 11)
Value: 10
Version: 2
Type: 4
Bits: used 16 of 27 (index: 27)
Value: 20


In [100]:
Packet(Packet.hex_to_bits("C0015000016115A2E0802F182340")).sum_versions()

23

In [101]:
in16 = data(16)[0]

def day16_1(packet) -> int: return Packet(Packet.hex_to_bits(packet)).sum_versions()
def day16_2(packet) -> int: return Packet(Packet.hex_to_bits(packet)).execute()

In [102]:
do(16, 893, 4358595186090)

[893, 4358595186090]

## Day 17: Trick Shot

Launch a probe from $(0, 0)$ at an initial velocity of $(x, y)$. Every step the $x$ velocity is reduced by $1$ until it reaches $0$ and the $y$ velocity is reduced by $1$ forever. Find the initial velocity that maximizes the probe's height while still ending up in the target area.
 
My input: `target area: x=25..67, y=-260..-200`

1. What is the highest y position it reaches on this trajectory?

In [298]:
in17 = (Point(25, -260), Point(67, -200))

def trajectory(velocity: Point, min_target: Point, max_target: Point) -> int:
    """Return max height"""
    current = Point(0, 0)
    max_y = 0
    hit_target_range = False

    while True:
        current += velocity
        max_y = max(max_y, current.y)
        if current.x >= min_target.x and current.y <= max_target.y:
            if current.x > max_target.x or current.y < min_target.y:
                return None
            return max_y
        if velocity.x > 0:
            velocity.x -= 1
        elif current.x < min_target.x:
            return None
        velocity.y -= 1

def trajectories(target_area) -> dict[Point, int]:
    target_min, target_max = target_area
    assert target_min.y < 0, "Target min Y must be negative"
    assert target_max.y < 0, "Target max Y must be negative"

    assert target_max.x > target_min.x, "Target max X must be greater than target min X"
    assert target_max.y > target_min.y, "Target max Y must be greater than target min Y"

    # FIXME: Is there a more elegant solution than brute force? Probably...
    heights = dict()
    for x in range(0, target_max.x+1):
        for y in range(-abs(target_min.y), abs(target_min.y)+1):
            height = trajectory(Point(x, y), target_min, target_max)
            if height is not None:
                heights[Point(x,y)] = height
    return heights

def day17_1(target_area) -> int: return max(trajectories(target_area).values())

def day17_2(target_area) -> int: return len(trajectories(target_area))


In [302]:
assert trajectory(Point(6,9), Point(20,-10), Point(30, -5)) == 45
assert trajectory(Point(6,0), Point(20,-10), Point(30, -5)) == 0
assert day17_2((Point(20,-10), Point(30, -5))) == 112

In [301]:
do(17, 33670, 4903)

[33670, 4903]

## Day 18: Snailfish

Today we're working with snailfish numbers, which are trees of pairs, where each pairs left and right can be either an int, or another pair. Adding two pairs is easy: make a new node with the two addends as children. Things get complicated in the reduction rules.

### Reduction

To reduce a snailfish number, you must repeatedly do the first action in this list that applies to the snailfish number:

1. If any pair is nested inside four pairs, the leftmost such pair explodes.
2. If any regular number is 10 or greater, the leftmost such regular number splits.

After applying an action, stop and restart the reduction process again. Repeat until no more actions are taken.

#### Exploding

To explode a pair, the pair's left value is added to the first regular number to the left of the exploding pair (if any), and the pair's right value is added to the first regular number to the right of the exploding pair (if any). Exploding pairs will always consist of two regular numbers. Then, the entire exploding pair is replaced with the regular number 0.

#### Splitting

To split a regular number, replace it with a pair; the left element of the pair should be the regular number divided by two and rounded down, while the right element of the pair should be the regular number divided by two and rounded up. For example, 10 becomes [5,5], 11 becomes [5,6], 12 becomes [6,6], and so on.

#### Magnitude

The magnitude of a pair is 3 times the magnitude of its left element plus 2 times the magnitude of its right element. The magnitude of a regular number is just that number.


1. Add up all of the snailfish numbers from the homework assignment in the order they appear. What is the magnitude of the final sum?

In [590]:
@dataclass
class Pair():
    value: int = None
    left: Pair = None
    right: Pair = None

    def from_str(string: str) -> Pair:
        stack = list()
        for char in string:
            if char in ["[", ","]:
                continue
            elif char == "]":
                r = stack.pop()
                l = stack.pop()
                pair = Pair(left=l, right=r)
                stack.append(pair)
            else:
                v = int(char)
                pair = Pair(value=v)
                stack.append(pair)
        assert len(stack) == 1
        return stack[0]

    def __str__(self) -> str:
        if self.value is not None:
            return str(self.value)
        else:
            return f"[{str(self.left)},{str(self.right)}]"

    def __add__(self,other) -> Pair:
        return Pair(left=self, right=other)

    def traverse(self, depth=0, parent=None) -> tuple[int, int, Pair]:
        if self.left is not None:
            yield from self.left.traverse(depth=depth+1, parent=self)
        if self.value is not None:
            yield (self, depth, parent)
        if self.right is not None:
            yield from self.right.traverse(depth=depth+1, parent=self)


    def explode(self):
        left = None
        exploded = False
        pairs = self.traverse()
        for current, depth, parent in pairs:
            if depth < 5:
                left = current
            else:
                exploded = True
                break
        if exploded:
            skip, _, _ = next(pairs)
            assert parent.right == skip
            right, _, _ = next(pairs)
            if left:
                left.value += parent.left.value
            if right:
                right.value += parent.right.value
            parent.left = parent.right = None
            parent.value = 0
        return exploded

# TODO: Traverse tree, make list of (leaf, depth) tuples
# First leaf with depth > 4 gets exploded to its neighbors on the list
# Also use same leaf list method for splits


In [591]:
for example in [
    "[1,2]",
    "[[1,2],3]",
    "[9,[8,7]]",
    "[[1,9],[8,5]]",
    "[[[[1,2],[3,4]],[[5,6],[7,8]]],9]",
    "[[[9,[3,8]],[[0,9],6]],[[[3,7],[4,9]],3]]",
    "[[[[1,3],[5,3]],[[1,3],[8,7]]],[[[4,9],[6,9]],[[8,2],[7,3]]]]",
]:
    assert str(Pair.from_str(example)) == example

assert str(Pair.from_str("[1,2]") + Pair.from_str("[[3,4],5]")) == "[[1,2],[[3,4],5]]"

In [592]:
assert list(map(lambda pair_depth: (pair_depth[0].value, pair_depth[1]), Pair.from_str("[[[[1,2],[3,4]],[[5,6],[7,8]]],9]").traverse()))  == [(1,4), (2,4), (3,4), (4,4), (5,4), (6,4), (7,4), (8,4), (9,1)]

In [594]:
test = Pair.from_str("[[3,[2,[1,[7,3]]]],[6,[5,[4,[3,2]]]]]")
test.explode()
str(test)

'[[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]]'