## --- Day 18: Snailfish ---

Add up all of the snailfish numbers from the homework assignment in the order they appear. What is the magnitude of the final sum?

In [1]:
from copy import deepcopy
from math import ceil, floor

class SnailfishNumber:
    """A tree representing a single Snailfish number"""
    def __init__(self, pairs_list):
        self.root = self.Node(pairs_list)

    def __iadd__(self, other):
        if type(other) == SnailfishNumber:
            new_root = self.Node()
            other_root = deepcopy(other).root
            self.root.parent = new_root
            new_root.left = self.root
            other_root.parent = new_root
            new_root.right = other_root
            self.root = new_root
            self.reduce()
            return self
        else:
            raise TypeError(f"Incompatible type {type(other)}")

    def __add__(self, other):
        if type(other) == SnailfishNumber:
            left = deepcopy(self)
            left += other
            return left
        else:
            raise TypeError(f"Incompatible type {type(other)}")


    def reduce(self, verbose=False):
        last_action = 1

        while last_action is not None:
            # Try exploding first
            last_action = self.explode()
            if not last_action:
                # If no explosions, try splitting
                last_action = self.split()
                if last_action and verbose:
                    print(f"After split {self.root}")
            elif verbose:
                print(f"After explode {self.root}")

    def explode(self):
        explodable_node = self.search_explodable()
        
        if explodable_node is None:
            return None
        
        prev_node = self.find_prev_value(explodable_node)
        next_node = self.find_next_value(explodable_node)

        if prev_node:
            prev_node += explodable_node.left.value

        if next_node:
            next_node += explodable_node.right.value

        explodable_node.left = None
        explodable_node.right = None
        explodable_node.value = 0

        return explodable_node


    def search_explodable(self):
        # find the left-most explodable node
        def _search(level, node):
            if node.value is not None:
                # Reached a value node (leaf), stop searching
                return None
            elif level == 4:
                # This node needs to be exploded, return it
                return node

            # Otherwise, recursively search the children
            left_result = _search(level+1, node.left)
            return left_result if left_result is not None else _search(level+1, node.right)

        return _search(0, self.root)

    def find_next_value(self, node):
        # Find the next leaf to the right of node
        curr = node
        right = None

        # Go back up the tree until curr is the left child
        while curr.parent is not None and right is None:
            if curr == curr.parent.left:
                right = curr.parent.right
            else:
                curr = curr.parent

        # If no right subtree found, return
        if right is None:
            return None

        # Now find the left-most leaf in the right subtree
        curr = right
        while curr.value is None:
            curr = curr.left

        return curr
        
    def find_prev_value(self, node):
        # Find the next leaf to the left of node
        curr = node
        left = None

        # Go back up the tree until curr is the right child
        while curr.parent is not None and left is None:
            if curr == curr.parent.right:
                left = curr.parent.left
            else:
                curr = curr.parent

        # If no left subtree found, return
        if left is None:
            return None

        # Now find the right-most leaf in the right subtree
        curr = left
        while curr.value is None:
            curr = curr.right

        return curr

    def split(self):
        splittable_node = self.search_splittable()

        if splittable_node is None:
            return None

        splittable_node.left = self.Node(floor(splittable_node.value/2), parent=splittable_node)
        splittable_node.right = self.Node(ceil(splittable_node.value/2), parent=splittable_node)
        splittable_node.value = None

        return splittable_node       

    def search_splittable(self):
        # find the left-most splittable node
        def _search(node):
            if node.value is not None and node.value >= 10:
                # This node is splittable
                return node
            elif node.value is None:
                # This node is a pair, search recursively
                left_result = _search(node.left)
                return left_result if left_result is not None else _search(node.right)

            return None

        return _search(self.root)

    def magnitude(self):
        """Recursively calculate magnitude, ie 3*left + 2*right"""
        def _mag(node):
            if node.value is not None:
                return node.value
            else:
                return 3 * _mag(node.left) + 2 * _mag(node.right)
        
        return _mag(self.root)

    def __str__(self):
        return str(self.root)

    class Node:
        """A node is either a "pair" with no value and exactly 2 children, or an int value, in which case it has no children."""
        def __init__(self, value=None, parent=None):
            self.parent = parent
            if type(value) is list and len(value) == 2:
                self.value = None
                self.left = SnailfishNumber.Node(value[0], parent=self)
                self.right = SnailfishNumber.Node(value[1], parent=self)
            elif type(value) is int:
                self.value = value
                self.left = None
                self.right = None              
            elif value is not None:
                raise TypeError(f"Node can only contain list of length 2 or int, not {value}")
            else:
                self.value = None
                self.left = None
                self.right = None

        def __str__(self):
            if self.value is not None:
                return str(self.value)
            else:
                return f"[{self.left},{self.right}]"

        def __iadd__(self, value):
            if self.value is not None:
                self.value += value
            else:
                raise TypeError(f"Can not add {type(value)} to Node with no value")



In [2]:
# Explode tests
ext1 = SnailfishNumber([[[[[9,8],1],2],3],4])
ext1.explode()
assert str(ext1.root) == "[[[[0,9],2],3],4]"

ext2 = SnailfishNumber([7,[6,[5,[4,[3,2]]]]])
ext2.explode()
assert str(ext2.root) == "[7,[6,[5,[7,0]]]]"

