In [1]:
from dataclasses import dataclass, field

# allow dataclass to be recursive
from __future__ import annotations

from typing import Optional

from collections_extended import setlist

import uuid
import itertools

import numpy as np

In [2]:
puzzle_input = open('inputs/18').read().strip()

In [3]:
test_input = '''[1,2]
[[1,2],3]
[9,[8,7]]
[[1,9],[8,5]]
[[[[1,2],[3,4]],[[5,6],[7,8]]],9]
[[[9,[3,8]],[[0,9],6]],[[[3,7],[4,9]],3]]
[[[[1,3],[5,3]],[[1,3],[8,7]]],[[[4,9],[6,9]],[[8,2],[7,3]]]]'''

In [4]:
test_input1 = '''[[[0,[5,8]],[[1,7],[9,6]]],[[4,[1,2]],[[1,4],2]]]
[[[5,[2,8]],4],[5,[[9,9],0]]]
[6,[[[6,2],[5,6]],[[7,6],[4,7]]]]
[[[6,[0,7]],[0,9]],[4,[9,[9,0]]]]
[[[7,[6,4]],[3,[1,3]]],[[[5,5],1],9]]
[[6,[[7,3],[3,2]]],[[[3,8],[5,7]],4]]
[[[[5,4],[7,7]],8],[[8,3],8]]
[[9,3],[[9,9],[6,[4,9]]]]
[[2,[[7,7],7]],[[5,8],[[9,3],[0,2]]]]
[[[[5,2],5],[8,[3,7]]],[[5,[7,5]],[4,4]]]'''

In [5]:
test_input2 = '''[[[0,[4,5]],[0,0]],[[[4,5],[2,6]],[9,5]]]
[7,[[[3,7],[4,3]],[[6,3],[8,8]]]]
[[2,[[0,8],[3,4]]],[[[6,7],1],[7,[1,6]]]]
[[[[2,4],7],[6,[0,5]]],[[[6,8],[2,8]],[[2,1],[4,5]]]]
[7,[5,[[3,8],[1,4]]]]
[[2,[2,2]],[8,[8,1]]]
[2,9]
[1,[[[9,3],9],[[9,0],[0,7]]]]
[[[5,[7,4]],7],1]
[[[[4,2],2],6],[8,7]]'''

In [6]:
test_input3 = '''[[[[4,3],4],4],[7,[[8,4],9]]]
[1,1]'''

In [7]:
test_input4 = '''[[[[4,0],[5,4]],[[7,7],[6,0]]],[[8,[7,7]],[[7,9],[5,0]]]]
[[2,[[0,8],[3,4]]],[[[6,7],1],[7,[1,6]]]]'''

In [8]:
@dataclass
class Node:
    reg: Optional[int] = None
    left: Optional[Node] = None
    right: Optional[Node] = None
    uuid: Optional[str] = field(default_factory=lambda: str(uuid.uuid4()))
        
