In [1]:
import math
import re
from functools import reduce
from pathlib import Path

In [259]:
test_input_1 = """[[[0,[5,8]],[[1,7],[9,6]]],[[4,[1,2]],[[1,4],2]]]
[[[5,[2,8]],4],[5,[[9,9],0]]]
[6,[[[6,2],[5,6]],[[7,6],[4,7]]]]
[[[6,[0,7]],[0,9]],[4,[9,[9,0]]]]
[[[7,[6,4]],[3,[1,3]]],[[[5,5],1],9]]
[[6,[[7,3],[3,2]]],[[[3,8],[5,7]],4]]
[[[[5,4],[7,7]],8],[[8,3],8]]
[[9,3],[[9,9],[6,[4,9]]]]
[[2,[[7,7],7]],[[5,8],[[9,3],[0,2]]]]
[[[[5,2],5],[8,[3,7]]],[[5,[7,5]],[4,4]]]
"""

#test_input = Path("input_1.txt").read_text()

In [263]:
class NumberNode:
    def __init__(self, number=None, parent=None):
        self.parent = parent
        if isinstance(number, int):
            self.data = number
            self.left = self.right = None
        elif isinstance(number, list):
            self.data = None
            self.left = NumberNode(number[0], parent=self)
            self.right = NumberNode(number[1], parent=self)
        else:
            self.data = self.left = self.right = None
            
    def leftmost_four(self, level=0):
        level += 1
        if self.data is not None:
            return None
        if self.left and self.right and None not in (self.left.data, self.right.data):
            if level > 4:
                return self
            return None
        else:
            next_four = self.left.leftmost_four(level)
            if not next_four:
                next_four = self.right.leftmost_four(level)
            return next_four
        
    def leftmost_big(self):
        if self.data is not None:
            if self.data > 9:
                return self
            return None
        next_big = self.left.leftmost_big()
        if not next_big:
            next_big = self.right.leftmost_big()
        return next_big
                          
    def next_left_leaf(self, called_by=None):
        if self.data is not None:
            return self
        else:
            if called_by is self.left:
                if self.parent is None:
                    return None
                return self.parent.next_left_leaf(called_by=self)
            elif called_by is self.right:
                return self.left.next_left_leaf(called_by=self)
            elif called_by is self.parent:
                return self.right.next_left_leaf()
            else:
                return self.parent.next_left_leaf(called_by=self)
            
    def next_right_leaf(self, called_by=None):
        if self.data is not None:
            return self
        else:
            if called_by is self.right:
                if self.parent is None:
                    return None
                return self.parent.next_right_leaf(called_by=self)
            elif called_by is self.left:
                return self.right.next_right_leaf(called_by=self)
            elif called_by is self.parent:
                return self.left.next_right_leaf(called_by=self)
            else:
                return self.parent.next_right_leaf(called_by=self)
    
    def explode(self):
        if four := self.leftmost_four():
            if left := four.next_left_leaf():
                left.data += four.left.data
            four.left.data = 0
            
            if right := four.next_right_leaf():
                right.data += four.right.data
            four.right.data = 0
            
            if four.left.data == 0 and four.right.data == 0:
                four.data = 0
                four.left = four.right = None
            return True
        return False
    
    def split(self):
        if big := self.leftmost_big():
            big.left = NumberNode()
            big.left.data = math.floor(big.data / 2)
            
            big.right = NumberNode()
            big.right.data = math.ceil(big.data / 2)
            
            big.data = None
            return True
        return False
    
    def __repr__(self):
        return self.__str__()
    
    def __str__(self):
        if self.data is not None:
            return str(self.data)
        else:
            return f"[{self.left},{self.right}]"
        
        
def build_tree(number):
    lists = eval(number)
    return NumberNode(lists)

def build_trees(numbers):
    return [build_tree(number) for number in numbers]

def parse_input(input_string):
    return build_trees(input_string.strip().splitlines())

def add_trees(a, b):
    new_root = NumberNode()
    new_root.left = a
    new_root.right = b
    a.parent = new_root
    b.parent = new_root
    return new_root

def snail_add(a: NumberNode, b: NumberNode):
    result = add_trees(a, b)
    reducing = True
    while reducing:
        initial_state = str(result)
        if not result.explode():
            if not result.split():
                reducing = False
    return result
    

In [281]:
# Part 1 - Test
tree = build_tree("[[[[[9,8],1],2],3],4]")
tree.explode()
assert str(tree) == "[[[[0,9],2],3],4]"


tree = build_tree("[7,[6,[5,[4,[3,2]]]]]")
tree.explode()
assert str(tree) == "[7,[6,[5,[7,0]]]]"

tree = build_tree("[[6,[5,[4,[3,2]]]],1]")
tree.explode()
assert str(tree) == "[[6,[5,[7,0]]],3]"

tree = build_tree("[[3,[2,[1,[7,3]]]],[6,[5,[4,[3,2]]]]]")
tree.explode()
assert str(tree) == "[[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]]"

