# Binary trees

- <https://adventofcode.com/2021/day/18>

The snailfish numbers are, in essence, [binary trees](https://en.wikipedia.org/wiki/Binary_tree#Methods_for_storing_binary_trees), and because each node in the tree always has exactly two children that are either other nodes or numbers (leaves), it is also a _full_ binary tree.

The two operations, exploding and splitting, are very similar to the kinds of operations that [self-balancing binary search trees](https://en.wikipedia.org/wiki/Self-balancing_binary_search_tree) perform whenever you insert or remove a node.

To 'explode' a node, there is a requirement to find the preceding and succeeding leaf nodes in [in-order traversal order](https://en.wikipedia.org/wiki/Tree_traversal#In-order,_LNR). I choose to implement the binary tree [using nodes and references](https://en.wikipedia.org/wiki/Binary_tree#Nodes_and_references), with a recursive `__iter__` method to handle the traversal. While processing explosions and splits, I track the preceding and following nodes as well, so we don't need parent pointers in the nodes; from those previous or next nodes you can find the parent in the subtree and replace it with a 0 leaf.


In [1]:
from __future__ import annotations
from copy import deepcopy
from enum import IntEnum
from functools import reduce
from itertools import chain, islice
from operator import add
from typing import ClassVar, Iterator, Final, Optional


class Dir(IntEnum):
    left = 0
    right = 1

    def __invert__(self) -> Dir:
        return Dir(1 - self)


LEFT: Final[Dir] = Dir.left
RIGHT: Final[Dir] = Dir.right
Depth = int


class SnailfishNumber:
    is_leaf: ClassVar[bool] = False

    def __init__(self, left: SnailfishNumber, right: SnailfishNumber) -> None:
        self.left = left
        self.right = right

    @classmethod
    def from_line(cls, line: str) -> SnailfishNumber:
        return cls._parse(line.encode())[0]

    @classmethod
    def _parse(cls, line: bytes, i: int = 0) -> tuple[SnailfishNumber, int]:
        if 0x30 <= line[i] <= 0x39:  # digits
            return Leaf(line[i] - 0x30), i + 1
        left, i = cls._parse(line, i + 1)
        right, i = cls._parse(line, i + 1)
        return SnailfishNumber(left, right), i + 1

    def __str__(self) -> str:
        return f"[{self.left},{self.right}]"

    def __getitem__(self, index: Dir) -> SnailfishNumber:
        """Dynamic access to the node left and right children"""
        return (self.left, self.right)[index]

    def __setitem__(self, index: Dir, value: SnailfishNumber) -> None:
        if index is LEFT:
            self.left = value
        else:
            self.right = value

    def __len__(self) -> int:
        return 2

    @property
    def magnitude(self) -> int:
        return 3 * self.left.magnitude + 2 * self.right.magnitude

    def explode(
        self, prev: Optional[SnailfishNumber], next: Optional[SnailfishNumber]
    ) -> None:
        assert self.left.is_leaf, self.right.is_leaf
        parent = pdir = None
        for node, ldir in (prev, LEFT), (next, RIGHT):
            if node is None:
                continue
            n = node[ldir]
            while not n.is_leaf:
                n = n[~ldir]
            n.value += self[ldir].value  # type: ignore
            if parent is None:
                parent, pdir = node[~ldir], ldir

        if parent is self:
            parent, pdir = next if prev is None else prev, ~pdir
        assert parent is not None
        while parent[pdir] is not self:
            parent = parent[pdir]
        parent[pdir] = Leaf()

    def split(self) -> bool:
        for d in (LEFT, RIGHT):
            if self[d].is_leaf and (v := self[d].value) >= 10:  # type: ignore
                self[d] = SnailfishNumber(Leaf(v // 2), Leaf((v + 1) // 2))
                return True
        return False

    def __iter__(self) -> Iterator[tuple[Depth, SnailfishNumber]]:
        """In-order traversal of nodes only, as (depth, node) tuples"""
        if not self.left.is_leaf:
            yield from ((depth + 1, n) for depth, n in self.left)
        yield 0, self
        if not self.right.is_leaf:
            yield from ((depth + 1, n) for depth, n in self.right)

    def __add__(self, other: SnailfishNumber) -> SnailfishNumber:
        new = SnailfishNumber(deepcopy(self), deepcopy(other))

        while True:
            # explodes
            prev, lookahead = None, chain(islice(new, 1, None), [(None, None)])
            for (depth, node), (_, next) in zip(new, lookahead):
                if depth == 4:
                    node.explode(prev, next)
                else:
                    prev = node
            # splits
            for _, node in new:
                if node.split():
                    break
            else:
                break

        return new


class Leaf(SnailfishNumber):
    is_leaf = True

    def __init__(self, value: int = 0) -> None:
        self.value = value

    def __str__(self) -> str:
        return str(self.value)

    @property
    def magnitude(self) -> int:
        return self.value


testlines = """\
[1,2]
[[1,2],3]
[9,[8,7]]
[[1,9],[8,5]]
[[[[1,2],[3,4]],[[5,6],[7,8]]],9]
[[[9,[3,8]],[[0,9],6]],[[[3,7],[4,9]],3]]
[[[[1,3],[5,3]],[[1,3],[8,7]]],[[[4,9],[6,9]],[[8,2],[7,3]]]]
""".splitlines()
for line in testlines:
    assert str(SnailfishNumber.from_line(line)) == line

assert (
    str(
        SnailfishNumber.from_line("[[[[4,3],4],4],[7,[[8,4],9]]]")
        + SnailfishNumber.from_line("[1,1]")
    )
    == "[[[[0,7],4],[[7,8],[6,0]]],[8,1]]"
)

testsums = {
    "[1,1]\n[2,2]\n[3,3]\n[4,4]": "[[[[1,1],[2,2]],[3,3]],[4,4]]",
    "[1,1]\n[2,2]\n[3,3]\n[4,4]\n[5,5]": "[[[[3,0],[5,3]],[4,4]],[5,5]]",
    "[1,1]\n[2,2]\n[3,3]\n[4,4]\n[5,5]\n[6,6]": "[[[[5,0],[7,4]],[5,5]],[6,6]]",
    (
        "[[[0,[4,5]],[0,0]],[[[4,5],[2,6]],[9,5]]]\n"
        "[7,[[[3,7],[4,3]],[[6,3],[8,8]]]]\n"
        "[[2,[[0,8],[3,4]]],[[[6,7],1],[7,[1,6]]]]\n"
        "[[[[2,4],7],[6,[0,5]]],[[[6,8],[2,8]],[[2,1],[4,5]]]]\n"
        "[7,[5,[[3,8],[1,4]]]]\n[[2,[2,2]],[8,[8,1]]]\n"
        "[2,9]\n[1,[[[9,3],9],[[9,0],[0,7]]]]￼￼\n[[[5,[7,4]],7],1]\n"
        "[[[[4,2],2],6],[8,7]]"
    ): "[[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]",
}
for lines, expected_str in testsums.items():
    nodes = map(SnailfishNumber.from_line, lines.splitlines())
    assert str(reduce(add, nodes)) == expected_str

testmagnitudes = {
    "[[1,2],[[3,4],5]]": 143,
    "[[[[0,7],4],[[7,8],[6,0]]],[8,1]]": 1384,
    "[[[[1,1],[2,2]],[3,3]],[4,4]]": 445,
    "[[[[3,0],[5,3]],[4,4]],[5,5]]": 791,
    "[[[[5,0],[7,4]],[5,5]],[6,6]]": 1137,
    "[[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]": 3488,
}
for testnum, expected_mag in testmagnitudes.items():
    assert SnailfishNumber.from_line(testnum).magnitude == expected_mag

testhomework = [
    SnailfishNumber.from_line(line)
    for line in """\
[[[0,[5,8]],[[1,7],[9,6]]],[[4,[1,2]],[[1,4],2]]]
[[[5,[2,8]],4],[5,[[9,9],0]]]
[6,[[[6,2],[5,6]],[[7,6],[4,7]]]]
[[[6,[0,7]],[0,9]],[4,[9,[9,0]]]]
[[[7,[6,4]],[3,[1,3]]],[[[5,5],1],9]]
[[6,[[7,3],[3,2]]],[[[3,8],[5,7]],4]]
[[[[5,4],[7,7]],8],[[8,3],8]]
[[9,3],[[9,9],[6,[4,9]]]]
[[2,[[7,7],7]],[[5,8],[[9,3],[0,2]]]]
[[[[5,2],5],[8,[3,7]]],[[5,[7,5]],[4,4]]]
""".splitlines()
]
testsum = reduce(add, testhomework)
assert str(testsum) == "[[[[6,6],[7,6]],[[7,7],[7,0]]],[[[7,7],[7,7]],[[7,8],[9,9]]]]"
assert testsum.magnitude == 4140


In [2]:
import aocd

homework = [
    SnailfishNumber.from_line(line)
    for line in aocd.get_data(day=18, year=2021).splitlines()
]
print("Part 1:", reduce(add, homework).magnitude)


Part 1: 4132


# Part 2

All we have to do for part 2 is loop over the permutations of the input snailfish numbers, to find the highest magnitude.

This is taking a bit more time for the puzzle input as each add operation can involve a large number of explosions and splits, which in turn require a lot of traversals.


In [3]:
from itertools import permutations

def maximize(numbers: list[SnailfishNumber]) -> int:
    return max((a + b).magnitude for a, b in permutations(numbers, 2))

assert maximize(testhomework) == 3993


In [4]:
print("Part 2:", maximize(homework))


Part 2: 4685
