In [58]:
from termcolor import colored
from typing import List 

class Octopus:
  def __init__(self, x: int, y: int, initial_energy: int) -> None:
    self.has_flashed = False
    self.energy_level = initial_energy
    self.x = x
    self.y = y
  
  def increase_energy(self) -> None:
    self.energy_level += 1

  def flash(self) -> bool:
    if self.energy_level > 9 and not self.has_flashed:
      self.has_flashed = True
      return True
    else:
      return False

  def reset(self) -> None:
    if self.has_flashed:
      assert self.energy_level > 9, 'Invalid state!'
      self.has_flashed = False
      self.energy_level = 0

  def __str__(self) -> str:
    if self.has_flashed:
      return f'{colored(str(self.energy_level).ljust(3, " "), "yellow")}'
    else:
      return f'{str(self.energy_level).ljust(3, " ")}'

  def __repr__(self) -> str:
    return f'{self.__class__.__name__}(x={self.x},y={self.y},energy_level={self.energy_level},has_flashed={self.has_flashed})'


class Field:
  def __init__(self, octopus_grid: List[Octopus]) -> None:
    self.grid = octopus_grid
    self.width = len(self.grid[0])
    self.height = len(self.grid)
    self.octopus_count = self.width * self.height

  def __str__(self) -> str:
    str_repr = ''
    for line in self.grid:
      for octopus in line:
        str_repr += str(octopus)
      str_repr += '\n'
    return str_repr

  def increase_energy(self) -> None:
    for octopus in self.get_all_octopus():
      octopus.increase_energy()

  def flash_all(self) -> List[bool]:
    flash_status = []
    for octopus in self.get_all_octopus():
      if octopus.flash():
        flash_status.append(True)
        for adj_octo in self.get_adjacent(octopus):
          adj_octo.increase_energy()
      else:
        flash_status.append(False)
    return flash_status

  def reset_all(self) -> None:
    for octopus in self.get_all_octopus():
      octopus.reset()

  def get_flash_count(self) -> int:
    count = 0
    for octopus in self.get_all_octopus():
      if octopus.has_flashed:
        count += 1
    return count
  
  def get_adjacent(self, octopus: Octopus) -> List[Octopus]:
    adjacent = []
    x, y = octopus.x, octopus.y
    # left
    if x > 0:
      adjacent.append(self.get_octopus(x-1, y))
    # up left
    if x > 0 and y > 0:
      adjacent.append(self.get_octopus(x-1, y-1))
    # up
    if y > 0:
      adjacent.append(self.get_octopus(x, y-1))
    # up right
    if y > 0 and x < self.width - 1:
      adjacent.append(self.get_octopus(x+1, y-1))
    # right
    if x < self.width - 1:
      adjacent.append(self.get_octopus(x+1, y))
    # down right
    if x < self.width - 1 and y < self.height - 1:
      adjacent.append(self.get_octopus(x+1, y+1))
    # down
    if y < self.height - 1:
      adjacent.append(self.get_octopus(x, y+1))
    # down left
    if x > 0 and y < self.height - 1:
      adjacent.append(self.get_octopus(x-1, y+1))
    return adjacent


  def get_octopus(self, x, y) -> Octopus:
    return self.grid[y][x]

  def get_all_octopus(self) -> List[Octopus]:
    return sum(self.grid, [])

  @classmethod
  def from_input_lines(cls, lines: List[str]) -> 'Field':
    octopus_grid = []
    for y, line in enumerate(lines):
      octopus_line = []
      for x, value in enumerate(line):
        octopus_line.append(Octopus(x, y, int(value)))
      octopus_grid.append(octopus_line)
    return Field(octopus_grid)

      

In [59]:
with open('./input11') as input:
  lines = [line.strip() for line in input.readlines()]

field = Field.from_input_lines(lines)

def step(field: Field) -> int:
  field.increase_energy()
  still_flashing = any(field.flash_all())
  while still_flashing:
    still_flashing = any(field.flash_all())
  count = field.get_flash_count()
  field.reset_all()
  return count
  
flash_count = 0
for i in range(100):
  flash_count += step(field)

flash_count



1617

In [67]:
with open('./input11') as input:
  lines = [line.strip() for line in input.readlines()]

field = Field.from_input_lines(lines)

for i in range(500):
  current_flashes = step(field)
  if current_flashes == field.octopus_count:
    print(i+1)
    break




258