ext3 = SnailfishNumber([[6,[5,[4,[3,2]]]],1])
ext3.explode()
assert str(ext3.root) == "[[6,[5,[7,0]]],3]"

ext4 = SnailfishNumber([[3,[2,[1,[7,3]]]],[6,[5,[4,[3,2]]]]])
ext4.explode()
assert str(ext4.root) == "[[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]]"
ext4.explode()
assert str(ext4.root) == "[[3,[2,[8,0]]],[9,[5,[7,0]]]]"

In [3]:
# Split test
spt1 = SnailfishNumber([[[[0,7],4],[15,[0,13]]],[1,1]])
spt1.split()
assert str(spt1.root) == "[[[[0,7],4],[[7,8],[0,13]]],[1,1]]"

In [4]:
# Reduce test
rt = SnailfishNumber([[[[[4,3],4],4],[7,[[8,4],9]]],[1,1]])
rt.reduce(verbose=True)
assert str(rt.root) == "[[[[0,7],4],[[7,8],[6,0]]],[8,1]]"

After explode [[[[0,7],4],[7,[[8,4],9]]],[1,1]]
After explode [[[[0,7],4],[15,[0,13]]],[1,1]]
After split [[[[0,7],4],[[7,8],[0,13]]],[1,1]]
After split [[[[0,7],4],[[7,8],[0,[6,7]]]],[1,1]]
After explode [[[[0,7],4],[[7,8],[6,0]]],[8,1]]


In [5]:
# Addition test
left = SnailfishNumber([[[[4,3],4],4],[7,[[8,4],9]]])
right = SnailfishNumber([1,1])
left += right
assert str(left.root) == "[[[[0,7],4],[[7,8],[6,0]]],[8,1]]"

In [6]:
# Magnitude tests
assert SnailfishNumber([9,1]).magnitude() == 29
assert SnailfishNumber([1,9]).magnitude() == 21
assert SnailfishNumber([[9,1],[1,9]]).magnitude() == 129

In [7]:
import json

class SnailfishCalculator:
    def __init__(self, sf_nums):
        self.sf_nums = sf_nums

    @classmethod
    def fromfile(cls, filename):
        sf_nums = []
        with open(filename) as f:
            for line in f.readlines():
                sf_nums.append(SnailfishNumber(json.loads(line)))
        
        return SnailfishCalculator(sf_nums)

    @classmethod
    def fromstr(cls, num_str):
        sf_nums = []
        lines = num_str.split()
        for line in lines:
            sf_nums.append(SnailfishNumber(json.loads(line)))

        return SnailfishCalculator(sf_nums)

    def sum(self):
        result = None
        for num in self.sf_nums:
            if result is not None:
                result += num
            else:
                result = deepcopy(num)

        return result

    def max_magnitude(self):
        max_magnitude = 0
        for left_num in self.sf_nums:
            for right_num in self.sf_nums:
                if right_num is not left_num:
                    magnitude = (left_num + right_num).magnitude()
                    max_magnitude = max(magnitude, max_magnitude) 

        return max_magnitude
        
    def print(self):
        for num in self.sf_nums:
            print(num.root)



In [8]:
# Examples
ex1_input = """
[1,1]
[2,2]
[3,3]
[4,4]
[5,5]
[6,6]
"""
ex1sc = SnailfishCalculator.fromstr(ex1_input)
assert str(ex1sc.sum()) == "[[[[5,0],[7,4]],[5,5]],[6,6]]"

ex2sc = SnailfishCalculator.fromfile("./inputs/Day18ex.txt")
ex2sum = ex2sc.sum()
assert str(ex2sum) == "[[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]"
assert ex2sum.magnitude() == 3488

ex3sc = SnailfishCalculator.fromfile("./inputs/Day18ex2.txt")
ex3sum = ex3sc.sum()
assert str(ex3sum) == "[[[[6,6],[7,6]],[[7,7],[7,0]]],[[[7,7],[7,7]],[[7,8],[9,9]]]]"
assert ex3sum.magnitude() == 4140

In [9]:
# Part 1 solution
part1_sc = SnailfishCalculator.fromfile("./inputs/Day18.txt")
part1_sum = part1_sc.sum()
print(part1_sum.magnitude())

4017


## Part 2

What is the largest magnitude of any sum of two different snailfish numbers from the homework assignment?

In [10]:
ex3sc.print()
ex3sc.max_magnitude()


[[[0,[5,8]],[[1,7],[9,6]]],[[4,[1,2]],[[1,4],2]]]
[[[5,[2,8]],4],[5,[[9,9],0]]]
[6,[[[6,2],[5,6]],[[7,6],[4,7]]]]
[[[6,[0,7]],[0,9]],[4,[9,[9,0]]]]
[[[7,[6,4]],[3,[1,3]]],[[[5,5],1],9]]
[[6,[[7,3],[3,2]]],[[[3,8],[5,7]],4]]
[[[[5,4],[7,7]],8],[[8,3],8]]
[[9,3],[[9,9],[6,[4,9]]]]
[[2,[[7,7],7]],[[5,8],[[9,3],[0,2]]]]
[[[[5,2],5],[8,[3,7]]],[[5,[7,5]],[4,4]]]


3993

In [11]:
# Part 2 solution
part1_sc.max_magnitude()

4583