# Advent of Code 2021

I really liked [Peter Norvig's approach](https://github.com/norvig/pytudes/blob/main/ipynb/Advent-2020.ipynb) to writing and tracking solutions, so I'm going to use his method this year.

## Day 0: Imports and Utility Functions

Preparations prior to Day 1:

- Some imports.
- A way to read the day's data file and to print/check the output.
- Some utilities that are likely to be useful.


In [3]:
from collections import defaultdict, namedtuple
from dataclasses import dataclass
from itertools import chain
from math import prod
import operator
from statistics import mean, median
from typing import Callable

In [4]:
def data(day: int, parser=str, sep='\n', filetype="input") -> list:
    "Split the day's input file into sections separated by `sep`, and apply `parser` to each."
    sections = open(f'data/advent2021/{filetype}{day}.txt').read().rstrip().split(sep)
    return [parser(section) for section in sections]
     
def do(day, *answers) -> dict[int, int]:
    "E.g., do(3) returns {1: day3_1(in3), 2: day3_2(in3)}. Verifies `answers` if given."
    g = globals()
    got = []
    for part in (1, 2):
        fname = f'day{day}_{part}'
        if fname in g: 
            got.append(g[fname](g[f'in{day}']))
            if len(answers) >= part: 
                assert got[-1] == answers[part - 1], (
                    f'{fname}(in{day}) got {got[-1]}; expected {answers[part - 1]}')
    return got

def first(iterable, default=None) -> object:
    "Return first item in iterable, or default."
    return next(iter(iterable), default)

def rest(sequence) -> object: return sequence[1:]

In [5]:
def quantify(iterable, pred=bool) -> int:
    "Count the number of items in iterable for which pred is true."
    return sum(1 for item in iterable if pred(item))

def sliding_window(iterable, window_size) -> list:
    for i in range(len(iterable) - window_size + 1):
        yield iterable[i:i+window_size]

## Day 1: Sonar Sweep

1. How many measurements are larger than the previous measurement?
2. Consider sums of a three-measurement sliding window. How many sums are larger than the previous sum?

In [6]:
in1 = data(1, parser=int)

In [7]:
def day1_1(depths):
    return quantify(sliding_window(depths, 2), lambda window: window[1] > window[0])

In [8]:
def day1_2(depths):
    sums = list(map(sum, sliding_window(depths, 3)))
    return quantify(sliding_window(sums, 2), lambda window: window[1] > window[0])

In [9]:
do(1, 1121, 1065)

[1121, 1065]

## Day 2: Dive!

1. What do you get if you multiply your final horizontal position by your final depth?

In [10]:
in2 = data(2, lambda line: line.split())

In [11]:
Position = tuple[int, int]

def move(position: Position, delta: Position) -> Position:
    return Position(map(operator.add, position, delta))

def navigate(instructions: list[str], position: Position = (0,0)) -> Position:
    for instruction in instructions:
        match instruction:
            case ["forward", n]:
                position = move(position, (int(n), 0))
            case ["down", n]:
                position = move(position, (0, int(n)))
            case ["up", n]:
                position = move(position, (0, -int(n)))
            case _:
                raise ValueError(f"Unmatched instruction: {instruction}")
    return position

In [12]:
def day2_1(instructions): return prod(navigate(instructions))

In [13]:
AimedPosition = namedtuple("AimedPosition", ["horizontal", "depth", "aim"])

def navigate_by_aim(instructions: list[str], position: AimedPosition = AimedPosition(0,0,0)) -> AimedPosition:
    for instruction in instructions:
        match instruction:
            case ["forward", n]:
                position = position._replace(
                    horizontal=position.horizontal + int(n),
                    depth=position.depth + int(n)*position.aim,
                )
            case ["down", n]:
                position = position._replace(aim=position.aim + int(n))
            case ["up", n]:
                position = position._replace(aim=position.aim - int(n))
            case _:
                raise ValueError(f"Unmatched instruction: {instruction}")
    return position

In [14]:
def day2_2(instructions):
    final_position = navigate_by_aim(instructions)
    return final_position.horizontal * final_position.depth

In [15]:
do(2, 1484118, 1463827010)

[1484118, 1463827010]

## Day 3: Binary Diagnostic

Power consumption is gamma rate times epsilon rate, where the gamma rate is the most common bit setting by position in the input and the epsilon rate is the least common bit setting (the bitwise negation).

Life support rating is oxygen generator rating times CO2 scrubber rating. These ratings are found by progressively filtering the list of values to only those with the most/least common bit in the nth position based on the previous bit filtering.

1. What is the power consumption of the submarine?
2. What is the life support rating of the submarine?

In [16]:
Bitlist = list[int]

def most_common_bits(readings: list[str]) -> Bitlist:
    rate = [0] * len(readings[0])
    for reading in readings:
        for i, bit in enumerate(reading):
            rate[i] += 1 if bit == "1" else -1
    return list(map(lambda bit: int(bit >= 0), rate))

def least_common_bits(readings: list[str]) -> Bitlist:
    rate = [0] * len(readings[0])
    for reading in readings:
        for i, bit in enumerate(reading):
            rate[i] += 1 if bit == "1" else -1
    return list(map(lambda bit: int(bit < 0), rate))

def bitlist_to_int(bitlist: Bitlist) -> int:
    val = 0
    for i, bit in enumerate(reversed(bitlist)):
        val += bit * 2**i
    return val

def toggle(bitlist: Bitlist) -> Bitlist:
    return list(map(lambda x: 0 if x else 1, bitlist))

In [17]:
in3 = data(3)

In [18]:
def day3_1(readings):
    gamma = most_common_bits(readings)
    epsilon = least_common_bits(readings)
    return bitlist_to_int(gamma) * bitlist_to_int(epsilon)

In [19]:
def filter_bits(readings: list[str], filter_fn: Callable[[list[str]], int]) -> int:
    filtered = readings
    for i in range(len(readings[0])):
        filter_bits = filter_fn(filtered)
        filtered = list(filter(lambda reading: int(reading[i]) == filter_bits[i], filtered))
        if len(filtered) == 1:
            break
    assert len(filtered) == 1
    return int(filtered[0], base=2)

def o2_rating(readings: list[str]) -> int: return filter_bits(readings, most_common_bits)
def co2_rating(readings: list[str]) -> int: return filter_bits(readings, least_common_bits)



In [20]:
sample3 = [
    '00100',
    '11110',
    '10110',
    '10111',
    '10101',
    '01111',
    '00111',
    '11100',
    '10000',
    '11001',
    '00010',
    '01010',
]
assert o2_rating(sample3) == 23
assert co2_rating(sample3) == 10

In [21]:
def day3_2(readings): return o2_rating(readings) * co2_rating(readings)

In [22]:
do(3, 3429254, 5410338)

[3429254, 5410338]

## Day 4: Giant Squid

If all numbers in any row or any column of a board are marked, that board wins. (Diagonals don't count.)

The score of the winning board is the sum of all unmarked numbers on that board times the number that was just called when the board won.

1. To guarantee victory against the giant squid, figure out which board will win first. What will your final score be if you choose that board?
2. Figure out which board will win last. Once it wins, what would its final score be?

In [23]:
BingoBoard = list[int]

def board_to_list(board: list[str]) -> BingoBoard:
    """
    Convert 2-D board as list of strings to a 1-D list of ints
    """
    return list(chain(*[list(map(int, row.split())) for row in board]))
in4 = data(4, parser=str.splitlines, sep="\n\n")


In [24]:
BOARD_DIM = 5
def is_winner(board: BingoBoard) -> bool:
    for i in range(BOARD_DIM):
        if all(square is None for square in board[i*BOARD_DIM:(i+1)*BOARD_DIM]) or all(square is None for square in board[i::BOARD_DIM]):
            return True
    return False

In [25]:
test_board = list(range(1, BOARD_DIM**2 + 1))
winning_row = test_board.copy()
winning_row[0:BOARD_DIM] = [None] * BOARD_DIM
winning_col = test_board.copy()
winning_col[0::BOARD_DIM] = [None] * BOARD_DIM

assert not is_winner(test_board)
assert is_winner(winning_row)
assert is_winner(winning_col)

In [26]:
def print_board(board):
    for i in range(BOARD_DIM):
        row = [board[i*BOARD_DIM+j] for j in range(BOARD_DIM)]
        print(row)

def day4_1(lines: list[str]) -> int:
    draws = list(map(int, first(lines)[0].split(",")))
    boards = list(map(board_to_list, rest(lines)))
    for draw in draws:
        for board in boards:
            try:
                board[board.index(draw)] = None
            except ValueError:
                pass
            if is_winner(board):
                return sum([square for square in board if square is not None]) * draw


In [27]:
def day4_2(lines: list[str]) -> int:
    draws = list(map(int, first(lines)[0].split(",")))
    boards = list(map(board_to_list, rest(lines)))
    for draw in draws:
        winners = []
        for board in boards:
            try:
                board[board.index(draw)] = None
            except ValueError:
                pass
            if is_winner(board):
                winners.append(board)
        for winner in winners:
            boards.remove(winner)
        if len(boards) == 0:
            return sum([square for square in winners[0] if square is not None]) * draw



In [28]:
do(4, 27027, 36975)
# 5120 is too low

[27027, 36975]

## Day 5: Hydrothermal Venture

Each line of vents is given as a line segment in the format x1,y1 -> x2,y2 inclusive.

1. Consider only horizontal and vertical lines. At how many points do at least two lines overlap?
2. Consider all lines (horizontal, vertical, and 45Â° diagonal). At how many points do at least two lines overlap?

In [66]:
@dataclass
class Point:
    x: int
    y: int

    def __hash__(self):
        return (self.x, self.y).__hash__()

    def __repr__(self):
        return f"({self.x}, {self.y})"

@dataclass
class Line:
    start: Point
    end: Point

In [30]:
def str_to_line(input: str) -> Line:
    points = []
    for coords in input.split(" -> "):
        points.append(Point(*list(map(int, coords.split(",")))))
    return Line(*points)

in5 = data(5, parser=str_to_line)


In [31]:
def print_grid(grid: dict[Point, int]):
    for y in range(max([point.y for point in grid.keys()]) + 1):
        for x in range(max([point.x for point in grid.keys()]) + 1):
            print(grid.get(Point(x, y), "."),end="")
        print("")

def day5_1(lines: list[Line], do_diagonal=False) -> int:
    grid = defaultdict(int)
    for line in lines:
        if line.start.x == line.end.x:
            for y in range(min(line.start.y, line.end.y), max(line.start.y, line.end.y) + 1):
                grid[Point(line.start.x, y)] += 1
        elif line.start.y == line.end.y:
            for x in range(min(line.start.x, line.end.x), max(line.start.x, line.end.x) + 1):
                grid[Point(x, line.start.y)] += 1
        elif do_diagonal:
            x_dir = 1 if line.start.x < line.end.x else -1
            y_dir = 1 if line.start.y < line.end.y else -1
            current = Point(line.start.x, line.start.y)
            grid[current] += 1
            while current != line.end:
                current = Point(current.x + x_dir, current.y + y_dir)
                grid[current] += 1
    # print_grid(grid)
    return quantify(grid.values(), lambda x: x >= 2)

def day5_2(lines: list[Line]): return day5_1(lines, do_diagonal=True)

In [32]:
sample5 = data(5, parser=str_to_line, filetype="sample")
assert day5_1(sample5) == 5
assert day5_2(sample5) == 12

In [33]:
do(5, 5632, 22213)

[5632, 22213]

## Day 6: Lanternfish

Each lanternfish produces a new lanternfish every 6 days, except the first time, which takes 2 days longer.

1. How many lanternfish would there be after 80 days?
2. How many lanternfish would there be after 256 days?

In [34]:
in6 = data(6, sep=",", parser=int)

In [35]:
def sim_lanternfish(initial_state: list[int], spawn_rate: int, initial_spawn_rate: int, days: int) -> dict[int, int]:
    max_spawn_rate = max(spawn_rate, initial_spawn_rate)
    state = {k: 0 for k in range(max_spawn_rate+1)}
    for fish in initial_state:
        state[fish] = state.get(fish, 0) + 1
    
    for _ in range(days):
        new_state = {k: 0 for k in range(max_spawn_rate+1)}
        new_state[initial_spawn_rate] = state.get(0, 0)
        for timer in range(1, max_spawn_rate+1):
            new_state[timer-1] = state.get(timer, 0)
        new_state[spawn_rate] += state[0]
        state = new_state

    return state

In [36]:
def day6_1(state): return sum(sim_lanternfish(state, spawn_rate=6, initial_spawn_rate=8, days=80).values())

In [37]:
def day6_2(state): return sum(sim_lanternfish(state, spawn_rate=6, initial_spawn_rate=8, days=256).values())

In [38]:
do(6, 379114, 1702631502303)

[379114, 1702631502303]

## Day 7: The Treachery of Whales

We have a bunch of crabs that can only move left and right on a line.

1. Each step costs one unit of fuel. Determine the horizontal position that the crabs can align to using the least fuel possible. How much fuel must they spend to align to that position?
2. Each step costs one more unit of fuel than the previous step. Determine the horizontal position that the crabs can align to using the least fuel possible. How much fuel must they spend to align to that position?

In [39]:
in7 = data(7, parser=int, sep=",")

In [40]:
def constant_total_fuel(crabs: list[int], target: int) -> int: return sum(map(lambda crab: abs(crab-target), crabs))

def triangular_number(n: int) -> int: return n * (n+1) / 2

def linear_total_fuel(crabs: list[int], target: int) -> int: return int(sum(map(lambda crab: triangular_number(abs(crab-target)), crabs)))

def day7_1(crabs: list[int]) -> int: return constant_total_fuel(crabs, int(median(crabs)))

def day7_2(crabs: list[int]) -> int: return linear_total_fuel(crabs, int(mean(crabs)))

In [41]:
do(7, 335330, 92439766)


[335330, 92439766]

## Day 8: Seven Segment Search

We're trying to interpret seven-segment displays. The input lines are ten unique input digits, a pipe, and four output digits. The segments that correspond to each letter are consistent within a line, but random from one line to the next.

```text
  0:      1:      2:      3:      4:
 aaaa    ....    aaaa    aaaa    ....
b    c  .    c  .    c  .    c  b    c
b    c  .    c  .    c  .    c  b    c
 ....    ....    dddd    dddd    dddd
e    f  .    f  e    .  .    f  .    f
e    f  .    f  e    .  .    f  .    f
 gggg    ....    gggg    gggg    ....

  5:      6:      7:      8:      9:
 aaaa    aaaa    aaaa    aaaa    aaaa
b    .  b    .  .    c  b    c  b    c
b    .  b    .  .    c  b    c  b    c
 dddd    dddd    ....    dddd    dddd
.    f  e    f  .    f  e    f  .    f
.    f  e    f  .    f  e    f  .    f
 gggg    gggg    ....    gggg    gggg
```

1. **For now, focus on the easy digits.** In the output values, how many times do digits `1`, `4`, `7`, or `8` appear?

In [42]:
Digit = set[str]

def digit_to_str(digit: Digit) -> str: return ''.join(sorted(digit))

@dataclass
class Display:
    in_digits: list[Digit]
    out_digits: list[Digit]

def parser8(line: str) -> Display:
    in_str, out_str = line.split('|')
    return Display(in_digits=[set(x) for x in in_str.split()], out_digits=[set(x) for x in out_str.split()])

in8 = data(8, parser=parser8)

In [43]:
def day8_1(displays: list[Display]) -> int:
    return quantify([digit for display in in8 for digit in display.out_digits], lambda digit: len(digit) in [2,3,4,7])

In [44]:
def decode_display(display: Display) -> int:
    knowns = {}
    unknowns = display.in_digits.copy()
    for k,v in {'8': 7, '4': 4, '7': 3, '1': 2}.items():
        knowns[k] = next(digit for digit in unknowns if len(digit) == v)
        unknowns.remove(knowns[k])

    a = knowns['7'] - knowns['1']
    partial = knowns['4'] | a
    knowns['9'] = next(digit for digit in unknowns if partial < digit)
    unknowns.remove(knowns['9'])

    e = knowns['8'] - knowns['9']
    partial = knowns['9'] - knowns['1'] | e
    knowns['6'] = next(digit for digit in unknowns if partial < digit)
    unknowns.remove(knowns['6'])

    knowns['0'] = next(digit for digit in unknowns if len(digit) == 6)
    unknowns.remove(knowns['0'])

    knowns['5'] = knowns['6'] - e
    unknowns.remove(knowns['5'])

    knowns['3'] = next(digit for digit in unknowns if knowns['1'] < digit)
    unknowns.remove(knowns['3'])

    assert len(unknowns) == 1
    knowns['2'] = unknowns[0]

    digits = {digit_to_str(v): k for k, v in knowns.items()}
    return int(''.join([digits[digit_to_str(digit)] for digit in display.out_digits]))

def day8_2(displays: list[Display]) -> int: return sum([decode_display(display) for display in displays])

In [45]:
assert day8_2(data(8, filetype='sample', parser=parser8)) == 61229

In [46]:
do(8, 284, 973499)

[284, 973499]

## Day 9: Smoke Basin

Each cell in the grid is the height of that point (0-9). A low point is a point that's lower than the four points around it (not diagonals). The risk level of a point is its height + 1.

A basin is all the points less than height 9 that are connected to a low point.

1. What is the sum of the risk levels of all low points on your heightmap?
2. What do you get if you multiply together the sizes of the three largest basins?


In [47]:
def parser9(line: str) -> list[int]: return list(map(int, line))

in9 = data(9, parser=parser9)

In [49]:
Grid = list[list[int]]

def neighbors(grid: Grid, point: Point) -> list[Point]:
    ret = []
    x, y = point.x, point.y
    for x_, y_ in [(x-1, y), (x+1, y), (x, y-1), (x, y+1)]:
        if 0 <= y_ < len(grid) and 0 <= x_ < len(grid[y_]):
            ret.append(Point(x_, y_))
    return ret

def lowest_points(grid: Grid) -> list[Point]:
    mins = []
    for y in range(len(grid)):
        for x in range(len(grid[y])):
            if grid[y][x] < min(map(lambda pt: grid[pt.y][pt.x], neighbors(grid, Point(x, y)))):
                mins.append(Point(x, y))
    return mins

def day9_1(grid: Grid) -> int:
    lows = lowest_points(grid)
    return sum([grid[low.y][low.x] for low in lows]) + len(lows)

In [50]:
assert day9_1(data(9, parser=parser9, filetype="sample")) == 15

In [83]:
def find_basin(grid: Grid, origin: Point) -> list[Point]:
    basin = {origin}
    processed = set()
    to_process = neighbors(grid, origin)
    while to_process:
        point = to_process.pop()
        processed.add(point)
        if point in basin:
            continue
        if grid[point.y][point.x] < 9:
            basin.add(point)
            for neighbor in filter(lambda nbr: nbr not in processed, neighbors(grid, point)):
                to_process.append(neighbor)
    return list(basin)

def day9_2(grid: Grid) -> int:
    lows = lowest_points(grid)
    basins = [find_basin(grid, low) for low in lows]
    return prod(sorted(map(len, basins), reverse=True)[0:3])

In [84]:
assert day9_2(data(9, parser=parser9, filetype="sample")) == 1134

In [85]:
do(9, 633, 1050192)

[633, 1050192]