In [1]:
import math
#from copy import deepcopy

In [2]:
class SnailfishBase:
    def __new__(cls, value, parent=None):
        submap = {int: SnailfishN, list: SnailfishP}
        subclass = submap[type(value)]
        return super(SnailfishBase, subclass).__new__(subclass)
        
    def __init__(self, parent=None):
        self.parent = parent
        self.lev = parent.lev + 1 if parent else 0
    
    def __add__(self, other):
        new_root = SnailfishP([self, other])
        # do I need to add non-roots? if so they should be children of the same parent
        # new_parent should get parent=old_parent
        self.parent = new_root
        other.parent = new_root
        self.level_up()
        other.level_up()
        new_root.reduce()
        return new_root
    
    def __radd__(self, other):
        if other == 0:
            return self
        else:
            raise TypeError
    
    def reduce(self):
        reduced = False
        while not(reduced):
            exploded = True
            while exploded:
                x = self.find_exploder()
                if x is not None:
                    x.explode()
                else:
                    exploded = False
            s = self.find_splitter()
            if s is not None:
                s.split()
            else:
                reduced = True
        

class SnailfishN(SnailfishBase):
    def __init__(self, value, parent=None):
        super().__init__(parent)
        self.value = value
    
    def level_up(self):
        self.lev += 1
    
    def get_list(self):
        return self.value
    
    def find_exploder(self):
        return None
    
    def find_rightmostN(self):
        return self
    
    def find_leftmostN(self):
        return self
    
    def split(self):
        values = [math.floor(self.value/2), math.ceil(self.value/2)]
        new_child = SnailfishP(values, parent = self.parent)
        if self.parent.left == self:
            self.parent.left = new_child
        else:
            self.parent.right = new_child
    
    def find_splitter(self):
        if self.value >= 10:
            return self
        return None
    
    def magnitude(self):
        return self.value


class SnailfishP(SnailfishBase):
    def __init__(self, value, parent=None):
        super().__init__(parent)
        if issubclass(type(value[0]),SnailfishBase):
            self.left = value[0]
            self.right = value[1]
        else:
            self.left = SnailfishBase(value[0], parent=self)
            self.right = SnailfishBase(value[1], parent=self)
            # self.reduce() ? #no need, the input is reduced
    
    def level_up(self):
        self.lev += 1
        self.left.level_up()
        self.right.level_up()
    
    def get_list(self):
        return [self.left.get_list(), self.right.get_list()]
    
    def find_exploder(self):
        if self.lev < 4:
            return self.left.find_exploder() or self.right.find_exploder()
        return self
    
    def find_rightmostN(self):
        return self.right.find_rightmostN()
    
    def find_leftmostN(self):
        return self.left.find_leftmostN()
    
    def find_regleft(self):
        # find the first regular number (SnailfishN) to the left of the current node
        # if there isn't one, return None
        if self.parent is None:
            return None
        # if the current node is a right-child, find the rightmost child of its left-sibling
        if self.parent.right == self:
            return self.parent.left.find_rightmostN()
        # otherwise, go up one level
        return self.parent.find_regleft()
    
    def find_regright(self):
        # find the first regular number (SnailfishN) to the right of the current node
        # if there isn't one, return None
        if self.parent is None:
            return None
        # if the current node is a left-child, find the leftmost child of its right-sibling
        if self.parent.left == self:
            return self.parent.right.find_leftmostN()
        # otherwise, go up one level
        return self.parent.find_regright()
    
    def explode(self):
        # add the values to left and righ neighbours
        ln, rn = self.find_regleft(), self.find_regright()
        if ln is not None:
            ln.value += self.left.value
        if rn is not None:
            rn.value += self.right.value
        # replace current node with SnailfishN-zero
        if self.parent.left == self:
            self.parent.left = SnailfishN(0, parent=self.parent)
        else:
            self.parent.right = SnailfishN(0, parent=self.parent)
    
    def find_splitter(self):
        return self.left.find_splitter() or self.right.find_splitter()
    
    def magnitude(self):
        return 3*self.left.magnitude() + 2*self.right.magnitude()

In [3]:
with open('day18.txt') as f:
    data = [(eval(x)) for x in f]
    
snailnumbers = [SnailfishBase(d) for d in data]

In [4]:
total = sum(SnailfishBase(d) for d in data)
total.get_list()

[[[[7, 7], [8, 8]], [[8, 9], [7, 8]]], [[[6, 7], [0, 8]], [8, 8]]]

In [5]:
total.magnitude()

3793

In [6]:
%%time
maxmag = 0
for s1 in data:
    for s2 in data:
        maxmag = max(maxmag, (SnailfishBase(s1)+SnailfishBase(s2)).magnitude())
        
print(maxmag)

4695
CPU times: user 3.15 s, sys: 3.64 ms, total: 3.15 s
Wall time: 3.15 s
