In [1]:
import doctest
from dataclasses import dataclass
from typing import Iterable, Optional
from itertools import count, islice, product

In [43]:
from collections import defaultdict
import itertools
from typing import TypeVar


T = TypeVar('T')
type Rules = dict[int, set[int]]


def parse_ints(s: str, sep: str = ',') -> list[int]:
  return [int(n) for n in s.split(sep)]


@dataclass
class ProblemInput:
  rules: Rules
  updates: list[list[int]]

  @classmethod
  def parse_input(cls, input: str) -> 'ProblemInput':
    raw_rules, raw_updates = input.split('\n\n')
    
    rules = defaultdict(lambda: set())
    for rule in raw_rules.splitlines():
      p, n = rule.split('|')
      rules[int(p)].add(int(n))
    
    updates = []
    for update in raw_updates.splitlines():
      updates.append(parse_ints(update))      
    
    return ProblemInput(rules, updates)


def is_correct(update: list[int], rules: Rules) -> bool:
  """
  >>> p = ProblemInput.parse_input("47|53\\n97|13\\n13|29\\n\\n")
  >>> is_correct([47, 53], p.rules)
  True
  >>> is_correct([13, 97], p.rules)
  False
  >>> is_correct([97, 13, 29], p.rules)
  True
  """
  return all(n in rules.get(p, {})
             for p, n in itertools.pairwise(update))


def make_correct(update: list[int], rules: Rules) -> list[int]:
  corrected = update[:]
  while not is_correct(corrected, rules):
    for i in range(len(corrected) - 1):
      p = corrected[i]
      n = corrected[i+1]
      if not is_correct([p, n], rules):
        if is_correct([n, p], rules):
          corrected[i] = n
          corrected[i+1] = p
  return corrected


def middle(l: list[T]) -> T:
  """
  >>> middle([3, 2, 9])
  2
  >>> middle([1])
  1
  """
  return l[len(l)//2]



In [51]:
def part_1_solution(p: ProblemInput) -> int:
  return sum(middle(u)
             for u in p.updates
             if is_correct(u, p.rules))

def part_2_solution(p: ProblemInput) -> int:
  return sum(middle(make_correct(u, p.rules))
             for u in p.updates
             if not is_correct(u, p.rules))

In [52]:
doctest.testmod(verbose=False, report=True, exclude_empty=True, optionflags=doctest.NORMALIZE_WHITESPACE)

TestResults(failed=0, attempted=6)

In [53]:
test_input = """47|53
97|13
97|61
97|47
75|29
61|13
75|53
29|13
97|29
53|29
61|53
97|53
61|29
47|13
75|47
97|75
47|61
75|61
47|29
75|13
53|13

75,47,61,53,29
97,61,53,29,13
75,29,13
75,97,47,61,53
61,13,29
97,13,75,29,47"""

problem = ProblemInput.parse_input(test_input)
assert part_1_solution(problem) == 143, "p1 test failed"
assert part_2_solution(problem) == 123, "p2 test failed"

In [54]:
%%time
# Final answers
with open('inputs/day05.txt') as f:
    problem = ProblemInput.parse_input(f.read().strip())
    print('Part 1: ', part_1_solution(problem))
    print('Part 2: ', part_2_solution(problem))

Part 1:  5762
Part 2:  4130
CPU times: user 19 ms, sys: 0 ns, total: 19 ms
Wall time: 29.3 ms