class Tree:        
    def __init__(self, lst):
        self.root = Node()
        Tree.from_lst(lst, self.root)    

    @staticmethod
    def from_lst(lst, node):
        a, b = lst[0], lst[1]

        if isinstance(a, int):
            node.left = Node(a)
        else:
            node.left = Node()
            Tree.from_lst(a, node.left)

        if isinstance(b, int):
            node.right = Node(b)
        else:
            node.right = Node()
            Tree.from_lst(b, node.right)
            
    def to_lst(self):
        return self._to_lst(self.root)
    
    def _to_lst(self, node):
        if node.reg is not None:
            return node.reg
        
        return [self._to_lst(node.left), self._to_lst(node.right)]
    
    def add_tree(self, other):
        new_root = Node()
        
        new_root.left = self.root
        new_root.right = other.root
        
        self.root = new_root

    def reduce(self):
        while True:
            if self.search_explosions():
                # print("Exploded")
                # print(tree.to_lst())
                continue

            if self.split():
                # print("Split")
                # print(tree.to_lst())
                continue

            break

    def magnitude(self):
        return self._magnitude(self.root)
    
    def _magnitude(self, node):        
        if node.reg is not None:
            return node.reg
        else:
            return 3*self._magnitude(node.left) + 2*self._magnitude(node.right)
        
    def in_order(self):
        return self._in_order(self.root)
        
    def _in_order(self, node: Optional[Node]) -> None:
        if node:
            self._in_order(node.left)
            print(node.reg, end="->")
            self._in_order(node.right)
    
    def get_leaf_list(self):
        return self._get_leaf_list(self.root, [], {})
    
    def _get_leaf_list(self, node: Optional[Node], ll, d):        
        if node:
            if node.reg is not None:
                d[node.uuid] = node
                ll.append(node.uuid)

            self._get_leaf_list(node.left, ll, d)
            self._get_leaf_list(node.right, ll, d)
        
        return ll, d
    
    def next_leaf(self, leaf, right):
        assert leaf.reg is not None
        
        ll, d = self.get_leaf_list()
        i = ll.index(leaf.uuid)
                
        if right:
            return d[ll[i+1]] if i+1 < len(ll) else None
        else:
            return d[ll[i-1]] if i-1 >= 0 else None
    
    def search_explosions(self):
        return self._search_explosions(self.root, depth=1)
    
    def explode_pair(self, node):
        next_left = self.next_leaf(node.left, right=False)
        if next_left:
            next_left.reg += node.left.reg

        next_right = self.next_leaf(node.right, right=True)
        if next_right:
            next_right.reg += node.right.reg

        node.left = None
        node.right = None
        node.reg = 0

    def _search_explosions(self, node: Node, depth):
        if node is None:
            return False
        
        if depth == 4:
            if node.left and node.right:
                if node.left.reg is None:
                    self.explode_pair(node.left)
                    return True
                elif node.right.reg is None:
                    self.explode_pair(node.right)
                    return True
        else:
            if self._search_explosions(node.left, depth+1):
                return True
            
            if self._search_explosions(node.right, depth+1):
                return True
        
        return False

    def split(self):
        return self._split(self.root)

    def _split(self, node: Node):
        if node is None:
            return False

        if node.reg is not None and node.reg >= 10:
            a = node.reg // 2
            b = node.reg - a
            
            node.reg = None
            node.left = Node(a)
            node.right = Node(b)

            return True
        
        if self._split(node.left):
            return True
        
        if self._split(node.right):
            return True

In [9]:
def p1(puzzle_input):
    snailfish_numbers = list(map(eval, puzzle_input.strip().split('\n')))
    
    lst = snailfish_numbers.pop(0)
    tree = Tree(lst)
    
    for lst in snailfish_numbers:
        tree.add_tree(Tree(lst))
        tree.reduce()

    return tree

In [10]:
assert p1(test_input2).to_lst() == [[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]

In [11]:
assert p1(test_input3).to_lst() == [[[[0,7],4],[[7,8],[6,0]]],[8,1]]

In [12]:
assert p1('''[1,1]
[2,2]
[3,3]
[4,4]''').to_lst() == [[[[1,1],[2,2]],[3,3]],[4,4]]

In [13]:
assert p1('''[1,1]
[2,2]
[3,3]
[4,4]
[5,5]''').to_lst() == [[[[3,0],[5,3]],[4,4]],[5,5]]

In [14]:
assert p1('''[1,1]
[2,2]
[3,3]
[4,4]
[5,5]
[6,6]''').to_lst() == [[[[5,0],[7,4]],[5,5]],[6,6]]

In [15]:
t = Tree([[[[[9,8],1],2],3],4])
t.search_explosions()
assert t.to_lst() == [[[[0,9],2],3],4]

In [16]:
t = Tree([7,[6,[5,[4,[3,2]]]]])
t.search_explosions()
assert  t.to_lst() == [7,[6,[5,[7,0]]]]

In [17]:
t = Tree([[6,[5,[4,[3,2]]]],1])
t.search_explosions()
assert t.to_lst() == [[6,[5,[7,0]]],3]

In [18]:
t = Tree([[3,[2,[1,[7,3]]]],[6,[5,[4,[3,2]]]]])
t.search_explosions()
assert  t.to_lst() == [[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]]

In [19]:
t = Tree([[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]])
t.search_explosions()
assert t.to_lst() == [[3,[2,[8,0]]],[9,[5,[7,0]]]]

In [20]:
assert p1(test_input1).magnitude() == 4140

In [21]:
assert p1(puzzle_input).magnitude() == 3884

In [22]:
def p2(puzzle_input):
    snailfish_numbers = list(map(eval, puzzle_input.strip().split('\n')))
    
    m = -np.inf
    for (lst1, lst2) in itertools.permutations(snailfish_numbers, 2):
        tree = Tree(lst1)
        tree.add_tree(Tree(lst2))
        tree.reduce()
        
        m = max(m, tree.magnitude())

    return m

In [23]:
assert p2(puzzle_input) == 4595