In [173]:
def find_split(s:str):
    c = 0
    for i,p in enumerate(s):
        #print(i, p, c)
        if p == '[':
            c += 1
        elif p == ']':
            c -= 1
        elif p == ',' and c==1:
            return s[1:i], s[i+1:-1]

def leftmost(p):
    if not isinstance(p, Pair):
        return p
    return leftmost(p.left)
def rightmost(p):
    if not isinstance(p, Pair):
        return p
    return rightmost(p.right)

import math

class Leaf():
    def __init__(self):
        pass
    def next_right(self):
        if self.parent is None:
            return None
        if self.parent.left == self:
            return leftmost(self.parent.right)
        return self.parent.next_right()
    
    def next_left(self) -> int:
        if self.parent is None:
            return None
        if self.parent.right == self:
            return rightmost(self.parent.left)
        return self.parent.next_left()

    def _am_left(self):
        return self.parent.left == self

class SN(Leaf):
    def __init__(self, v, parent=None) -> None:
        self.v = int(v)
        self.parent = parent

    def _split(self) -> bool:
        if self.v < 10:
            return False
        a = math.floor(self.v/2)
        b = math.ceil(self.v/2)
        p = Pair(f"[{a},{b}]", self.parent.d + 1, self.parent)
        if self._am_left():
            self.parent.left = p
        else:
            self.parent.right = p
        return True

    def __abs__(self) -> int:
        return self.v

    def __repr__(self):
        return str(self.v)

class Pair(Leaf):
    def __init__(self, s: str, d=0, parent=None) -> None:
        self.d = d
        self.parent = parent
        a, b = find_split(s)
        self.left = SN(a, self) if a.isnumeric() else Pair(a, d+1, self)
        self.right = SN(b, self) if b.isnumeric() else Pair(b, d+1, self)

    def reduce(self):
        changed = True
        while changed:
            if self._explode():
                continue
            if self._split():
                continue
            changed = False

        return self

    def explode(self) -> 'Pair':
        self._explode()
        return self

    def _explode(self) -> bool:
        if isinstance(self.left, Pair):
            ex = self.left._explode()
            if ex:
                return True
        if isinstance(self.right, Pair):
            ex = self.right._explode()
            if ex:
                return True

        if self.d != 4:
            return False
        nr = self.next_right()
        if nr is not None:
            nr.v += self.right.v 
        nl = self.next_left()
        if nl is not None:
            nl.v += self.left.v
        if self._am_left():
            self.parent.left = SN(0, self.parent)
        else:
            self.parent.right = SN(0, self.parent)
        return True
    
    def split(self) -> 'Pair':
        self._split()
        return self

    def _split(self) -> bool:
        if self.left._split():
            return True
        if self.right._split():
            return True
        return False

    def __abs__(self) -> int:
        return 3*abs(self.left) + 2*abs(self.right)

    def __repr__(self) -> str:
        return f"[{self.left},{self.right}]"

    def __add__(self, other: 'Pair') -> 'Pair':
        p = Pair(f"[{self},{other}]")
        return p.reduce()

In [174]:
p = Pair("[[[[[9,8],1],2],3],4]")
print(p.explode())
print(Pair("[7,[6,[5,[4,[3,2]]]]]").explode())
print(Pair("[[3,[2,[1,[7,3]]]],[6,[5,[4,[3,2]]]]]").explode().explode())

[[[[0,9],2],3],4]
[7,[6,[5,[7,0]]]]
[[3,[2,[8,0]]],[9,[5,[7,0]]]]


In [129]:
Pair("[[[[0,7],4],[15,[0,13]]],[1,1]]").split()

[[[[0,7],4],[[7,8],[0,13]]],[1,1]]

In [130]:
p1 = Pair("[[[[4,3],4],4],[7,[[8,4],9]]]")
p2 = Pair("[1,1]")
p1+p2


[[[[0,7],4],[[7,8],[6,0]]],[8,1]]

In [172]:
def part1(inp):
    pairs = [Pair(l.strip()) for l in inp.strip().splitlines()]
    return abs(sum(pairs[1:], pairs[0]))

In [175]:
ex1 = """[1,1]
[2,2]
[3,3]
[4,4]"""
ex2 = """[1,1]
[2,2]
[3,3]
[4,4]
[5,5]"""
ex3 = """[[[0,[4,5]],[0,0]],[[[4,5],[2,6]],[9,5]]]
[7,[[[3,7],[4,3]],[[6,3],[8,8]]]]
[[2,[[0,8],[3,4]]],[[[6,7],1],[7,[1,6]]]]
[[[[2,4],7],[6,[0,5]]],[[[6,8],[2,8]],[[2,1],[4,5]]]]
[7,[5,[[3,8],[1,4]]]]
[[2,[2,2]],[8,[8,1]]]
[2,9]
[1,[[[9,3],9],[[9,0],[0,7]]]]
[[[5,[7,4]],7],1]
[[[[4,2],2],6],[8,7]]"""

print(part1(ex1))
print(part1(ex2))
print(part1(ex3))


445
791
3488


In [176]:


ex4 = """[[[0,[5,8]],[[1,7],[9,6]]],[[4,[1,2]],[[1,4],2]]]
[[[5,[2,8]],4],[5,[[9,9],0]]]
[6,[[[6,2],[5,6]],[[7,6],[4,7]]]]
[[[6,[0,7]],[0,9]],[4,[9,[9,0]]]]
[[[7,[6,4]],[3,[1,3]]],[[[5,5],1],9]]
[[6,[[7,3],[3,2]]],[[[3,8],[5,7]],4]]
[[[[5,4],[7,7]],8],[[8,3],8]]
[[9,3],[[9,9],[6,[4,9]]]]
[[2,[[7,7],7]],[[5,8],[[9,3],[0,2]]]]
[[[[5,2],5],[8,[3,7]]],[[5,[7,5]],[4,4]]]"""

part1(ex4)

4140

In [168]:
with open("data/18.txt") as f:
    pzl = f.read()
part1(pzl)

(3699, [[[[7,7],[7,8]],[[7,0],[8,7]]],[[[5,6],[6,6]],[[6,4],[0,6]]]])

In [177]:
from itertools import permutations
def part2(inp):
    pairs = [Pair(l.strip()) for l in inp.strip().splitlines()]
    res = max(abs(a+b) for a,b in permutations(pairs, 2))
    return res
part2(ex4)

3993

In [171]:
part2(pzl)

4735