In [9]:
from ast import literal_eval
from itertools import product
import functools as ft
from math import floor, ceil
from aoc import submit

DAY = 18


In [10]:
import re

rx_explode_pair = re.compile(r"\[(\d+),(\d+)]")
rx_explode_ls = re.compile(r"(\d+)(?!.*\d+)")
rx_explode_rs = re.compile(r"\d+")
rx_d = re.compile(r"\d+")


def replace_match(string, replacement, match, offset=0):
    span = match.span()
    return string[:span[0] + offset] + replacement + string[span[1] + offset:]


def find_explode_index(snailfish):
    count = 0
    for match in re.finditer("[\[\]]", snailfish):
        count += 1 if match.group() == '[' else -1
        if count == 5:
            return max(0, match.span()[0])


def explode(snailfish):
    index = find_explode_index(snailfish)
    if index is not None:
        pair = rx_explode_pair.search(snailfish[index:])
        ls = rx_explode_ls.search(snailfish[:index])
        rs = rx_explode_rs.search(snailfish[index + pair.end():])
        x, y = pair.groups()
        if rs:
            snailfish = replace_match(snailfish, str(int(y) + int(rs.group())), rs, index + pair.end())
        snailfish = replace_match(snailfish, '0', pair, index)
        if ls:
            snailfish = replace_match(snailfish, str(int(x) + int(ls.group())), ls)
        return snailfish, True
    return snailfish, False


def split(snailfish):
    match = re.search(r"\d{2,}", snailfish)
    if match:
        value = int(match.group())
        return replace_match(snailfish, f"[{value // 2},{value // 2 + (value % 2)}]", match), True
    return snailfish, False


def add(snail, fish):
    return f"[{snail},{fish}]"


def reduce(snailfish):
    reduced = True
    while reduced:
        snailfish, reduced = explode(snailfish)
        if not reduced:
            snailfish, reduced = split(snailfish)
    return snailfish


def magnitude(snailfish):
    snailfish = snailfish.replace("[", "(3*").replace(",", "+2*").replace("]", ")")
    result = eval(snailfish)
    return result


In [17]:
@submit(day=DAY)
def part_one(raw):
    snailfish, *rest = raw.splitlines()
    for other in rest:
        snailfish = reduce(add(snailfish, other))

    return magnitude(snailfish)

part_one:
✅ example: 4140           (31.77 ms)
✅ input:   4116           (207.62 ms)


In [21]:
@submit(day=DAY)
def part_two(raw):
    return max(
        magnitude(reduce(add(snail, fish)))
        for snail, fish in product(raw.splitlines(), repeat=2)
        if snail is not fish
    )

part_two:
✅ example: 3993           (48.54 ms)
✅ input:   4638           (3784.28 ms)


In [13]:
import ipytest
ipytest.autoconfig()

In [14]:
%%ipytest -qq --color=no

import pytest


@pytest.mark.parametrize('snailfish,expected', [
    ("[[[[[9,8],1],2],3],4]", "[[[[0,9],2],3],4]"),
    ("[1,[[[[9,8],1],2],3],4]", "[10,[[[0,9],2],3],4]"),
    ("[7,[6,[5,[4,[3,2]]]]]", "[7,[6,[5,[7,0]]]]"),
    ("[[6,[5,[4,[3,2]]]],1]", "[[6,[5,[7,0]]],3]"),
    ("[[3,[2,[1,[7,3]]]],[6,[5,[4,[3,2]]]]]", "[[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]]"),
    ("[[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]]", "[[3,[2,[8,0]]],[9,[5,[7,0]]]]"),
    ("[[[[0,[5,8]],[[1,7],[9,6]]],[[4,[1,2]],[[1,4],2]]],[[[5,[2,8]],4],[5,[[9,9],0]]]]",
     '[[[[5,0],[[9,7],[9,6]]],[[4,[1,2]],[[1,4],2]]],[[[5,[2,8]],4],[5,[[9,9],0]]]]'),
    ("[[[[5,0],[[9,7],[9,6]]],[[4,[1,2]],[[1,4],2]]],[[[5,[2,8]],4],[5,[[9,9],0]]]]",
     "[[[[5,9],[0,[16,6]]],[[4,[1,2]],[[1,4],2]]],[[[5,[2,8]],4],[5,[[9,9],0]]]]")
])
def test_explode(snailfish, expected):
    reduced, matched = explode(snailfish)
    assert reduced == expected


