In [1]:
SAMPLE_TEXT = """
[[[0,[5,8]],[[1,7],[9,6]]],[[4,[1,2]],[[1,4],2]]]
[[[5,[2,8]],4],[5,[[9,9],0]]]
[6,[[[6,2],[5,6]],[[7,6],[4,7]]]]
[[[6,[0,7]],[0,9]],[4,[9,[9,0]]]]
[[[7,[6,4]],[3,[1,3]]],[[[5,5],1],9]]
[[6,[[7,3],[3,2]]],[[[3,8],[5,7]],4]]
[[[[5,4],[7,7]],8],[[8,3],8]]
[[9,3],[[9,9],[6,[4,9]]]]
[[2,[[7,7],7]],[[5,8],[[9,3],[0,2]]]]
[[[[5,2],5],[8,[3,7]]],[[5,[7,5]],[4,4]]]
"""

In [2]:
import json

In [5]:
def tokenize_line(line):
    return json.loads(line)

def parse_text(raw_text):
    return [tokenize_line(l) for l in raw_text.split("\n") if l]

def read_input():
    with open("input.txt", "rt") as f:
        return f.read()

In [46]:
from dataclasses import dataclass
from typing import Union, List, Optional
from treelib import Node, Tree

In [144]:
class PairNode:
    def __init__(self):
        self.left: Optional["PairNode"] = None
        self.parent: Optional["PairNode"] = None
        self.right: Optional["PairNode"] = None
        self.value: Optional["int"] = None

    def is_value(self):
        return self.value is not None

    def is_pair(self):
        return self.left is not None

    def could_split(self):
        return self.is_value() and self.value > 9

    def do_split(self):
        left_value = self.value // 2
        right_value = self.value - left_value
        self.value = None

        self.left = PairNode()
        self.left.value = left_value
        self.left.parent = self

        self.right = PairNode()
        self.right.value = right_value
        self.right.parent = self

    def depth(self) -> int:
        depth = 0
        parent = self.parent
        while parent:
            depth += 1
            parent = parent.parent
        return depth

    def could_explode(self):
        return self.is_pair() and self.depth() >= 4

    def do_explode(self):
        value_left = self.find_value_left()
        value_right = self.find_value_right()
        if value_left:
            value_left.value += self.left.value
        if value_right:
            value_right.value += self.right.value
        self.left.parent = None
        self.right.parent = None
        self.left = None
        self.right = None
        self.value = 0

    def find_leftmost_value(self, node: "PairNode"):
        if node.is_value():
            return node
        else:
            return self.find_leftmost_value(node.left)

    def find_rightmost_value(self, node: "PairNode"):
        if node.is_value():
            return node
        else:
            return self.find_rightmost_value(node.right)

    def find_value_right(self):
        node_from = self
        parent = self.parent
        while parent:
            # We need to find a node that expands to the right
            if parent.left == node_from:
                return self.find_leftmost_value(parent.right)
            node_from = parent
            parent = parent.parent
        return None

    def find_value_left(self):
        node_from = self
        parent = self.parent
        while parent:
            # We need to find a node that expands to the left
            if parent.right == node_from:
                return self.find_rightmost_value(parent.left)
            node_from = parent
            parent = parent.parent
        return None

    def pprint(self, tree: Tree):
        if self.parent:
            parent = id(self.parent)
        else:
            parent = None
        if self.is_value():
            tree.create_node(str(self.value), id(self), parent=parent)
        else:
            tree.create_node(f"[ depth={self.depth()}", id(self), parent=parent)
            self.left.pprint(tree)
            self.right.pprint(tree)

    def __str__(self):
        if self.is_value():
            return str(self.value)
        else:
            return f"[{self.left},{self.right}]"

    def __repr__(self):
        if self.is_value():
            return str(self.value)
        else:
            return f"Pair({self.left}, {self.right})"

    def magnitude(self):
        if self.is_value():
            return self.value
        else:
            return 3 * self.left.magnitude() + 2 * self.right.magnitude()

