In [4]:
import doctest
from dataclasses import dataclass
from typing import Iterable, Optional
from itertools import count, islice, product

In [270]:
from collections import defaultdict
from dataclasses import field
from enum import Enum

class GuardState(Enum):
  OK = 1
  OOB = 2
  LOOP = 3


def parse_ints(s: str, sep: str = ',') -> list[int]:
  return [int(n) for n in s.split(sep)]


_ROTATIONS = {'^': '>', '>': 'v', 'v': '<', '<': '^'}
_STARTING_DIR = '^'

@dataclass
class ProblemInput:
  grid: list[list[str]]
  height: int
  width: int
  guard: tuple[str, int, int]
  visited: set[tuple[str, int, int]] = field(default_factory=lambda: set())
  
  def __post_init__(self):
    self.mark_visited()

  @classmethod
  def parse_input(cls, input: str) -> 'ProblemInput':
    grid = [list(l) for l in input.strip().split('\n')]
    height = len(grid)
    width = len(grid[0])
    for i in range(height):
      for j in range(width):
        if grid[i][j] == _STARTING_DIR:
          return ProblemInput(grid, height, width, (_STARTING_DIR, i, j))

  def is_oob(self, pos: tuple[int, int]) -> bool:
    i, j = pos
    return i < 0 or i >= self.height or j < 0 or j >= self.width

  def is_obstacle(self, pos: tuple[int, int]) -> bool:
    i, j = pos
    return not self.is_oob(pos) and self.grid[i][j] == '#'
  
  def will_loop(self, pos: tuple[int, int]) -> bool:
    return (self.guard[0], pos[0], pos[1]) in self.visited
  
  def rotate(self):
    dir, i, j = self.guard
    self.guard = (_ROTATIONS[dir], i, j)
    
  def next(self) -> bool:
    dir, i, j = self.guard
    if dir == '^':
      return i-1, j
    elif dir == '>':
      return i, j+1
    elif dir == 'v':
      return i+1, j
    elif dir == '<':
      return i, j-1
    
  def mark(self, pos: tuple[int, int], c: str):
    if self.is_oob(pos): return
    i, j = pos
    self.grid[i][j] = c
    
  def mark_visited(self):
    self.visited.add(self.guard)
    
  def move(self, pos: tuple[int, int]):
    self.guard = (self.guard[0],) + pos
    self.mark_visited()

  def step(self) -> GuardState:
    next_pos = self.next()
    if self.is_oob(next_pos):
      return GuardState.OOB
    elif self.will_loop(next_pos):
      return GuardState.LOOP
    
    if self.is_obstacle(next_pos):
      self.rotate()
    else:
      self.move(next_pos)
    return GuardState.OK

  def get_visited(self) -> int:
    return set(v[1:] for v in self.visited)


In [272]:
def part_1_solution(p: ProblemInput) -> int:
  # While we haven't hit a termination case.
  while p.step() == GuardState.OK:
    pass
  return len(p.get_visited())

def part_2_solution(p: ProblemInput) -> int:
  start = p.guard
  def copy():
    return ProblemInput(p.grid, p.height, p.width, start)
  
  # Solve once to find visited states.
  pp = copy()
  while pp.step() == GuardState.OK:
    pass
  visited = pp.get_visited()
  
  # Mark all visited stats as obstructions and look for loops.
  looped = 0
  for pos in visited:
    if pos == start[1:]: continue
    pp = copy()
    # Mark as obstacle.
    pp.mark(pos, '#')
    while (state := pp.step()) == GuardState.OK:
      pass
    if state == GuardState.LOOP:
      # print('FOUND!', pos)
      looped += 1
    # Revert.
    pp.mark(pos, '.')
  return looped
  

In [273]:
doctest.testmod(verbose=False, report=True, exclude_empty=True, optionflags=doctest.NORMALIZE_WHITESPACE)

TestResults(failed=0, attempted=0)

In [274]:
problem = ProblemInput.parse_input(test_input)
part_1_solution(problem)
# print(problem.step())
# print(problem.count_visited())

41

In [275]:
test_input = """....#.....
.........#
..........
..#.......
.......#..
..........
.#..^.....
........#.
#.........
......#..."""

problem = ProblemInput.parse_input(test_input)
assert part_1_solution(problem) == 41, "p1 test failed"
problem = ProblemInput.parse_input(test_input)
assert part_2_solution(problem) == 6, "p2 test failed"

In [279]:
%%time
# Final answers
with open('inputs/day06.txt') as f:
    input = f.read().strip()
    problem = ProblemInput.parse_input(input)
    print('Part 1: ', part_1_solution(problem))

    problem = ProblemInput.parse_input(input)
    print('Part 2: ', part_2_solution(problem))

Part 1:  5516
Part 2:  2008
CPU times: user 11.3 s, sys: 0 ns, total: 11.3 s
Wall time: 11.3 s
