In [244]:
from pathlib import Path
from functools import cache

In [245]:
class LookUp:
    def __init__(self, D : dict):
        self._D = dict(D)

    def __getitem__(self, k):
        return self._D[k]

    def __hash__(self):
        return hash(tuple(self._D.items()))

In [246]:
@cache
def apply(D : LookUp, l, op=None, r=None):
    if op is None:
        return int(l)

    l = apply(D,*D[l])
    r = apply(D,*D[r])
    if op == '+':
        return l + r
    elif op == '-':
        return l - r
    elif op == '/':
        return int(l / r)
    elif op == '*':
        return l * r
    elif op == '=':
        return l == r
    else:
        raise Exception(f'{op=} not known')

In [247]:
def parse_line(l):
    k, v = l.split(':')
    v = tuple(v.split())
    return (k,v)

def read(prefix='data'):
    data = Path(f'{prefix}/21.txt').read_text().rstrip().split('\n')
    return dict(parse_line(l) for l in data)

In [248]:
D = LookUp(read('test'))
apply(D, *D['root'])

152

In [249]:
D = LookUp(read())
apply(D, *D['root'])

82225382988628

In [250]:
@cache
def humncount(D : LookUp, l, op=None, r=None):
    tot = 0
    if op is None:
        return 0 
        
    if l == 'humn':
        tot += 1
    else:
        lc = humncount(D,*D[l])
        if lc > 0:
            tot +=1

    if r == 'humn':
        tot += 1
    else:
        rc = humncount(D,*D[r])
        if rc > 0:
            tot +=1

    return tot

In [251]:
@cache
def simplify(D : LookUp, l, op=None, r=None):
    tot = 0

    if op is None:
        return (l, op, r)
        
    if (l != 'humn') and humncount(D, *D[l]) == 0:
        l = apply(D,*D[l])

    if (r != 'humn') and humncount(D, *D[r]) == 0:
        r = apply(D,*D[r])

    return (l, op, r)

In [252]:
def follow(D):
    k = 'root'
    (l, op, r) = D['root']
    L = [(l,op, r)]
    while l != None:
        k = r if str(l).isdigit() else l
        (l,op, r) = D[k]
        L.append(D[k])

    return L

In [253]:
def invert(L):
    val = 0
    assert L[-1] == (None, None, None)
    for (l, op, r) in L[:-1]:
        left_is_var = str(l).islower()
        o = int(r if left_is_var else l)
        if op == '+':
            val -= o
        elif op == '-':
            if left_is_var:
                val = val + o
            else:
                val = o - val
        elif op == '/':
            if left_is_var:
                val = val * o
            else:
                val = o / val
        elif op == '*':
            val /= o
        elif op == '=':
            val = o
        else:
            raise Exception(f'{op=} not known')

        val = int(val)

    return val

In [254]:
def pt2(D):
    D['root'] = (D['root'][0], '=', D['root'][2])
    D['humn'] = (None, None, None)
    DL = LookUp(D)
    assert humncount(DL, *DL['root']) == 1
    S= {k : simplify(DL, *DL[k]) for k in D}
    humn = invert(follow(S))

    D['humn'] = (humn, None, None)
    DL = LookUp(D)
    assert apply(LookUp(D), *D['root'])
    return humn

In [255]:
pt2(read('test'))

301

In [256]:
pt2(read())

3429411069028