In [6]:
from pathlib import Path
import os
import math
import ast

FOLDER = Path(os.path.dirname(os.path.realpath("__file__"))) / 'data'
in_file = 'day18.txt'

with open(FOLDER / in_file) as f:
    data = f.readlines()

In [19]:
class TreeNode():
    @classmethod
    def from_list(cls, n, parent=None):
        if isinstance(n, int):
            t = cls(val=n, parent=parent)
            return t
        else:
            T = cls(parent=parent)
            T.left = cls.from_list(n[0], parent=T)
            T.right = cls.from_list(n[1], parent=T)
        return T
    
    def __init__(self, val=None, left=None, right=None, parent=None):
        self.val = val
        self.left = left
        self.right = right
        self.parent = parent

    def is_leaf(self):
        return self.val is not None

    def to_zero(self):
        self.val = 0
        self.left = None
        self.right = None

    def magnitude(self):
        if self.is_leaf(): return self.val
        return 3 * self.left.magnitude()  + 2 * self.right.magnitude()
        
    def to_list(self):
        if self.is_leaf(): return self.val 
        return  [self.left.to_list(),self.right.to_list()]

    def post_add(self):
        explode = True
        split = True
        
        while explode:
            while explode:
                explode = self.explode()
            explode = self.split()
    
    def explode(self):
        "Returns True unless it makes if through without exploding"
        left_neighbor = None
        nodes = self.traverse()
        # Depth-first traversal will line up the nodes
        # While traversing keep track of the last node seen
        # Level four nodes will come in pairs, when you see one
        # grab the next. That will make the "exploding pair". 
        # The left-closest wil be in left-neighbor. To get right closest
        # then grab one more...unless you're at the edge of the tree
        # that will be the next right value
        for n in nodes:
            current, level = n
            
            if level == 4:             
                left = current
                # grab next from traversal
                right, level = next(nodes)
                right_neighbor = next(nodes, None)
                if left_neighbor:
                    left_neighbor.val += left.val
                if right_neighbor:
                    right_neighbor[0].val +=  right.val
                
                # the exploding pair's parent becomes zero
                left.parent.to_zero()
                
                # might be more to explode
                return True
            left_neighbor = current
        return False
    
    def split(self):
        '''Returns True if exploding is required'''
        for n in self.traverse():
            node, level = n
            if node.val > 9:
                node.left = TreeNode(val = math.floor(node.val / 2), parent=node)
                node.right = TreeNode(val = math.ceil(node.val / 2), parent=node)
                node.val = None
                return True
        return False
    
    def traverse(self, level=-1, child=None):
        if self.is_leaf():
            yield (self, level)
        else:
            yield from self.left.traverse(level + 1, self.left)
            yield from self.right.traverse(level + 1, self.right)
     
    def __repr__(self):
        if self.is_leaf():
            return f"{self.__class__.__name__}({self.val})"
        return f"{self.__class__.__name__}({self.left} | {self.right})"
        
    def __add__(self, other):
        t = TreeNode(left=self, right=other)
        self.parent = t
        other.parent = t
        t.post_add()
        return t
    
    def __radd__(self, other):
        t = TreeNode(left=other, right=self)
        self.parent = t
        other.parent = t
        t.post_add()
        return t


### Part One

In [26]:
raw_lists = [ast.literal_eval(line) for line in data]

t = TreeNode.from_list(raw_lists[0])
t = sum(map(TreeNode.from_list, raw_lists[1:]), t)

total_magnitude = t.magnitude()

print("Solution 1: ", total_magnitude)

Solution 1:  3359


### Part Two

In [21]:
from itertools import permutations

# well...the tree node code mutates the trees, so we can't reuse the trees here :(
solution = max(
    (TreeNode.from_list(a) + TreeNode.from_list(b)).magnitude()
    for a, b in permutations(raw_lists, r=2))

print("Solution 2: ", solution)


Solution 2:  4616
