In [None]:
from pathlib import Path
import numpy as np
from scipy.signal import convolve2d

In [None]:
test_input_1 = """5483143223
2745854711
5264556173
6141336146
6357385478
4167524645
2176841721
6882881134
4846848554
5283751526
"""

input_1 = Path("input_1.txt").read_text()

In [None]:
FLASH_KERNEL = np.asarray([
    [1, 1, 1],
    [1, 1, 1],
    [1, 1, 1]
], dtype=int)

def parse_input(input_string):
    rows = input_string.strip("\n").split("\n")
    height = len(rows)
    width = len(rows[0])
    flat = "".join(rows)
    return np.fromstring(" ".join(flat), int, sep=' ').reshape((height, width))

def next_step(octopi):
    octopi += 1
    unhandled_flashes = list(zip(*np.nonzero(octopi == 10)))
    while unhandled_flashes:
        y, x = unhandled_flashes.pop(0)
        mask = np.zeros(octopi.shape, dtype=int)
        mask[y, x] = 1
        flash_filter = convolve2d(mask, FLASH_KERNEL, mode='same').astype(bool)
        octopi[flash_filter] += 1
        unhandled_flashes += [i for i in zip(*np.nonzero(octopi == 10)) if i not in unhandled_flashes]
    octopi[octopi > 9] = 0
    return octopi

def flashes_after_steps(steps, octopi, verbose=False):
    flashes = 0
    for n in range(steps):
        octopi = next_step(octopi)
        flashes += np.count_nonzero(octopi == 0)
        if verbose:
            print()
            print(octopi)
    return flashes

def first_mega_flash(octopi, max_steps=1000):
    for n in range(1, max_steps + 1):
        octopi = next_step(octopi)
        if (octopi == 0).all():
            return n
    raise Exception(f"No mega flash detected after {max_steps} steps")        

In [None]:
# Part 1 - Test
octopi = parse_input(test_input_1)
assert flashes_after_steps(100, octopi) == 1656

In [None]:
# Part 1
octopi = parse_input(input_1)
flashes_after_steps(100, octopi)

In [None]:
# Part 2 - Test
octopi = parse_input(test_input_1)
assert first_mega_flash(octopi) == 195

In [None]:
# Part 2
octopi = parse_input(input_1)
first_mega_flash(octopi)