In [None]:
import re
from copy import deepcopy

In [None]:
TEST = False

In [None]:
if TEST:
    filename = "data/input_11_test"
else:
    filename = "data/input_11"

In [None]:
with open(filename) as file:
    input_str = file.read()

In [None]:
input_lines = input_str.split("\n")
input_lines[:20]

In [None]:
monkeys = list(range(len(input_lines)//7))
monkeys

In [None]:
starting_items  = [list(map(int,re.findall(r"\d+",input_lines[i]))) for i in list(range(1,len(input_lines),7))]
starting_items

In [None]:
def get_op(op, var):
    match op, var:
        case '+','old':
            function = lambda x: x+x
        case '*','old':
            function = lambda x: x*x
        case '+',_:
            function = lambda x: x+int(var)
        case '*',_:
            function = lambda x: x*int(var)
    return function


In [None]:
operations = [get_op(*input_lines[i].split(' ')[-2:]) for i in list(range(2,len(input_lines),7))]

In [None]:
tests = [int(input_lines[i].split(' ')[-1]) for i in list(range(3,len(input_lines),7))]
tests

In [None]:
true_throw = [int(input_lines[i].split(' ')[-1]) for i in list(range(4,len(input_lines),7))]
print(true_throw)
false_throw = [int(input_lines[i].split(' ')[-1]) for i in list(range(5,len(input_lines),7))]
print(false_throw)

In [None]:
total_inspected_items = [0 for i in monkeys]
total_inspected_items

In [None]:
def inspect_item(item, monkey):
    
    op_result = operations[monkey](item)
    new_item = op_result//3
    if new_item%tests[monkey]==0:
        throw = true_throw[monkey]
    else:
        throw = false_throw[monkey]
    
    total_inspected_items[monkey] += 1
    
    return throw, new_item

In [None]:
def turn(items, monkey):
    items_to_inspect = items[monkey]
    for item in items_to_inspect:
        throw, new_item = inspect_item(item, monkey) # figure out where to throw item and what the new worryt score is
        items[throw].append(new_item) # Throw item with new worry score
    items[monkey] = [] # All items inspected and thrown, clear list
    return items

In [None]:
def round(items):
    for monkey in monkeys:
        items = turn(items,monkey)
    return items

In [None]:
def n_rounds(items,n):
    for n in range(n):
        items=round(items)
    return items

In [None]:
total_inspected_items = [0 for i in monkeys]
total_inspected_items

In [None]:
_ = n_rounds(deepcopy(starting_items),20)

In [None]:
total_inspected_items.sort()

In [None]:
total_inspected_items

In [None]:
monkey_business = total_inspected_items[-1]*total_inspected_items[-2]

In [None]:
TEST_ANSWER = 10605

In [None]:
if TEST:
    assert monkey_business == TEST_ANSWER
else: 
    print("Monkey Business is at {0}".format(monkey_business))

Part 2 - More worry
Because we no longer divide the worry by 3, our worry levels will cause memory overflow if we calculate them explicitly at evert turn and round. However, we don't need to do that - all we need to do is to keep track of when the new worry levels pass and fail monek tests.

All the tests check whether the worry level is divisible by a prime. Hence, if we know the result of the modulo operation by that prime of a worry level, we can work out whether the new worry level will be divisible by that prime:
 - for addition operations we know that x+y mod k == 0 is true when ((x mod k) + (y mod k)) mod k == 0.
 - for multiplication operations we know that x\*y mod k == 0 is true when x mod k == 0 OR y mod k == 0 (only true for k in prime numbers)
 
Hence for each item, instead if storing it's value, we only need to store the result of x mod k for every k that our monkeys use to test.

First we convert each starting item x into a list of numbers, each representing the result of x mod k for k in tests

In [None]:
starting_items

In [None]:
new_starting_items = [[[x % k for k in tests] for x in monkey_items] for monkey_items in starting_items]

In [None]:
new_starting_items

Next, we create a new set of operations, which return the modulo values instead of explicit item values for each item

In [None]:
def get_op_mod(op, var):
    match op, var:
        case '+','old':
            function = lambda x,k: ((x%k) + (x%k))%k
        case '*','old':
            function = lambda x,k: ((x%k) * (x%k))%k
        case '+',_:
            function = lambda x,k: ((x%k) + (int(var)%k))%k
        case '*',_:
            function = lambda x,k: ((x%k) * (int(var)%k))%k
    return function

In [None]:
operations_mod = [get_op_mod(*input_lines[i].split(' ')[-2:]) for i in list(range(2,len(input_lines),7))]

In [None]:
def inspect_item(item, monkey):
    
    new_item = [operations_mod[monkey](item[monkey_i],tests[monkey_i]) for monkey_i in monkeys]
    if new_item[monkey]==0:
        throw = true_throw[monkey]
    else:
        throw = false_throw[monkey]
    
    total_inspected_items[monkey] += 1
    
    return throw, new_item

In [None]:
total_inspected_items = [0 for i in monkeys]
total_inspected_items

In [None]:
_ = n_rounds(deepcopy(new_starting_items),10000)

In [None]:
total_inspected_items.sort()

In [None]:
total_inspected_items

In [None]:
monkey_business = total_inspected_items[-1]*total_inspected_items[-2]

In [None]:
TEST_ANSWER = 2713310158

In [None]:
if TEST:
    assert monkey_business == TEST_ANSWER
else: 
    print("Monkey Business is at {0}".format(monkey_business))