tree = build_tree("[[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]]")
tree.explode()
assert str(tree) == "[[3,[2,[8,0]]],[9,[5,[7,0]]]]"

assert str(add_trees(build_tree("[[[[4,3],4],4],[7,[[8,4],9]]]"), build_tree("[1,1]"))) == "[[[[[4,3],4],4],[7,[[8,4],9]]],[1,1]]"

tree = build_tree("[[[[[4,3],4],4],[7,[[8,4],9]]],[1,1]]")
tree.explode()
assert str(tree) == "[[[[0,7],4],[7,[[8,4],9]]],[1,1]]"

tree = build_tree("[[[[0,7],4],[7,[[8,4],9]]],[1,1]]")
tree.explode()
assert str(tree) == "[[[[0,7],4],[15,[0,13]]],[1,1]]"

tree = build_tree("[[[[0,7],4],[15,[0,13]]],[1,1]]")
tree.split()
assert str(tree) == "[[[[0,7],4],[[7,8],[0,13]]],[1,1]]"

tree = build_tree("[[[[0,7],4],[[7,8],[0,13]]],[1,1]]")
tree.split()
assert str(tree) == "[[[[0,7],4],[[7,8],[0,[6,7]]]],[1,1]]"

tree = build_tree("[[[[0,7],4],[[7,8],[0,[6,7]]]],[1,1]]")
tree.explode()
assert str(tree) == "[[[[0,7],4],[[7,8],[6,0]]],[8,1]]"

assert str(snail_add(build_tree("[[[[4,3],4],4],[7,[[8,4],9]]]"), build_tree("[1,1]"))) == "[[[[0,7],4],[[7,8],[6,0]]],[8,1]]"

assert str(reduce(snail_add, (build_trees(("[1,1]", "[2,2]", "[3,3]", "[4,4]"))))) == "[[[[1,1],[2,2]],[3,3]],[4,4]]"
assert str(reduce(snail_add, (build_trees(("[1,1]", "[2,2]", "[3,3]", "[4,4]", "[5,5]"))))) == "[[[[3,0],[5,3]],[4,4]],[5,5]]"
assert str(reduce(snail_add, (build_trees(("[1,1]", "[2,2]", "[3,3]", "[4,4]", "[5,5]", "[6,6]"))))) == "[[[[5,0],[7,4]],[5,5]],[6,6]]"

number_trees = parse_input(test_input_1)
str(reduce(snail_add, number_trees[0:3]))
#assert str(reduce(snail_add, number_trees)) == "[[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]"

KeyboardInterrupt: 

In [269]:
# Part 1
number_trees = parse_input(test_input_1)
reduce(snail_add, number_trees)

KeyboardInterrupt: 

In [591]:
PAIR = re.compile("^(\d+),(\d+)")
NUMBER = re.compile("\d+")
DOUBLE = re.compile("\d{2}")
EXPLODE = re.compile("(\d+)([^\d]+)(\d+)")

def parse_input(input_string):
    return input_string.strip().splitlines()

def snail_add(a, b):
    result = add_numbers(a, b)
    reducing = True
    while reducing:
        new_result = explode_number(result)
        if new_result == result:
            result = new_result
            new_result = split_number(result)
            if new_result == result:
                result = new_result
                reducing = False
        result = new_result
    return result

def add_numbers(a, b):
    return f"[{a},{b}]"


def explode_number(number, verbose=False):
    if level_four := first_level_four_index(number):
        a_i, b_i = level_four
    else:
        return number
    
    left_i = next_number_left(a_i, number)
    right_i = next_number_right(b_i, number)
    
    if left_i:
        left = explode_left(number[left_i[0]:a_i[1]], verbose=verbose)
        left_cut = left_i[0]
    else:
        left = "0"
        left_cut = a_i[0]-1

    if right_i:
        #print(f"{b_i=} {right_i=}")
        right = explode_right(number[b_i[0]:right_i[1]], verbose=verbose)
        right_cut = right_i[1]
    else:
        right = "0"
        right_cut = b_i[1]+1
    new_pair = f"{left},{right}".replace("0,0", "0")
    new_number = f"{number[:left_cut]}{new_pair}{number[right_cut:]}".replace("[0],", "")
    if verbose:
        print(f"{number} => {new_number}")
    return new_number


def explode_left(substring, verbose=False):
    # 7,[[8 => 15,[0
    # 0,[6 => 6
    a, middle, b = EXPLODE.match(substring).groups()
    a, b = int(a), int(b)
    if middle.count("[") > 1:
        middle = substring[1:-1].replace("[", "", 1)
        new_substring =  f"{a+b}{middle}0"
    else:
        new_substring = str(a+b)
    if verbose:
        print(f"  L: {substring} => {new_substring}")
    return new_substring

def explode_right(substring, verbose=False):
    # 4],9 => 13
    # 7]]]],[1 => 0]]],[8
    a, middle, b = EXPLODE.match(substring).groups()
    a, b = int(a), int(b)
    if middle.count("]") > 0:
        middle = "".join(middle.rsplit("]", 1))
        new_substring = f"0{middle}{a+b}"
    else:
        new_substring = str(a+b)
    if verbose:
        print(f"  R: {substring} => {new_substring}")
    return new_substring