@dataclass
class PairTree:
    root: PairNode

    def pprint(self):
        tree = Tree()
        self.root.pprint(tree)
        print(tree)

    def reduce(self):
        # print(self.root)
        while True:
            explosions = self.find_explosions(self.root)
            if explosions:
                explosions[0].do_explode()
                # print(self.root)
                continue
            splits = self.find_splits(self.root)
            if splits:
                splits[0].do_split()
                # print(self.root)
                continue
            break

    def find_explosions(self, node: PairNode) -> List[PairNode]:
        if not isinstance(node, PairNode):
            return []

        result = []
        result.extend(self.find_explosions(node.left))
        if node.could_explode():
            result.append(node)
        result.extend(self.find_explosions(node.right))
        return result

    def find_splits(self, node: PairNode) -> List[PairNode]:
        if not isinstance(node, PairNode):
            return []

        result = []
        result.extend(self.find_splits(node.left))
        if node.could_split():
            result.append(node)
        result.extend(self.find_splits(node.right))
        return result

    def add(self, other: "PairTree"):
        new_root = PairNode()
        new_root.left = self.root
        new_root.left.parent = new_root

        new_root.right = other.root
        new_root.right.parent = new_root
        self.root = new_root

    def magnitude(self):
        return self.root.magnitude()

def is_list(arg):
    return isinstance(arg, list)

def get_pairs(pairs, parent):

    left_value, right_value = pairs
    left = PairNode()
    parent.left = left
    left.parent = parent
    if isinstance(left_value, int):
        left.value = left_value
    else:
        get_pairs(left_value, left)

    right = PairNode()
    parent.right = right
    right.parent = parent
    if isinstance(right_value, int):
        right.value = right_value
    else:
        get_pairs(right_value, right)

def build_tree(line):
    root = PairNode()
    get_pairs(line, root)
    return PairTree(root)

In [145]:
# Explosion tests

def verify_reduction(input, expected):
    tree = build_tree(parse_text(input)[0])
    tree.reduce()
    assert str(tree.root) == expected

verify_reduction("[[[[[9,8],1],2],3],4]", "[[[[0,9],2],3],4]")
verify_reduction("[7,[6,[5,[4,[3,2]]]]]", "[7,[6,[5,[7,0]]]]")
verify_reduction("[[6,[5,[4,[3,2]]]],1]", "[[6,[5,[7,0]]],3]")
verify_reduction("[[3,[2,[1,[7,3]]]],[6,[5,[4,[3,2]]]]]", "[[3,[2,[8,0]]],[9,[5,[7,0]]]]")
verify_reduction("[[[[[4,3],4],4],[7,[[8,4],9]]],[1,1]]", "[[[[0,7],4],[[7,8],[6,0]]],[8,1]]")

In [151]:
def verify_addition(*args):
    expected = args[-1]
    tree = build_tree(parse_text(args[0])[0])
    for a in args[1:-1]:
        other = build_tree(parse_text(a)[0])
        # print(tree.root, "+", other.root)
        tree.add(other)
        tree.reduce()
    assert str(tree.root) == expected

verify_addition("[[[[4,3],4],4],[7,[[8,4],9]]]", "[1,1]", "[[[[0,7],4],[[7,8],[6,0]]],[8,1]]")

In [152]:
verify_addition("[1,1]", "[2,2]", "[3,3]", "[4,4]", "[[[[1,1],[2,2]],[3,3]],[4,4]]")

In [153]:
verify_addition("[1,1]", "[2,2]", "[3,3]", "[4,4]", "[5,5]", "[[[[3,0],[5,3]],[4,4]],[5,5]]")

In [154]:
verify_addition("[1,1]", "[2,2]", "[3,3]", "[4,4]", "[5,5]", "[6,6]", "[[[[5,0],[7,4]],[5,5]],[6,6]]")

