In [1]:
from dataclasses import dataclass
from pathlib import Path
import re
from collections import deque
import math

In [2]:
test_input = """Monkey 0:
  Starting items: 79, 98
  Operation: new = old * 19
  Test: divisible by 23
    If true: throw to monkey 2
    If false: throw to monkey 3

Monkey 1:
  Starting items: 54, 65, 75, 74
  Operation: new = old + 6
  Test: divisible by 19
    If true: throw to monkey 2
    If false: throw to monkey 0

Monkey 2:
  Starting items: 79, 60, 97
  Operation: new = old * old
  Test: divisible by 13
    If true: throw to monkey 1
    If false: throw to monkey 3

Monkey 3:
  Starting items: 74
  Operation: new = old + 3
  Test: divisible by 17
    If true: throw to monkey 0
    If false: throw to monkey 1
"""

In [3]:
ITEMS_PATTERN = re.compile("Starting items: ((?:\d+(?:, )?)+)")
OPERATION_PATTERN = re.compile("Operation: new = (.*)")
TEST_PATTERN = re.compile("Test: divisible by (\d+)")
TRUE_PATTERN = re.compile("If true: throw to monkey (\d+)")
FALSE_PATTERN = re.compile("If false: throw to monkey (\d+)")


@dataclass
class Monkey:
    items: deque[int]
    operation: tuple
    test: int
    if_true: int
    if_false: int
    inspections: int = 0
        
    def inspect_items(self, crazy_bananas=None):
        if self.items:
            self.inspections += len(self.items)
        if crazy_bananas:
            self.items = [self._operation(item) for item in self.items]
            self.items = [item % crazy_bananas for item in self.items]
        else:
            self.items = [math.floor(self._operation(item) / 3) for item in self.items]

    def throw_items(self):
        throw_items = [
            (self.if_true if item % self.test == 0 else self.if_false, item)
            for item in self.items
        ]
        self.items = []
        return throw_items
    
    def receive_item(self, item):
        self.items.append(item)
        
    def _operation(self, item):
        match self.operation:
            case a, "+", b:
                a = item if a == "old" else int(a)
                b = item if b == "old" else int(b)
                return a + b
            case a, "*", b:
                a = item if a == "old" else int(a)
                b = item if b == "old" else int(b)
                return a * b
            case _:
                raise Exception()
    
    @classmethod
    def from_string(cls, monkey_string):
        if items_match := ITEMS_PATTERN.search(monkey_string):
            items = list(map(int, items_match.group(1).split(", ")))
            
        if items_match := OPERATION_PATTERN.search(monkey_string):
            operation = tuple(
                map(lambda e: int(e) if e.isnumeric() else e, items_match.group(1).split(" "))
            )
        
        if test_match := TEST_PATTERN.search(monkey_string):
            test = int(test_match.group(1))
        
        if true_match := TRUE_PATTERN.search(monkey_string):
            if_true = int(true_match.group(1))

        if false_match := FALSE_PATTERN.search(monkey_string):
            if_false = int(false_match.group(1))
        
        return cls(items, operation, test, if_true, if_false)

def parse_input(monkey_input):
    return [Monkey.from_string(monkey) for monkey in monkey_input.strip().split("\n\n")]

def do_round(monkeys, crazy_bananas=False):
    for monkey in monkeys:
        monkey.inspect_items(crazy_bananas=crazy_bananas)
        for other_monkey, item in monkey.throw_items():
            monkeys[other_monkey].receive_item(item)

monkeys = parse_input(test_input)
assert len(monkeys) == 4
assert monkeys[0].items == [79, 98]
assert monkeys[0].operation == ("old", "*", 19)
assert monkeys[0].test == 23
assert monkeys[0].if_true == 2
assert monkeys[0].if_false == 3

In [4]:
# Part 1 - test
monkeys = parse_input(test_input)
for n in range(20):
    do_round(monkeys)
top_monkeys = sorted(monkeys, key=lambda m: m.inspections, reverse=True)[0:2]
assert math.prod([monkey.inspections for monkey in top_monkeys]) == 10605

In [5]:
# Part 1
monkeys = parse_input(Path("input.txt").read_text())
for n in range(20):
    do_round(monkeys)
top_monkeys = sorted(monkeys, key=lambda m: m.inspections, reverse=True)[0:2]
print(math.prod([monkey.inspections for monkey in top_monkeys]))

100345


In [6]:
# Part 2 - test
monkeys = parse_input(test_input)
test_product = math.prod(m.test for m in monkeys)
for n in range(10000):
    do_round(monkeys, crazy_bananas=test_product)

top_monkeys = sorted(monkeys, key=lambda m: m.inspections, reverse=True)[0:2]
assert math.prod([monkey.inspections for monkey in top_monkeys]) == 2713310158

In [7]:
# Part 2
# Guesses: 18192789587 (low)

monkeys = parse_input(Path("input.txt").read_text())
test_product = math.prod(m.test for m in monkeys)
for n in range(10000):
    do_round(monkeys, crazy_bananas=test_product)
top_monkeys = sorted(monkeys, key=lambda m: m.inspections, reverse=True)[0:2]
print(math.prod([monkey.inspections for monkey in top_monkeys]))

28537348205
