In [None]:
test_input = """root: pppw + sjmn
dbpl: 5
cczh: sllz + lgvd
zczc: 2
ptdq: humn - dvpt
dvpt: 3
lfqf: 4
humn: 5
ljgn: 2
sjmn: drzm * dbpl
sllz: 4
pppw: cczh / lfqf
lgvd: ljgn * ptdq
drzm: hmdt - zczc
hmdt: 32"""

In [None]:
import operator
from collections import defaultdict

In [None]:
def parse(s):
    monkeys = {}

    for l in s.split("\n"):
        label, rest = l.split(": ")

        try:
            num = int(rest)
            monkeys[label] = num
        except Exception:
            oper1, op_code, oper2 = rest.split(" ")

            if op_code == "+":
                op = operator.add
            elif op_code == "*":
                op = operator.mul
            elif op_code == "-":
                op = operator.sub
            else:
                # is this right? double check they're always divisible
                op = operator.floordiv

            monkeys[label] = (oper1, oper2, op)
    return monkeys

In [46]:
# monkeys = parse(test_input)
monkeys = parse(open("inputs/21").read())

Do a DFS to get a topological sort.


Awkward because I need to reverse the edges. Currently it's from child to parent.

In [47]:
edges = defaultdict(list)

for label, instr in monkeys.items():
    if type(instr) == tuple:
        a, b, _ = instr

        edges[a].append(label)
        edges[b].append(label)

In [48]:
all_monkeys = set(monkeys.keys())

In [49]:
order = []

marks = set()

def visit(n):
    if n in marks:
        return

    # no temp mark stuff for cycles; not relevant here

    for m in edges[n]:
        visit(m)
    
    marks.add(n)
    order.append(n)

while unmarked := all_monkeys - marks:
    next = unmarked.pop()
    visit(next)

final_order = list(reversed(order))

In [50]:
eval_dict = {}

for label in final_order:
    instr = monkeys[label]

    if type(instr) == int:
        eval_dict[label] = instr
    else:
        a, b, op = instr

        eval_dict[label] = op(eval_dict[a], eval_dict[b])

In [51]:
eval_dict['root']

168502451381566

In [52]:
inv_ops = {
    operator.add: operator.sub,
    operator.sub: operator.add,
    operator.mul: operator.floordiv,
    operator.floordiv: operator.mul
}

In [60]:
from dataclasses import dataclass

class Expression:
    ...

@dataclass    
class Variable(Expression):
    def __repr__(self):
        return "x"

@dataclass
class Num(Expression):
    num: int

    def __repr__(self):
        return str(self.num)

@dataclass 
class Eq(Expression):
    left: Expression
    right: Expression

    def __repr__(self):
        return self.left.__repr__() + " = " + self.right.__repr__()

@dataclass
class Opr(Expression):
    opr: any
    left: Expression
    right: Expression

    def __repr__(self):
        match self.opr:
            case operator.add:
                op = "+"
            case operator.sub:
                op = "-"
            case operator.mul:
                op = "*"
            case operator.floordiv:
                op = "//"
        
        return "(" + self.left.__repr__() + f" {op} " + self.right.__repr__() + ")"


def reduce(expr):
    '''TBH there are probably some redundant cases here.
    Should have just done scipy rootfinding. Sigh.
    '''
    match expr:
        # merge these identity cases, obviously
        case Num(x):
            return Num(x)
        case Variable():
            return Variable()
        case Opr(op, Num(a), Num(b)):
            return Num(op(a, b))
        case Opr(op, Variable(), Num(n)):
            return expr
        case Opr(op, Num(n), Variable()):
            return expr
        case Opr(op, a, b):
            return Opr(op, reduce(a), reduce(b))
        case Eq(Variable(), Num(d)):
            return expr
        case Eq(Num(d), Variable()):
            return expr
        case Eq(x, Num(d)):
            match x:
                case Opr(op, a, Num(r)):
                    return Eq(a, Num(inv_ops[op](d, r)))
                case Opr(operator.add, Num(r), a):
                    return Eq(a, Num(inv_ops[operator.add](d, r)))
                case Opr(operator.mul, Num(r), a):
                    return Eq(a, Num(inv_ops[operator.mul](d, r)))
                case Opr(op, Num(r), a):
                    return Eq(Num(r), Opr(inv_ops[op], Num(d), a))
        case Eq(Num(d), x):
            return reduce(Eq(x, Num(d)))
        case Eq(x, y):
            return reduce(Eq(reduce(x), reduce(y)))
        case _:
            raise Exception(expr)

In [61]:
expression_dict = {}

for label in final_order:
    instr = monkeys[label]

    if label == "humn": 
        expression_dict[label] = Variable()
    elif type(instr) == int:
        expression_dict[label] = Num(instr)
    else:
        a, b, op = instr

        if label == "root":
            expression_dict[label] = Eq(expression_dict[a], expression_dict[b])
        else:
            expression_dict[label] = Opr(op, expression_dict[a], expression_dict[b])

In [62]:
expression_dict

{'gdqz': 3,
 'gwzr': 2,
 'swzv': 5,
 'mvzb': 2,
 'npfj': 1,
 'brjh': 4,
 'qnjm': 2,
 'fvqc': 6,
 'qrrt': 5,
 'bwtt': 2,
 'djcs': 3,
 'hsqh': 4,
 'hzzz': 2,
 'grdg': 7,
 'fzfg': 2,
 'vmcz': 6,
 'wmvl': 6,
 'vrbl': 8,
 'rgjs': 3,
 'zggm': 11,
 'mtgf': 4,
 'tdgs': 7,
 'lwtw': 2,
 'rljr': 3,
 'mplv': 2,
 'gzjp': 3,
 'npvh': 2,
 'wprl': 4,
 'vbwq': 5,
 'mtjr': 3,
 'fcbn': 5,
 'dttr': 11,
 'mlwg': 3,
 'vpfq': 3,
 'vlzq': 2,
 'gzll': 5,
 'gpdt': 3,
 'hcfp': 2,
 'sstb': 2,
 'mfpd': 4,
 'lbft': 3,
 'cnmp': 3,
 'fszt': 4,
 'nzfm': 5,
 'rnlm': 17,
 'brcc': 3,
 'humn': x,
 'hcwz': 5,
 'hjlz': 2,
 'bldq': 2,
 'rgzt': 8,
 'rpfb': 8,
 'lmts': 2,
 'jzql': 3,
 'tdpc': 5,
 'rgfg': 9,
 'zzsg': 2,
 'bscm': 3,
 'tpmw': 17,
 'rgpl': 9,
 'mgpp': 8,
 'spbq': 4,
 'bdvl': 5,
 'fvhf': 4,
 'vcpv': 2,
 'smvr': 1,
 'htvt': 3,
 'gvzz': 15,
 'zgrl': 1,
 'svrs': 4,
 'qtlb': 3,
 'vnrb': 2,
 'pgmp': 2,
 'nzfp': 3,
 'rlfn': 2,
 'vqlc': 3,
 'cbls': 3,
 'qrrv': 4,
 'gwzz': 4,
 'frmc': 6,
 'hwzz': 7,
 'fqbv': 2,
 'bbhs': 2,

In [63]:
current_result = expression_dict['root']

for i in range(100):
    next_result = reduce(current_result)

    if next_result == current_result:
        break

    current_result = next_result

In [64]:
current_result

x = 3343167719435