def test_add():
    assert add("[[[[4,3],4],4],[7,[[8,4],9]]]", "[1,1]") == "[[[[[4,3],4],4],[7,[[8,4],9]]],[1,1]]"


def test_magnitude():
    assert magnitude("[[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]") == 3488

..........                                                                                   [100%]


In [19]:
def isnode(obj):
    return isinstance(obj, Node)


class Node:
    def __init__(self, left, right, parent=None):
        self.left = left
        self.right = right
        self.parent = parent

        if isnode(left): left.parent = self
        if isnode(right): right.parent = self

    def __repr__(self):
        return f"[{self.left},{self.right}]"

    def __iter__(self):
        if isnode(self.left):
            yield from self.left

        yield self

        if isnode(self.right):
            yield from self.right

    def __add__(self, other):
        return Node(self, other).reduce()

    @property
    def is_root(self):
        return self.parent is None

    @property
    def depth(self):
        return 0 if self.is_root else self.parent.depth + 1

    @property
    def magnitude(self):
        left = self.left.magnitude if isnode(self.left) else self.left
        right = self.right.magnitude if isnode(self.right) else self.right

        return 3 * left + 2 * right

    @property
    def first_node(self):
        return self.left.first_node if isnode(self.left) else self

    @property
    def last_node(self):
        return self.right.last_node if isnode(self.right) else self

    @property
    def next_node(self):
        if self.is_root: return None
        parent = self.parent

        if parent.right == self:
            return parent.next_node

        return parent.right.first_node if isnode(parent.right) else parent

    @property
    def prev_node(self):
        if self.is_root: return None
        parent = self.parent

        if parent.left == self:
            return parent.prev_node

        return parent.left.last_node if isnode(parent.left) else parent

    def split(self):
        left, right = self.left, self.right

        if not isnode(left) and left > 9:
            self.left = Node(floor(left / 2), ceil(left / 2), self)
            return True

        if not isnode(right) and right > 9:
            self.right = Node(floor(right / 2), ceil(right / 2), self)
            return True

        return False

    def explode(self):
        if self.depth != 4:
            return False

        previous_node, next_node = self.prev_node, self.next_node

        if previous_node:
            if isnode(previous_node.right):
                previous_node.left += self.left
            else:
                previous_node.right += self.left

        if next_node:
            if isnode(next_node.left):
                next_node.right += self.right
            else:
                next_node.left += self.right

        if self.parent.left is self:
            self.parent.left = 0
        else:
            self.parent.right = 0

        return True

    def reduce(self):
        while (
                any(node.explode() for node in self)
                or any(node.split() for node in self)
        ):
            continue
        return self

    def clone(self):
        left = self.left.clone() if isnode(self.left) else self.left
        right = self.right.clone() if isnode(self.right) else self.right

        return Node(left, right)


def tree(value):
    try:
        left, right = value
    except TypeError:
        return value
    else:
        return Node(tree(left), tree(right))


def parse_input(raw):
    return [tree(literal_eval(line)) for line in raw.splitlines()]


@submit(day=DAY)
def part_one(raw):
    return ft.reduce(lambda snail, fish: snail + fish, parse_input(raw)).magnitude

part_one:
✅ example: 4140           (47.73 ms)
✅ input:   4116           (347.94 ms)


In [20]:
@submit(day=DAY)
def part_two(raw):
    return max(
        (snail.clone() + fish.clone()).magnitude
        for snail, fish in product(parse_input(raw), repeat=2)
        if snail is not fish
    )


part_two:
✅ example: 3993           (67.64 ms)
✅ input:   4638           (5113.80 ms)
