# --- Day 11: Dumbo Octopus --- 

https://adventofcode.com/2021/day/11

## Get Input Data

In [1]:
import numpy as np

In [2]:
small_test_data = '11111\n19991\n19191\n19991\n11111'
small_test_energy_levels = np.matrix([[int(x) for x in line] for line in small_test_data.split('\n')])
small_test_energy_levels

matrix([[1, 1, 1, 1, 1],
        [1, 9, 9, 9, 1],
        [1, 9, 1, 9, 1],
        [1, 9, 9, 9, 1],
        [1, 1, 1, 1, 1]])

In [3]:
with open('../inputs/test_energy_levels.txt') as file:
    test_energy_levels = np.matrix([[int(x) for x in line.strip()] for line in file.readlines()])

In [4]:
with open('../inputs/energy_levels.txt') as file:
    energy_levels = np.matrix([[int(x) for x in line.strip()] for line in file.readlines()])

## Part 1
---

In [5]:
def get_neighbors(point, max_n):
    """Return a list of all good neighbors, vertically, horizontally, and diagonally.
    max_n = size of square
    """

    i, j = point[0], point[1]
    
    neighbor_i = np.array([i+1, i-1, i+0, i+0, i+1, i-1, i+1, i-1], dtype='int64')
    neighbor_j = np.array([j+0, j+0, j+1, j-1, j+1, j-1, j-1, j+1], dtype='int64')

    # Reverse through range of len(neighbor_i) and remove from both
    # neighbor_i and neighbor_j any i, j pair that is outside permitted range (0:max_n)
    for pos in range(len(neighbor_i)-1, -1, -1):
        if not 0 <= neighbor_i[pos] <= max_n or not 0 <= neighbor_j[pos] <= max_n:
            neighbor_i = np.delete(neighbor_i, pos)
            neighbor_j = np.delete(neighbor_j, pos)

    return [neighbor_i, neighbor_j]

In [6]:
def find_flashes(points, grid):
    """Recursively count the number of flashes in a step."""

    num_flashes = 0

    # Stopping condition:
    # If there are no 9s in the grid[points], then this will be the 
    # last recursive call, so increment one more time, then return
    if not 9 in grid[tuple(points)]:
        grid[tuple(points)] += 1
        return num_flashes

    for i, j in zip(*points):
        grid[i, j] += 1

        if grid[i, j] == 10:
            num_flashes += 1
            neighbors = get_neighbors((i, j), len(grid)-1)
            num_flashes += find_flashes(neighbors, grid)

    return num_flashes

In [7]:
def run_steps(grid, n):
    """Run grid through n steps."""

    grid = grid.copy()

    num_flashes = 0
    for _ in range(n):
        
        points = np.where(grid >= 0)
        num_flashes += find_flashes(points, grid)

        # Reset energy levels > 9 back to 0
        grid = np.where(grid > 9, 0, grid)

    return num_flashes    

### Run on Test Data

In [8]:
run_steps(test_energy_levels, 10)  # Should return 204

204

In [9]:
run_steps(test_energy_levels, 100)  # Should return 1656

1656

### Run on Input Data

In [10]:
run_steps(energy_levels, 100)

1705

## Part 2
---

In [11]:
def run_steps_until_all_flash(grid):
    """Run grid through steps until all cells flash at the same time."""

    grid = grid.copy()

    step = 0
    while grid.sum() > 0:
        step += 1

        points = np.where(grid >= 0)
        _ = find_flashes(points, grid)

        # Reset energy levels > 9 back to 0
        grid = np.where(grid > 9, 0, grid)

    return step

### Run on Test Data

In [12]:
run_steps_until_all_flash(test_energy_levels)  # Should return 195

195

### Run on Input Data

In [13]:
run_steps_until_all_flash(energy_levels)

265