# Day 11: Dumbo Octopus

In [1]:
from pathlib import Path
from copy import deepcopy
from itertools import product, dropwhile
from more_itertools import flatten, quantify, take, first
from functools import reduce, partial
from typing import Iterable

from aoc2021.util import read_as_list

## Puzzle input data

In [2]:
parse_input = lambda line: [int(c) for c in line.rstrip()]

# Test data.
tdata = list(map(parse_input, [
    '5483143223',
    '2745854711',
    '5264556173',
    '6141336146',
    '6357385478',
    '4167524645',
    '2176841721',
    '6882881134',
    '4846848554',
    '5283751526',
]))

# Input data.
data = read_as_list(Path('./day11-input.txt'), func=parse_input)
data

[[4, 4, 3, 8, 6, 2, 4, 2, 6, 2],
 [6, 2, 6, 3, 2, 5, 1, 8, 6, 4],
 [2, 6, 1, 8, 8, 1, 2, 4, 3, 4],
 [2, 1, 3, 4, 2, 6, 4, 5, 6, 5],
 [1, 8, 1, 5, 1, 3, 1, 2, 4, 7],
 [2, 6, 1, 2, 4, 5, 7, 3, 2, 5],
 [8, 5, 8, 5, 7, 6, 7, 5, 8, 4],
 [7, 2, 1, 7, 1, 3, 4, 5, 5, 6],
 [2, 8, 2, 5, 4, 5, 6, 5, 6, 3],
 [8, 2, 4, 8, 4, 7, 3, 5, 8, 4]]

## Puzzle answers
### Part 1

In [3]:
Input = list[list[int]]
Pos = tuple[int,int]


def neighbours(pos: Pos, sz: int = 10) -> list[Pos]:
    """TL, T, TR, L, R, BL, B, BR"""
    drs = (-1,-1,-1,0,0,1,1,1)
    dcs = (-1,0,1,-1,1,-1,0,1)
    row,col = pos
    return [(r,c) for r,c in ((row+dr,col+dc) for dr,dc in zip(drs,dcs)) if 0<=r<sz and 0<=c<sz]


def update(energy: Input, pos: Pos) -> Input:
    r, c = pos
    energy[r][c] += 1
    return energy


def flash(energy: Input, pos: Pos) -> Input:
    r, c = pos
    energy[r][c] = 0
    return energy


def step(energy: Input) -> Iterable[Input]:
    e = deepcopy(energy)
    sz = len(e)
    neighbs = partial(neighbours, sz=sz)
    coords = list(product(range(sz), repeat=2))
    while True:
        to_update = coords
        flashed = set()
        while to_update:
            e = reduce(update, to_update, e)
            to_flash = [(r,c) for r,c in set(to_update) - flashed if e[r][c] > 9]
            to_update = [p for p in flatten(map(neighbs, to_flash)) if p not in flashed]
            flashed = flashed.union(to_flash)
        e = reduce(flash, flashed, e)
        yield deepcopy(e)


def nflashes(energy: Input) -> int:
    return quantify(flatten(energy), lambda x: x == 0)


def num_flashes_after(energy: Input, nsteps: int) -> int:
    return sum(map(nflashes, take(nsteps, step(energy))))


assert neighbours((0,0), 10) == [(0,1), (1,0), (1,1)]
assert neighbours((9,9), 10) == [(8,8), (8,9), (9,8)]
assert next(step([[1,2,3],[4,5,6],[7,8,9]])) == [[2,3,4],[6,8,9],[9,0,0]]
assert nflashes([[2,3,4],[6,8,9],[9,0,0]]) == 2
assert num_flashes_after(tdata, 10) == 204
assert num_flashes_after(tdata, 100) == 1656

In [4]:
n = num_flashes_after(data, 100)
print(f'Total number of flashes after 100 steps: {n}')

Total number of flashes after 100 steps: 1640


### Part 2

In [5]:
def first_syncflash(energy: Input) -> int:
    target = len(energy) ** 2
    i,_ = first(dropwhile(lambda p: p[1] < target, enumerate(map(nflashes, step(tdata)), start=1)))
    return i


assert first_syncflash(tdata) == 195

In [6]:
n = first_syncflash(data)
print(f'The first step during which all octopuses flash: {n}')

The first step during which all octopuses flash: 195
