In [15]:
"""
https://adventofcode.com/2021/day/11
"""

from typing import Iterable, Tuple

RAW = """5483143223
2745854711
5264556173
6141336146
6357385478
4167524645
2176841721
6882881134
4846848554
5283751526"""


                

class School:
    """
    A group of Octopuses or Octopi...
    """
    
    def __init__(self, text):
        self.text=text
        self.parse()
        
        
    def get_nearest(self, i:int, j:int)->Iterable[Tuple[int, int]]:
        """
        Generates  all cardinal locations
        nearest to the Octopus
        """
        nearest = []
        for jj,ii  in [(-1,1), (0,1), (1,1),
                     (-1,0),        (1,0),
                     (-1,-1),(0,-1),(1,-1)
                    ]:
            if 0<=j+jj<self.nc and 0<=i+ii<self.nr:
                yield (i+ii, j+jj)
                
    def parse(self):
        self.lines = self.text.splitlines()
        self.nc=len(self.lines[0])
        self.nr=len(self.lines)
        self.grid = []
        for i in range(self.nr):
            row = []
            for j in range(self.nc):
                row.append(int(self.lines[i][j]))
            self.grid.append(row)
    
    def step(self):
        """
        Increment one step
        """       
        grid = self.grid.copy()
        for i in range(self.nr):
            for j in range(self.nc):
                grid[i][j]+=1
        
        has_flashed = set()
        while True:
            need_to_flash = [
                (i,j)
                for i in range(self.nr)
                for j in range(self.nc)
                if grid[i][j] > 9 and (i,j) not in has_flashed
            ]
            
            if not need_to_flash:
                self.grid = grid
                break
                
            for i, j in need_to_flash:
                has_flashed.add((i, j))
                for ii, jj in self.get_nearest(i, j):
                    grid[ii][jj] += 1
            
            # this part i got wrong
            # you cannot set as you check for need to flash
            for i, j in has_flashed:
                grid[i][j] = 0
        return len(has_flashed)
                    
def all_flashed(text:str)->int:
    """
    Returns first step where all octpuses
    flash simultaneously
    """
    school = School(text=text)
    step = 0
    while True:
        step+=1
        school.step()
        total_sum =sum(energy
                       for r in school.grid
                       for energy in r)
        if total_sum == 0:
            return step
        
        
    
                    

s=School(text=RAW)
assert sum(s.step() for _ in range(10)) == 204
s=School(text=RAW)
assert sum(s.step() for _ in range(100)) == 1656
assert all_flashed(text=RAW) == 195

with open('inputs/day11.txt') as f:
    puzzle = f.read()
    s = School(text=puzzle)
    print('p1', sum(s.step() for _ in range(100)))
    print('p2', all_flashed(text=puzzle))
    

p1 1667
p2 488
