In [None]:
from typing import Callable, List, NamedTuple, Tuple, Iterable
from math import log, exp
import re


class Monkey(NamedTuple):
    items: List[int]
    operation: Callable[[int], int]
    test: Tuple[int, int, int]


start_monkey = re.compile(r"Monkey\s(\d+)")
operation_re = re.compile(r"([+*])\s+((\d+)|(old))")
find_number_re = re.compile(r"(\d+)")

def get_operation(m, part=1):
    if part != 1:
        if m.group(1) == "+":
            if m.group(2) == "old":
                return lambda logx: logx + log(2)
            else:
                return lambda logx, a=int(m.group(2)): log(exp(logx) + a)
        elif m.group(1) == "*":
            if m.group(2) == "old":
                return lambda x: 2*x
            else:
                return lambda x, a=log(int(m.group(2))): x + a

    
    if m.group(1) == "+":
        if m.group(2) == "old":
            return lambda x: x + x
        else:
            return lambda x, a=int(m.group(2)): x + a
    elif m.group(1) == "*":
        if m.group(2) == "old":
            return lambda x: x * x
        else:
            return lambda x, a=int(m.group(2)): x * a


def parse_monkey(lines: List[str], part=1):
    assert len(lines) == 6
    items = lines[1].split("Starting items: ")[-1]
    items = [(int(x) if part == 1 else log(int(x))) for x in items.split(", ")]
    operation = lines[2].split("Operation: new = old ")[-1]
    operation = operation_re.match(operation)
    operation = get_operation(operation, part=part)
    test = tuple(int(find_number_re.search(l).group(0)) for l in lines[3:6])
    return Monkey(items, operation, test)


def parse_monkeys(filename="input.txt", part=1):
    with open(filename) as f:
        lines = f.readlines()
    monkeys = []
    for i, l in enumerate(lines):
        if m := start_monkey.match(l):
            monkeys.append(parse_monkey([l.strip() for l in lines[i : i + 6]], part=part))
    return monkeys

In [None]:

def play_round(monkeys: List[Monkey], worry_div: int, stats: List[int]):
    inspected = []
    for im, m in enumerate(monkeys):
        for i in m.items:
            w = m.operation(i)
            w //= worry_div
            throw_to = m.test[1] if w % m.test[0] == 0 else m.test[2]
            monkeys[throw_to].items.append(w)
            stats[im] += 1
        m.items.clear()
    return inspected

def play_for(monkeys: List[Monkey], k: int, worry_div: int):
    stats = [0 for m in monkeys]
    for _ in range(k):
        play_round(monkeys, worry_div, stats)
        if _ % 100 == 0:
            print(_)
    stats_ = list(sorted(stats))
    print(stats_)
    print(stats_[-2]*stats_[-1])
print(play_for(parse_monkeys("test.txt"), 20, 3))
print(play_for(parse_monkeys("test.txt"), 20, 1))
print(play_for(parse_monkeys("input.txt"), 20, 3))
print(play_for(parse_monkeys("test.txt"), 1000, 1))


In [None]:
def is_divisable(logw, a):
    x = logw - log(a)
    xx = exp(x)
    print(xx)
    return abs(xx) % 1 < 1E-10

is_divisable(log(30), 7)

In [None]:


def play_round2(monkeys: List[Monkey], stats: List[int]):
    inspected = []
    for im, m in enumerate(monkeys):
        for i in m.items:
            # print(m)
            logw = m.operation(i)
            throw_to = m.test[1] if is_divisable(logw, m.test[0]) else m.test[2]
            monkeys[throw_to].items.append(logw)
            stats[im] += 1
        m.items.clear()
    return inspected

def play_for2(monkeys: List[Monkey], k: int):
    stats = [0 for m in monkeys]
    for _ in range(k):
        play_round2(monkeys, stats)
        if _ % 100 == 0:
            print(_)
    stats_ = list(sorted(stats))
    print(stats_)
    print(stats_[-2]*stats_[-1])
print(play_for2(parse_monkeys("test.txt", part=2), 20))
# print(play_for2(parse_monkeys("input.txt"), 20, 3))