In [51]:
from pathlib import Path

In [52]:
lines = Path("day18.txt").read_text().splitlines()

In [85]:
def reduce(tree):
    # repeatedly explode/split, always applying any available explode first
    while explode(tree) or split(tree):
        pass
    
    return magnitude(tree)

def explode(tree, stack=None, depth=0):
    stack = stack or []
    match tree:
        case [left, right] if depth == 4:
            # find the node to the left, or the rightmost cousin
            c = tree
            for p in stack[::-1]:
                if p[1] is c:
                    # if the node is directly to the left, update it
                    if isinstance(p[0], int):
                        p[0] += left
                    else:
                        # otherwise, descend on the right
                        p = p[0]
                        while isinstance(p[1], list):
                            p = p[1]
                        p[1] += left
                    break
                c = p

            # find the node to the right, or the leftmost cousin
            c = tree
            for p in stack[::-1]:
                if p[0] is c:
                    # if the node is directly to the right, update it
                    if isinstance(p[1], int):
                        p[1] += right
                    else:
                        # otherwise, descend on the left
                        p = p[1]
                        while isinstance(p[0], list):
                            p = p[0]
                        p[0] += right
                    break
                c = p
            
            # finally, replace the exploded node with 0
            if stack[-1][0] is tree:
                stack[-1][0] = 0
            else:
                stack[-1][1] = 0
            return True
        case [_, _]:
            # continue searching for nodes to explode
            return any(explode(n, stack + [tree], depth + 1) for n in tree)
        
    return False

def split(tree, stack=None):
    stack = stack or []
    
    match tree:
        case [left, right]:
            # continue searching for nodes to explode
            return any(split(n, stack + [tree]) for n in tree)
        case value if value >= 10:
            # split the current node and update the parent
            new = [value // 2, -(-value // 2)]
            if stack[-1][0] is tree:
                stack[-1][0] = new
            else:
                stack[-1][1] = new
            return True
        
    return False

def magnitude(tree):
    match tree:
        case [left, right]:
            return 3*magnitude(left) + 2*magnitude(right)
        case value:
            return value


## Part 1

In [86]:
trees = [eval(line) for line in lines]
total = trees[0]
for tree in trees[1:]:
    total = [total, tree]
    reduce(total)

magnitude(total)

4072

## Part 2

In [89]:
max(max(reduce([eval(lines[i]), eval(lines[j])]), 
        reduce([eval(lines[j]), eval(lines[i])]))
    for i in range(len(lines))
    for j in range(i + 1, len(lines)))

4483