In [1]:
import collections
from typing import List, Union

In [2]:
class SnailNumber:
    
    def __init__(self, left: Union[int, 'SnailNumber'], right: [int, 'SnailNumber']):
        self.left = left
        self.right = right
        self.parent = None
        
        if type(self.left) != int:
            self.left.parent = self
        if type(self.right) != int:
            self.right.parent = self
        
    def __repr__(self):
        return(f"[{self.left},{self.right}]")
    
    def __eq__(self, other):
        return self.__repr__() == other.__repr__()
    
    def __add__(self, other):
        snum_sum = SnailNumber(self, other)
        snum_sum.reduce()
        return snum_sum
    
    @classmethod
    def from_raw(cls, tokens: str):
        def find_mid_comma(tokens, start: int) -> int:
            num_open_brackets = 0
            for i, token in enumerate(tokens[start:]):
                if token == "[":
                    num_open_brackets += 1
                elif token == "]":
                    num_open_brackets -= 1
                elif token == "," and num_open_brackets == 0:
                    return start + i
            raise Exception("Expected to find a comma!")

        def inner_parse(tokens: str, start: int, end: int) -> Union[int, 'SnailNumber']:
            """
            Parses tokens in indices [start, end).
            """
            if tokens[start] != "[":
                return int("".join(tokens[start:end]))

            assert tokens[-1] == "]"

            comma_idx = find_mid_comma(tokens, start+1)

            left = inner_parse(tokens, start+1, comma_idx)
            right = inner_parse(tokens, comma_idx+1, end-1)

            return cls(left, right)

        return inner_parse(tokens, 0, len(tokens))
    
    def walk(self):
        if type(self.left) == int:
            yield self.left
        else:
            yield from self.left.walk()
        
        if type(self.right) == int:
            yield self.right
        else:
            yield from self.right.walk()
        
    def reduce(self):
        exploding = True
        splitting = True
        
        while exploding or splitting:
            exploding = self.explode_all()
            splitting = self.split()  # only do it once!

    def explode_all(self):
        exploded = False
        
        # tuple: (level of tree, node)
        to_explore = [(0, self)]
        while to_explore:
            level, node = to_explore.pop()
            
            if type(node) == SnailNumber and level == 4:
                node.explode()
                exploded = True
            
            if type(node.right) == SnailNumber:
                to_explore.append((level+1, node.right))
            if type(node.left) == SnailNumber:
                to_explore.append((level+1, node.left))
                
        return exploded

    def explode(self):
        # First, just do some quick checks to make sure this is explodable
        assert type(self.left) == int
        assert type(self.right) == int
        
        # Replace the regular numbers to the left and right.
        self.explode_left()
        self.explode_right()
        
        # Now replace itself
        if self.parent.left is self:
            self.parent.left = 0
        elif self.parent.right is self:
            self.parent.right = 0
        else:
            raise Exception("You made mistake!")
                
    def explode_left(self):
        # Find first regular number to the left
        prev_node = self
        curr_node = self.parent

        while curr_node:
            if curr_node.left is prev_node:
                # We need to keep searching up
                curr_node, prev_node = curr_node.parent, curr_node
            
            elif curr_node.right is prev_node:
                if type(curr_node.left) == int:
                    # We've found the first regular number to the left!
                    curr_node.left += self.left
                    return
                
                else: 
                    # We now need to go down the branch to the left
                    curr_node, prev_node = curr_node.left, curr_node
            
            elif curr_node.parent is prev_node:
                if type(curr_node.right) == int:
                    # We've found the first regular number to the left!
                    curr_node.right += self.left
                    return
                
                else:
                    # We need to keep traveling down to the right
                    curr_node, prev_node = curr_node.right, curr_node
            
            else:
                raise Exception("Oops? Didn't expect to get here.")
                
        # If we get here, it means that there's no regular number to the left.

    def explode_right(self):
        # Find first regular number to the right
        prev_node = self
        curr_node = self.parent

        while curr_node:
            if curr_node.right is prev_node:
                # We need to keep searching up
                curr_node, prev_node = curr_node.parent, curr_node
            
            elif curr_node.left is prev_node:
                if type(curr_node.right) == int:
                    # We've found the first regular number to the right!
                    curr_node.right += self.right
                    return
                
                else: 
                    # We now need to go down the branch to the right
                    curr_node, prev_node = curr_node.right, curr_node
            
            elif curr_node.parent is prev_node:
                if type(curr_node.left) == int:
                    # We've found the first regular number to the right!
                    curr_node.left += self.right
                    return
                
                else:
                    # We need to keep traveling down to the left
                    curr_node, prev_node = curr_node.left, curr_node
            
            else:
                raise Exception("Oops? Didn't expect to get here.")
                
        # If we get here, it means that there's no regular number to the right.
        return
    
    def split(self):        
        # tuple: (node, node's parent)
        to_explore = [(self, None)]
        while to_explore:
            node, parent = to_explore.pop()
            
            if type(node) == SnailNumber:
                to_explore.append((node.right, node))
                to_explore.append((node.left, node))
            
            else:
                if node >= 10:
                    parent.split_child(node)
                    return True
                    
        return False
    
    def split_child(self, child: int):
        if self.left == child:
            self.left = SnailNumber(self.left // 2, (self.left + 1) // 2)
            self.left.parent = self
            return
            
        elif self.right == child:
            self.right = SnailNumber(self.right // 2, (self.right + 1) // 2)
            self.right.parent = self
            return
            
        else:
            raise Exception("Oops? Can't find the child you want to split...")
            
    def magnitude(self):
        left_mag = 3*self.left if type(self.left) == int else 3*self.left.magnitude()
        right_mag = 2*self.right if type(self.right) == int else 2*self.right.magnitude()
        
        return left_mag + right_mag

# Test Cases

In [3]:
def sum_snail_numbers(snums: List[SnailNumber]) -> SnailNumber:
    result = snums[0]
    for snum in snums[1:]:
        result += snum
    return result

def parse_homework_assignment(input_filename: str) -> List[SnailNumber]:
    with open(input_filename) as input_file:
        return [SnailNumber.from_raw(line.strip()) for line in input_file.readlines()]

In [4]:
def test_simple_add():
    print("...test_simple_add...")
    a = SnailNumber.from_raw('[1,2]')
    b = SnailNumber.from_raw('[[3,4],5]')
    assert a + b == SnailNumber.from_raw('[[1,2],[[3,4],5]]')


def test_explode_all():
    print('...test_explode_all...')
    snum = SnailNumber.from_raw('[[[[0,7],4],[7,[[8,4],9]]],[1,1]]')
    snum.explode_all()
    assert snum == SnailNumber.from_raw('[[[[0,7],4],[15,[0,13]]],[1,1]]')
    
    snum = SnailNumber.from_raw('[[3,[2,[1,[7,3]]]],[6,[5,[4,[3,2]]]]]')
    snum.explode_all()
    assert snum == SnailNumber.from_raw('[[3,[2,[8,0]]],[9,[5,[7,0]]]]')


def test_reduce():
    print('...test_reduce...')
    test_cases = [
        # (unreduced snail number, expected result)
        
        # Explosion-only cases
        ('[[[[[9,8],1],2],3],4]', '[[[[0,9],2],3],4]'),
        ('[7,[6,[5,[4,[3,2]]]]]', '[7,[6,[5,[7,0]]]]'),
        ('[[6,[5,[4,[3,2]]]],1]', '[[6,[5,[7,0]]],3]'),
        ('[[3,[2,[1,[7,3]]]],[6,[5,[4,[3,2]]]]]', '[[3,[2,[8,0]]],[9,[5,[7,0]]]]'),
        ('[[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]]', '[[3,[2,[8,0]]],[9,[5,[7,0]]]]'),
        
        # Split-only cases
        ('[11,1]', '[[5,6],1]'),
        ('[11,[12,13]]', '[[5,6],[[6,6],[6,7]]]'),
        
        # More complicated cases
        ('[[[[[4,3],4],4],[7,[[8,4],9]]],[1,1]]', '[[[[0,7],4],[[7,8],[6,0]]],[8,1]]'),
    ]
    for unreduced, expected in test_cases:
        snum = SnailNumber.from_raw(unreduced)
        snum.reduce()
        try:
            assert snum == SnailNumber.from_raw(expected)
        except AssertionError:
            print("Test case failed:", unreduced)
            print("- Expected:", expected)
            print("- Actual:  ", snum)


def test_small_sum_examples():
    print('...test_small_sum_examples...')
    test_cases = [
        (("[1,1]", "[2,2]", "[3,3]", "[4,4]"), '[[[[1,1],[2,2]],[3,3]],[4,4]]'),
        (("[1,1]", "[2,2]", "[3,3]", "[4,4]", "[5,5]"), '[[[[3,0],[5,3]],[4,4]],[5,5]]'),
        (("[1,1]", "[2,2]", "[3,3]", "[4,4]", "[5,5]", "[6,6]"), '[[[[5,0],[7,4]],[5,5]],[6,6]]'),
    ]
    
    for snums_to_sum, expected in test_cases:
        snums = [SnailNumber.from_raw(s) for s in (snums_to_sum)]
        actual = sum_snail_numbers(snums)
        try:
            assert actual == SnailNumber.from_raw(expected)
        except AssertionError:
            print("Test case failed:", snums_to_sum)
            print("- Expected:", expected)
            print("- Actual:  ", actual)

            
def test_larger_sum_example():
    print('...test_larger_sum_example...')
    snums = parse_homework_assignment('input-example-1.txt')
    actual = sum_snail_numbers(snums)
    assert actual == SnailNumber.from_raw('[[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]')


def test_magnitude():
    print('...test_magnitude...')
    test_cases = [
        ('[9,1]', 29),
        ('[1,9]', 21),
        ('[[9,1],[1,9]]', 129),
        ('[[1,2],[[3,4],5]]', 143),
        ('[[[[0,7],4],[[7,8],[6,0]]],[8,1]]', 1384),
        ('[[[[1,1],[2,2]],[3,3]],[4,4]]', 445),
        ('[[[[3,0],[5,3]],[4,4]],[5,5]]', 791),
        ('[[[[5,0],[7,4]],[5,5]],[6,6]]', 1137),
        ('[[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]', 3488),
    ]
    
    for s, expected in test_cases:
        snum = SnailNumber.from_raw(s)
        try:
            assert snum.magnitude() == expected
        except:
            print("Test case failed:", snum)
            print("- Expected:", expected)
            print("- Actual:  ", snum.magnitude())


def test_homework():
    print('...test_homework...')
    snums = parse_homework_assignment('input-example-2.txt')
    actual = sum_snail_numbers(snums)
    assert actual == SnailNumber.from_raw('[[[[6,6],[7,6]],[[7,7],[7,0]]],[[[7,7],[7,7]],[[7,8],[9,9]]]]')
    assert actual.magnitude() == 4140
    
def run_test_cases():
    test_simple_add()
    test_explode_all()
    test_reduce()
    test_small_sum_examples()
    test_larger_sum_example()
    test_magnitude()
    test_homework()
    print("Done with tests!")
    
    
run_test_cases()

...test_simple_add...
...test_explode_all...
...test_reduce...
...test_small_sum_examples...
...test_larger_sum_example...
...test_magnitude...
...test_homework...
Done with tests!


# Part 1

In [5]:
sum_snail_numbers(parse_homework_assignment('input.txt')).magnitude()

3884

# Part 2

In [6]:
homework = 'input.txt'

In [7]:
from itertools import permutations


with open(homework) as input_file:
    raw_snums = [line.strip() for line in input_file.readlines()]

    
def helper(a: str, b: str) -> int:
    """Adds two snail numbers together and gets magnitude."""
    snum = SnailNumber.from_raw(a) + SnailNumber.from_raw(b)
    return snum.magnitude()

max([helper(a,b) for a, b in permutations(raw_snums, 2)])

4595