In [155]:
SAMPLE_ADDITION = """
[[[0,[4,5]],[0,0]],[[[4,5],[2,6]],[9,5]]]
[7,[[[3,7],[4,3]],[[6,3],[8,8]]]]
[[2,[[0,8],[3,4]]],[[[6,7],1],[7,[1,6]]]]
[[[[2,4],7],[6,[0,5]]],[[[6,8],[2,8]],[[2,1],[4,5]]]]
[7,[5,[[3,8],[1,4]]]]
[[2,[2,2]],[8,[8,1]]]
[2,9]
[1,[[[9,3],9],[[9,0],[0,7]]]]
[[[5,[7,4]],7],1]
[[[[4,2],2],6],[8,7]]
"""
lines = parse_text(SAMPLE_ADDITION)
tree = build_tree(lines[0])
for l in lines[1:]:
    other = build_tree(l)
    tree.add(other)
    tree.reduce()
assert str(tree.root) == "[[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]"

In [160]:
def verify_magnitude(input, expected):
    tree = build_tree(parse_text(input)[0])
    assert tree.magnitude() == expected

verify_magnitude("[9,1]", 29)
verify_magnitude("[[1,2],[[3,4],5]]", 143)
verify_magnitude("[[[[0,7],4],[[7,8],[6,0]]],[8,1]]", 1384)
verify_magnitude("[[[[1,1],[2,2]],[3,3]],[4,4]]", 445)
verify_magnitude("[[[[3,0],[5,3]],[4,4]],[5,5]]", 791)
verify_magnitude("[[[[5,0],[7,4]],[5,5]],[6,6]]", 1137)
verify_magnitude("[[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]", 3488)

In [162]:
SAMPLE_ADDITION_2 = """
[[[0,[5,8]],[[1,7],[9,6]]],[[4,[1,2]],[[1,4],2]]]
[[[5,[2,8]],4],[5,[[9,9],0]]]
[6,[[[6,2],[5,6]],[[7,6],[4,7]]]]
[[[6,[0,7]],[0,9]],[4,[9,[9,0]]]]
[[[7,[6,4]],[3,[1,3]]],[[[5,5],1],9]]
[[6,[[7,3],[3,2]]],[[[3,8],[5,7]],4]]
[[[[5,4],[7,7]],8],[[8,3],8]]
[[9,3],[[9,9],[6,[4,9]]]]
[[2,[[7,7],7]],[[5,8],[[9,3],[0,2]]]]
[[[[5,2],5],[8,[3,7]]],[[5,[7,5]],[4,4]]]
"""
lines = parse_text(SAMPLE_ADDITION_2)
tree = build_tree(lines[0])
for l in lines[1:]:
    other = build_tree(l)
    tree.add(other)
    tree.reduce()
assert str(tree.root) == "[[[[6,6],[7,6]],[[7,7],[7,0]]],[[[7,7],[7,7]],[[7,8],[9,9]]]]"
assert tree.magnitude() == 4140

In [163]:
lines = parse_text(read_input())
tree = build_tree(lines[0])
for l in lines[1:]:
    other = build_tree(l)
    tree.add(other)
    tree.reduce()
tree.magnitude()

3216

In [177]:
def largest_magnitude(lines):

    def get_magnitude(line1, line2):
        tree = build_tree(line1)
        tree.add(build_tree(line2))
        tree.reduce()
        return tree.magnitude()

    max_magnitude = 0
    for line1 in lines:
        for line2 in lines:
            if line1 == line2:
                continue
            max_magnitude = max(max_magnitude, get_magnitude(line1, line2), get_magnitude(line2, line1))
    return max_magnitude

In [178]:
largest_magnitude(parse_text(SAMPLE_ADDITION_2))

3993

In [179]:
largest_magnitude(parse_text(read_input()))

4643