In [22]:
import doctest
from dataclasses import dataclass
from typing import Iterable, Optional
from itertools import count, islice, product

In [28]:
from collections import defaultdict


type vec2d = tuple[int, int] 


@dataclass
class ProblemInput:
  grid: list[str]
  height: int
  width: int

  def is_oob(self, pos: vec2d) -> bool:
    i, j = pos
    return i < 0 or i >= self.height or j < 0 or j >= self.width
  
  def at(self, pos: vec2d, oob_char: str = '?') -> bool:
    return oob_char if self.is_oob(pos) else self.grid[pos[0]][pos[1]]
  
  @classmethod
  def parse_input(cls, input: str) -> 'ProblemInput':
    grid = input.splitlines()
    return ProblemInput(grid, height=len(grid), width=len(grid[0]))


DIRS = [
  (i, j) for i in (-1, 0, 1) for j in (-1, 0, 1)
]

DIAGONALS = [
  (i, j) for i in (-1, 1) for j in (-1, 1)
]


def step2d(pos: vec2d, dir: vec2d, n: Optional[int] = None) -> Iterable[vec2d]:
  i, j = pos
  di, dj = dir
  stepper = zip(count(i, di), count(j, dj))
  if n:
    stepper = islice(stepper, n)
  yield from stepper
    

def get_word(grid: ProblemInput, pos: vec2d, dir: vec2d, n: int) -> str:
  return ''.join(grid.at(idx) for idx in step2d(pos, dir, n))


def find_occurrences(grid: ProblemInput, word: str = 'XMAS', dirs: list[vec2d] = DIRS) -> Iterable[tuple[vec2d, vec2d]]:
  for pos in product(range(grid.height), range(grid.width)):
    for dir in dirs:
      if get_word(grid, pos, dir, len(word)) == word:
        yield (pos, dir)
        

def count_xmas_v1(*args, **kwargs) -> int:
  count = 0
  for _ in find_occurrences(*args, **kwargs):
    count += 1
  return count


def count_xmas_v2(grid: ProblemInput) -> int:
  by_center = defaultdict(int)
  for pos, dir in find_occurrences(grid, 'MAS', DIAGONALS):
    center = pos[0] + dir[0], pos[1] + dir[1]
    by_center[center] += 1
  return sum(1 for v in by_center.values() if v > 1)


In [29]:
doctest.testmod(verbose=False, report=True, exclude_empty=True, optionflags=doctest.NORMALIZE_WHITESPACE)

TestResults(failed=0, attempted=0)

In [30]:
test_input = """MMMSXXMASM
MSAMXMSMSA
AMXSXMAAMM
MSAMASMSMX
XMASAMXAMM
XXAMMXXAMA
SMSMSASXSS
SAXAMASAAA
MAMMMXMMMM
MXMXAXMASX"""

problem = ProblemInput.parse_input(test_input)
assert count_xmas_v1(problem) == 18, "p1 test failed"
assert count_xmas_v2(problem) == 9, "p2 test failed"

In [31]:
# Final answers
with open('inputs/day04.txt') as f:
    problem = ProblemInput.parse_input(f.read().strip())
    print('Part 1: ', count_xmas_v1(problem))
    print('Part 2: ', count_xmas_v2(problem))

Part 1:  2578
Part 2:  1972
