In [1]:
from math import ceil, floor

In [2]:
data = open('../inputs/input-day18').read().splitlines()

Python can evaluate these as nested lists:

In [3]:
lines = [eval(l) for l in data]

Define a node class so I can go up and down using "parent", "left", "right". Also a nice "add" method to defined the sum a -> [a,b].

In [4]:
class Node():
    def __init__(self, value=None):
        self.value = value
        self.left = None
        self.right = None
        self.parent = None

    def __repr__(self):
        if self.value is not None:
            return str(self.value)
        else:
            return f"[{self.left}, {self.right}]"
        
    def add(self, right):
        root = Node()
        root.left = self
        root.right = right
        root.left.parent = root
        root.right.parent = root
        return root

Converting the nested lists into a tree structure using the Node class:

In [5]:
def build_nodes(line):
    if type(line) == int:
        return Node(line)
    
    root = Node()
    root.left = build_nodes(line[0])
    root.right = build_nodes(line[1])
    
    root.left.parent = root
    root.right.parent = root
    
    return root

2 functions to find the previous and next leaves; for example, to find the previous, go up until we're not the left sun, then go left and find the right-most leaf. To find the next, we exchange left-right. 

In [6]:
def find_prev(node):
    while node.parent:
        if node.parent.left != node:
            break
        else:
            node = node.parent
    node = node.parent
    if node is None:
        return None
    node = node.left
    while node.right:
        node = node.right
    return node

In [7]:
def find_next(node):
    while node.parent:
        if node.parent.right != node:
            break
        else:
            node = node.parent
    node = node.parent
    if node is None:
        return None
    node = node.right
    while node.left:
        node = node.left
    return node

### Split
This is the meat of the question: splitting is finding values >= 10 and replacing them with two leaves. The trick is propagating the fact that we did it somewhere in the tree so we only do it once, per the requirements. 

In [8]:
def split(node):
    if node.value is not None:
        if node.value > 9:
            left = floor(node.value / 2)
            right = ceil(node.value / 2)
            node.left = Node(left)
            node.right = Node(right)

            node.left.parent = node
            node.right.parent = node
            node.value = None
            return True
        else:
            return False
    else:
        if not split(node.left):
            return split(node.right)
        else:
            return True

### Explode
Here we make sure we're at level 5 (we can't be lower because this function is exhuastive and every addition increases the depth by 1). We use the auxiliary `find_prev`, `find_next` to do the addition. Finally we remove this node by going to the parent and replacing it with a numberic Node(0). 

In [9]:
def explode(node, depth=1):
    if depth == 5 and node.value is None:
        prev = find_prev(node)
        if prev:
            prev.value += node.left.value
        
        _next = find_next(node)
        if _next:
            _next.value += node.right.value

        new_node = Node(0)
        
        if node.parent.left == node:
            node.parent.left = new_node
        else:
            node.parent.right = new_node

        new_node.parent = node.parent
    
    else:
        if node.left:
            explode(node.left, depth + 1)
        if node.right:
            explode(node.right, depth + 1)

The reduce step is comprised of multiple steps where we explode everything and split once - this is how they were defined. To stop the while loop, I convert to string to see if the tree doesn't change anymore (trick, instead of defining tree equality). 

In [10]:
def reduce(root):
    number_s1 = ""
    number_s2 = ""
    
    while number_s1 != str(root):
        number_s1 = str(root)
        explode(root)
        # while number_s1 != str(root):
        #     number_s1 = str(root)
            # explode(root)
        split(root)
    
    return root

-----
# Task 1
Adding all the numbers, one after the other:

In [11]:
ans = build_nodes(lines[0])

for line in lines[1:]:
    line = build_nodes(line)
    ans = ans.add(line)
    ans = reduce(ans)

In [12]:
print(ans)

[[[[7, 7], [7, 8]], [[8, 7], [0, 7]]], [[[6, 6], [6, 7]], 6]]


Sum is defined recursively:

In [13]:
def final_sum(node):
    if node.value is not None:
        return node.value
    return 3*final_sum(node.left) + 2*final_sum(node.right)

In [14]:
final_sum(ans)

3411

----------
# Task 2
Checking all pairs is $O(n^2)$ and we brute force it. Takes me 6 seconds. 

In [17]:
big = 0
for i in range(len(lines)):
    for j in range(len(lines)):
        if i == j:
            continue 
        number = build_nodes(lines[i]).add(build_nodes(lines[j]))
        total = final_sum(reduce(number))
        if total > big:
            big = total

In [18]:
print(big)

4680
