# Binary trees

- <https://adventofcode.com/2021/day/18>


In [1]:
from __future__ import annotations
from dataclasses import dataclass
from itertools import chain
from functools import reduce
from operator import add
from string import digits
from typing import Iterator, Optional, Union


@dataclass
class SnailfishNumber:
    left: Optional[SnailfishNumber]
    right: Optional[SnailfishNumber]
    parent: Optional[SnailfishNumber] = None

    @classmethod
    def from_line(cls, line: str) -> SnailfishNumber:
        return cls._parse(line.encode())[0]

    @classmethod
    def _parse(cls, line: bytes, i: int = 0) -> tuple[SnailfishNumber, int]:
        if 0x30 <= line[i] <= 0x39:  # digits
            return Leaf(value=line[i] - 0x30), i + 1
        left, i = cls._parse(line, i + 1)
        right, i = cls._parse(line, i + 1)
        new = left.parent = right.parent = SnailfishNumber(left, right)
        return new, i + 1

    def __str__(self) -> str:
        return f"[{self.left},{self.right}]"

    @property
    def magnitude(self) -> int:
        return 3 * self.left.magnitude + 2 * self.right.magnitude

    def split(self) -> bool:
        lv = getattr(self.left, "value", 0)
        if lv >= 10:
            self.left.parent = None
            l, r = Leaf(value=lv // 2), Leaf(value=(lv + 1) // 2)
            self.left = l.parent = r.parent = SnailfishNumber(l, r, parent=self)
            return True

        rv = getattr(self.right, "value", 0)
        if rv >= 10:
            self.right.parent = None
            l, r = Leaf(value=rv // 2), Leaf(value=(rv + 1) // 2)
            self.right = l.parent = r.parent = SnailfishNumber(l, r, parent=self)
            return True

        return False

    def explode(self) -> None:
        n, p = self, self.parent
        while p is not None and p.left is n:
            n, p = p, p.parent
        if p is not None:
            n = p.left
            while n.right is not None:
                n = n.right
            n.value += self.left.value

        n, p = self, self.parent
        while p is not None and p.right is n:
            n, p = p, p.parent
        if p is not None:
            n = p.right
            while n.left is not None:
                n = n.left
            n.value += self.right.value

        p = self.parent
        if p.left is self:
            p.left = Leaf(value=0, parent=p)
        else:
            p.right = Leaf(value=0, parent=p)
        self.parent = None

    def __iter__(self) -> Iterator[tuple[int, SnailfishNumber]]:
        if self.left.left is not None:
            yield from ((d + 1, n) for d, n in self.left)
        yield 0, self
        if self.right.right is not None:
            yield from ((d + 1, n) for d, n in self.right)

    def __add__(self, other: SnailfishNumber) -> SnailfishNumber:
        new = self.parent = other.parent = SnailfishNumber(self, other)

        while True:
            # explodes
            for depth, node in new:
                if depth == 4:
                    node.explode()
            # splits
            for _, node in new:
                if node.split():
                    break
            else:
                break

        return new


@dataclass
class Leaf(SnailfishNumber):
    value: int = 0
    left: Optional[SnailfishNumber] = None
    right: Optional[SnailfishNumber] = None

    def __str__(self) -> str:
        return str(self.value)

    @property
    def magnitude(self) -> int:
        return self.value


testlines = """\
[1,2]
[[1,2],3]
[9,[8,7]]
[[1,9],[8,5]]
[[[[1,2],[3,4]],[[5,6],[7,8]]],9]
[[[9,[3,8]],[[0,9],6]],[[[3,7],[4,9]],3]]
[[[[1,3],[5,3]],[[1,3],[8,7]]],[[[4,9],[6,9]],[[8,2],[7,3]]]]
""".splitlines()
for line in testlines:
    assert str(SnailfishNumber.from_line(line)) == line

assert (
    str(SnailfishNumber.from_line("[[[[4,3],4],4],[7,[[8,4],9]]]") + SnailfishNumber.from_line("[1,1]"))
    == "[[[[0,7],4],[[7,8],[6,0]]],[8,1]]"
)

testsums = {
    "[1,1]\n[2,2]\n[3,3]\n[4,4]": "[[[[1,1],[2,2]],[3,3]],[4,4]]",
    "[1,1]\n[2,2]\n[3,3]\n[4,4]\n[5,5]": "[[[[3,0],[5,3]],[4,4]],[5,5]]",
    "[1,1]\n[2,2]\n[3,3]\n[4,4]\n[5,5]\n[6,6]": "[[[[5,0],[7,4]],[5,5]],[6,6]]",
    (
        "[[[0,[4,5]],[0,0]],[[[4,5],[2,6]],[9,5]]]\n"
        "[7,[[[3,7],[4,3]],[[6,3],[8,8]]]]\n"
        "[[2,[[0,8],[3,4]]],[[[6,7],1],[7,[1,6]]]]\n"
        "[[[[2,4],7],[6,[0,5]]],[[[6,8],[2,8]],[[2,1],[4,5]]]]\n"
        "[7,[5,[[3,8],[1,4]]]]\n[[2,[2,2]],[8,[8,1]]]\n"
        "[2,9]\n[1,[[[9,3],9],[[9,0],[0,7]]]]￼￼\n[[[5,[7,4]],7],1]\n"
        "[[[[4,2],2],6],[8,7]]"
    ): "[[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]",
}
for lines, expected in testsums.items():
    nodes = map(SnailfishNumber.from_line, lines.splitlines())
    assert str(reduce(add, nodes)) == expected

testmagnitudes = {
    "[[1,2],[[3,4],5]]": 143,
    "[[[[0,7],4],[[7,8],[6,0]]],[8,1]]": 1384,
    "[[[[1,1],[2,2]],[3,3]],[4,4]]": 445,
    "[[[[3,0],[5,3]],[4,4]],[5,5]]": 791,
    "[[[[5,0],[7,4]],[5,5]],[6,6]]": 1137,
    "[[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]": 3488,
}
for testnum, expected in testmagnitudes.items():
    assert SnailfishNumber.from_line(testnum).magnitude == expected

testhomework = """\
[[[0,[5,8]],[[1,7],[9,6]]],[[4,[1,2]],[[1,4],2]]]
[[[5,[2,8]],4],[5,[[9,9],0]]]
[6,[[[6,2],[5,6]],[[7,6],[4,7]]]]
[[[6,[0,7]],[0,9]],[4,[9,[9,0]]]]
[[[7,[6,4]],[3,[1,3]]],[[[5,5],1],9]]
[[6,[[7,3],[3,2]]],[[[3,8],[5,7]],4]]
[[[[5,4],[7,7]],8],[[8,3],8]]
[[9,3],[[9,9],[6,[4,9]]]]
[[2,[[7,7],7]],[[5,8],[[9,3],[0,2]]]]
[[[[5,2],5],[8,[3,7]]],[[5,[7,5]],[4,4]]]
""".splitlines()
testsum = reduce(add, map(SnailfishNumber.from_line, testhomework))
assert str(testsum) == "[[[[6,6],[7,6]],[[7,7],[7,0]]],[[[7,7],[7,7]],[[7,8],[9,9]]]]"
assert testsum.magnitude == 4140


In [2]:
import aocd

homework_lines = aocd.get_data(day=18, year=2021).splitlines()
print("Part 1:", reduce(add, map(SnailfishNumber.from_line, homework_lines)).magnitude)

Part 1: 4132


# Part 2

In [3]:
from itertools import permutations
assert max(
    (SnailfishNumber.from_line(a) + SnailfishNumber.from_line(b)).magnitude
    for a, b in permutations(testhomework, 2)
) == 3993


In [4]:
print("Part 2:", max(
    (SnailfishNumber.from_line(a) + SnailfishNumber.from_line(b)).magnitude
    for a, b in permutations(homework_lines, 2)
))


Part 2: 4685