def split_number(number):
    for match in DOUBLE.finditer(number):
        value = int(match[0])
        left = math.floor(value / 2)
        right = math.ceil(value / 2)
        return f"{number[:match.start()]}[{left},{right}]{number[match.end():]}"
    return number
    
def next_number_left(i, number):
    for number in list(NUMBER.finditer(number[:i[0]]))[::-1]:
        return number.start(), number.end()
    #for n, char in reversed(list(enumerate(number[:i]))):
    #    if char.isdigit():
    #        return n
    return None

def next_number_right(i, number):
    for number in NUMBER.finditer(number[i[1]:]):
        return number.start()+i[1], number.end()+i[1]
    return None

def first_level_four_index(number):
    level = 0
    for n, char in enumerate(number):
        if char == "[":
            level += 1
        elif char == "]":
            level -= 1
        elif level > 4:
            if pair := PAIR.match(number[n:]):
                return (pair.start(1)+n, pair.end(1)+n), (pair.start(2)+n, pair.end(2)+n)
    return None

In [592]:
# Part 1 - Test
assert explode_number("[[[[[9,8],1],2],3],4]") == "[[[[0,9],2],3],4]"
assert explode_number("[7,[6,[5,[4,[3,2]]]]]") == "[7,[6,[5,[7,0]]]]"
assert explode_number("[[6,[5,[4,[3,2]]]],1]") == "[[6,[5,[7,0]]],3]"
assert explode_number("[[3,[2,[1,[7,3]]]],[6,[5,[4,[3,2]]]]]") == "[[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]]"
assert explode_number("[[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]]") == "[[3,[2,[8,0]]],[9,[5,[7,0]]]]"

assert add_numbers("[[[[4,3],4],4],[7,[[8,4],9]]]", "[1,1]") == "[[[[[4,3],4],4],[7,[[8,4],9]]],[1,1]]"
assert explode_number("[[[[[4,3],4],4],[7,[[8,4],9]]],[1,1]]") == "[[[[0,7],4],[7,[[8,4],9]]],[1,1]]"
assert explode_number("[[[[0,7],4],[7,[[8,4],9]]],[1,1]]") == "[[[[0,7],4],[15,[0,13]]],[1,1]]"
assert split_number("[[[[0,7],4],[15,[0,13]]],[1,1]]") == "[[[[0,7],4],[[7,8],[0,13]]],[1,1]]"
assert split_number("[[[[0,7],4],[[7,8],[0,13]]],[1,1]]") == "[[[[0,7],4],[[7,8],[0,[6,7]]]],[1,1]]"
assert explode_number("[[[[0,7],4],[[7,8],[0,[6,7]]]],[1,1]]") == "[[[[0,7],4],[[7,8],[6,0]]],[8,1]]"
assert snail_add("[[[[4,3],4],4],[7,[[8,4],9]]]", "[1,1]") == "[[[[0,7],4],[[7,8],[6,0]]],[8,1]]"

assert reduce(snail_add, ("[1,1]", "[2,2]", "[3,3]", "[4,4]")) == "[[[[1,1],[2,2]],[3,3]],[4,4]]"
assert reduce(snail_add, ("[1,1]", "[2,2]", "[3,3]", "[4,4]", "[5,5]")) == "[[[[3,0],[5,3]],[4,4]],[5,5]]"
assert reduce(snail_add, ("[1,1]", "[2,2]", "[3,3]", "[4,4]", "[5,5]", "[6,6]")) == "[[[[5,0],[7,4]],[5,5]],[6,6]]"

a = "[[[0,[4,5]],[0,0]],[[[4,5],[2,6]],[9,5]]]"
b = "[7,[[[3,7],[4,3]],[[6,3],[8,8]]]]"
c = "[[[[4,0],[5,4]],[[7,7],[6,0]]],[[8,[7,7]],[[7,9],[5,0]]]]"
assert snail_add(a, b) == c


a = "[[[[4,0],[5,4]],[[7,7],[6,0]]],[[8,[7,7]],[[7,9],[5,0]]]]"
b = "[[2,[[0,8],[3,4]]],[[[6,7],1],[7,[1,6]]]]"
c = "[[[[6,7],[6,7]],[[7,7],[0,7]]],[[[8,7],[7,7]],[[8,8],[8,0]]]]"

#print(snail_add(a, b))
#assert snail_add(a, b) == c

#numbers = parse_input(test_input_1)
#print(reduce(snail_add, numbers))
#assert reduce(snail_add, numbers) == "[[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]"

In [519]:
[[[[[0,0],[3,2]],[3,3]],[4,4]],[5,5]]

[[[[[0, 0], [3, 2]], [3, 3]], [4, 4]], [5, 5]]

In [6]:
# Part 1

In [7]:
# Part 2 - Test

In [185]:
# Part 2

2

In [396]:
PAIR.match("12,132],[9,67]").span(2)

(3, 6)