In [79]:
import doctest
import string
from collections import defaultdict
from dataclasses import dataclass
from typing import List, Iterable, Tuple
import re


In [88]:
NUMBER_RE = re.compile(r'(\d+)')

def is_symbolic(c: str):
  """
  >>> is_symbolic('.')
  False
  >>> is_symbolic('4')
  False
  >>> is_symbolic('$')
  True
  >>> is_symbolic('#')
  True
  """
  return c != '.' and c not in string.digits


@dataclass(frozen=True)
class Coord:
  x: int
  y: int

  def __sub__(self, n) -> 'Coord':
    return Coord(self.x-n, self.y-n)

  def __add__(self, n) -> 'Coord':
    return Coord(self.x+n, self.y+n)


@dataclass(frozen=True)
class Rect:
  start: Coord
  end: Coord

  def expand(self, size: int) -> 'Rect':
    return Rect(self.start-size, self.end+size)

  def intersects(self, r: 'Rect') -> bool:
    start = Coord(max(self.start.x, r.start.x),
                  max(self.start.y, r.start.y))
    end = Coord(min(self.end.x, r.end.x),
                min(self.end.y, r.end.y))
    return end.x > start.x and end.y > start.y


@dataclass(frozen=True)
class PartNumber:
  start: Coord
  end: Coord
  number: int

  @classmethod
  def find_in(cls, s: str, y: int) -> Iterable['PartNumber']:
    """
    >>> list(PartNumber.find_in('..123...45.6', y=5))
    [PartNumber(start=Coord(x=2, y=5), end=Coord(x=5, y=6), number=123),
     PartNumber(start=Coord(x=8, y=5), end=Coord(x=10, y=6), number=45),
     PartNumber(start=Coord(x=11, y=5), end=Coord(x=12, y=6), number=6)]
    """
    for m in NUMBER_RE.finditer(s):
      start = Coord(m.start(), y)
      end = Coord(m.end(), y+1)
      yield PartNumber(start, end, int(m.group()))
    

@dataclass(frozen=True)
class Symbol:
  coord: Coord
  symbol: str

  def as_rect(self) -> 'Rect':
    return Rect(self.coord, self.coord+1)

@dataclass
class Schematic:
  grid: List[List[str]]

  @classmethod
  def parse_from(cls, grid_str) -> 'Schematic':
    return Schematic([l.strip() for l in grid_str.splitlines()])

  def symbols(self) -> Iterable[Symbol]:
    for y, row in enumerate(self.grid):
      for x, c in enumerate(row):
        if is_symbolic(c):
          yield Symbol(coord=Coord(x, y), symbol=c)
  
  def part_numbers(self) -> Iterable[Tuple[Symbol, PartNumber]]:
    symbols = list(self.symbols())
    sym_rects = [s.as_rect().expand(1) for s in symbols]
    for y, row in enumerate(self.grid):
      for pn in PartNumber.find_in(row, y):
        pn_rect = Rect(pn.start, pn.end)
        for sym, sym_rect in zip(symbols, sym_rects):
          if sym_rect.intersects(pn_rect):
            yield sym, pn

  def gear_ratio(self, sym: Symbol, pns: Iterable[PartNumber]) -> int:
    if sym.symbol != '*' or len(pns) != 2:
      return 0
    else:
      return pns[0].number * pns[1].number
  
  def gear_ratios(self) -> int:
    parts_by_gear = defaultdict(list)
    for sym, pn in self.part_numbers():
      if sym.symbol == '*':
        parts_by_gear[sym].append(pn)
    return sum(self.gear_ratio(sym, pns) for (sym, pns) in parts_by_gear.items())
      
  
def solution_1(schematic: Schematic):
  return sum(pn.number for (_, pn) in schematic.part_numbers())

def solution_2(schematic: Schematic):
  return schematic.gear_ratios()


In [89]:
doctest.testmod(verbose=False, report=True, exclude_empty=True, optionflags=doctest.NORMALIZE_WHITESPACE)


TestResults(failed=0, attempted=5)

In [90]:
test_input = """467..114..
...*......
..35..633.
......#...
617*......
.....+.58.
..592.....
......755.
...$.*....
.664.598.."""

schematic = Schematic.parse_from(test_input)
print(solution_1(schematic))
print(solution_2(schematic))


4361
467835


In [92]:
# Final answers
with open('../data/day03.txt') as f:
    schematic = Schematic.parse_from(f.read().strip())
    print('Part 1: ', solution_1(schematic))
    print('Part 2: ', solution_2(schematic))


Part 1:  540131
Part 2:  86879020
