In [90]:
import numpy as np
from collections import Counter, defaultdict, namedtuple
from math import gcd, ceil
import re
import networkx as nx
from dataclasses import dataclass
from matplotlib import pyplot as plt
# plt.imshow(pic)
from functools    import cache

from aocutils.common import to_int, flatten, reverse_dict, zippify, list_multiply, ints
from aocutils.grid import iterate, arr_to_dict, grid_to_dict, neighbors, arr_neighbors, dimensions, manhattan, positive, conv1d, conv2d
from aocutils.maze import bfs, dijkstra, get_path, dfs
from aocutils.math import lcm, gcd, factors, mul_inv
from aocutils.special import find_pattern_in_iter, find_repeat, find_cycle, md5, CFG, ShuntingYard, binarysearch, deduce_matches
from aocutils.visuals import visualize_graph, labelize, cat, animate_grid

In [142]:
import math
class Pair():
    # the explode functionality is still very messy... could have solved with a DFS to have the Pairs in order
    def __init__(self, pair, nest, parent, isleft=False):
        self.pair = pair
        self.nest = nest
        self.children = []
        self.left = None
        self.right = None
        self.parent = parent
        self.parseall()
        self.exploding = False
        self.isleft = isleft

    def incnest(self):
        self.nest += 1
        if not isinstance(self.left, int):
            self.left.incnest()
        if not isinstance(self.right, int):
            self.right.incnest()

    def parse(self, pair):
        # print(pair)
        pnt = 0
        if pair[1] != '[':
            # print('newval', pair[1])
            pnt = 1
            s = ''
            while pnt < len(pair) and pair[pnt] not in ', ]':
                s += pair[pnt]
                pnt += 1
            # print('final', s)
            return int(s), pnt
        else:
            opens = 0
            for ch in pair[1:-1]:
                pnt += 1
                if ch == '[':
                    opens +=1
                if ch == ']':
                    opens -= 1
                if opens == 0:
                    return pair[1:pnt+1], pnt+1
    
    def ispair(self, part):
        return True if isinstance(part, Pair) else False

    def receiveleft(self, num, child=None):
        # print(self, 'receiveleft', num)
        if isinstance(self.left, int):
            self.left += num
            return
            
        else:
            if child == True: # child is left
                if self.parent:
                    self.parent.receiveleft(num, self.isleft)
            elif child == False:
                if not self.left.exploding:
                    self.left.receiveright(num)
            else: # child = None
                # [[6,3],[8,8]] receiveleft 3 moet links erbij doen
                self.left.receiveleft(num)

    def receiveright(self, num, child=None):
        # print(self, 'receiveright', num)
        if isinstance(self.right, int):
            self.right += num
            return
        else:
            if child == False: # child is right
                if self.parent:
                    self.parent.receiveright(num, self.isleft)
            elif child == True:
                if not self.right.exploding:
                    self.right.receiveleft(num)
            else: # child = None
                self.right.receiveright(num)

    def explode(self):
        if isinstance(self.left, int) and isinstance(self.right, int):
            if self.nest < 4:
                return False
            else:
                # explode
                self.exploding = True
                # print(f'    ... explode {self.left}, {self.right}')
                return self.left, self.right
        if not isinstance(self.left, int):
            res = self.left.explode()
            if res:
                if not res == True:
                    l, r = res
                    self.left = 0
                    if isinstance(self.right, int):
                        self.right += r
                    else:
                        self.right.receiveleft(r)
                    if self.parent:
                        self.parent.receiveleft(l, self.isleft)
                return True
                


        if not isinstance(self.right, int):
            res = self.right.explode()
            if res:
                if not res == True:
                    l, r = res
                    self.right = 0
                    if isinstance(self.left, int):
                        self.left += l
                    else:
                        self.left.receiveright(l)
                    if self.parent:

                        self.parent.receiveright(r, self.isleft)
                return True
        return False

    def split(self):   
        if isinstance(self.left, int):
            if self.left > 9:
                res = f'[{self.left//2},{math.ceil(self.left/2)}]'
                self.left = Pair(res, self.nest + 1, self, isleft=True)
                return True
        else:
            if self.left.split():
                return True
        
        if isinstance(self.right, int):
            if self.right > 9:
                res = f'[{self.right//2},{math.ceil(self.right/2)}]'
                self.right = Pair(res, self.nest + 1, self, isleft=False)
                return True
            else:
                return False
        else:
            return self.right.split()
    
    def magnitude(self):
        if not isinstance(self.left, int):
            a = self.left.magnitude() * 3
        else:
            a = self.left * 3
        if not isinstance(self.right, int):
            b = self.right.magnitude() * 2
        else:
            b = self.right * 2
        # print('magnitude', a,b)
        return a + b
        
    def parseall(self):
        self.left, pnt = self.parse(self.pair)
        # print('left', self.left, 'pnt', pnt)
        if not isinstance(self.left, int):
            self.left = Pair(self.left, self.nest + 1, self, isleft=True)
        self.right, pnt = self.parse(self.pair[pnt:])
        if not isinstance(self.right, int):
            self.right = Pair(self.right, self.nest + 1, self)
        # print('right', self.right, 'pnt', pnt)


    def __repr__(self):        
        return f'[{self.left},{self.right}]'
    def leftright(self):
        return str(self.left), str(self.right)



In [143]:
p = Pair('[[[[[9,8],1],2],3],4]', 0, None)
p.explode()
assert p.__repr__() == '[[[[0,9],2],3],4]'

p = Pair('[7,[6,[5,[4,[3,2]]]]]', 0, None)
p.explode()
assert p.__repr__() == '[7,[6,[5,[7,0]]]]'

p = Pair('[[6,[5,[4,[3,2]]]],1]', 0, None)
p.explode()
assert p.__repr__() == '[[6,[5,[7,0]]],3]'

p = Pair('[[3,[2,[1,[7,3]]]],[6,[5,[4,[3,2]]]]]', 0, None)
p.explode()
assert p.__repr__() == '[[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]]'

p = Pair('[[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]]', 0, None)
p.explode()
assert p.__repr__() == '[[3,[2,[8,0]]],[9,[5,[7,0]]]]'

In [144]:
def cat(a,b):
    root = Pair('[1,1]', -1, None)
    root.left = a
    root.right = b
    root.left.isleft = True
    root.right.isleft = False
    root.left.parent = root
    root.right.parent = root
    root.incnest()
        
    while True:
        if root.explode():
            continue
        if root.split():
            continue
        break
    return root



In [145]:
a = Pair('[[[[4,3],4],4],[7,[[8,4],9]]]', 0, None)
b = Pair('[1,1]', 0, None)
p = cat(a,b)
assert p.__repr__() == '[[[[0,7],4],[[7,8],[6,0]]],[8,1]]'

In [146]:
def parse(filename):
    lines = open(filename).read().splitlines()
    a = Pair(lines[0], 0, None)
    for line in lines[1:]:
        b = Pair(line, 0, None)
        a = cat(a,b)
    return a

In [147]:
assert parse('test4.txt').__repr__() == '[[[[1,1],[2,2]],[3,3]],[4,4]]'
assert parse('test5.txt').__repr__() == '[[[[3,0],[5,3]],[4,4]],[5,5]]'
assert parse('test6.txt').__repr__() == '[[[[5,0],[7,4]],[5,5]],[6,6]]'
assert parse('test2.txt').__repr__() == '[[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]'


In [148]:
parse('input.txt').magnitude()

4347

In [149]:
lines = open('input.txt').read().splitlines()
best = 0
for i in lines:
    for j in lines:
        a = Pair(i, 0, None)
        b = Pair(j, 0, None)
        best = max(cat(a,b).magnitude(), best)
best

4721

In [151]:
# the string based implementation below

import re
import math

def explode(pair):
    if isinstance(pair, list):
        pair = pair.__repr__().replace(' ', '')
    pnt = 0
    opens = 0
    
    while pnt < len(pair):
        if pair[pnt] == '[':
            opens += 1
        if pair[pnt] == ']':
            opens -= 1
        
        if opens == 5:
            end = pair.find(']', pnt)
            a,b = re.findall(r'\d+',pair[pnt:end+1])
            beforenums = re.findall(r'\d+',pair[:pnt])
            afternums = re.findall(r'\d+',pair[end:])

            if beforenums:
                before = str(int(beforenums[-1]) + int(a))
                before = before.join(pair[:pnt].rsplit(beforenums[-1],1))
            else:
                before = pair[:pnt]

            if afternums:
                after = str(int(afternums[0]) + int(b))
                after = after.join(pair[end+1:].split(afternums[0], 1))
            else:
                after = pair[end+1:]
            
            return before + '0' + after
        pnt +=1
    
    return pair

def split(pair):
    large = re.search(r'\d\d+', pair)
    if large:
        newpart = '[' + str(int(large[0])//2) + ',' + str(math.ceil(int(large[0])/2)) + ']'
        pair = re.sub(large[0], newpart, pair, count=1)
    return pair


def final(s):
    new = explode(s)
    if new != s: 
        # print('result after explosion', new)
        return final(new)
    new2 = split(new)
    if new2 != new:
        # print('result after split', new2)
        return final(new2)
    else:
        return new2

def cat(a,b):
    if isinstance(a, list):
        a = a.__repr__().replace(' ', '')
    if isinstance(b, list):
        b = str(b)
    return final(f'[{a},{b}]')

def parse(filename):
    lines = open(filename).read().splitlines()
    a = lines[0]
    for line in lines[1:]:
        a = cat(a,line)
    return a


In [152]:
assert final('[[[[[9,8],1],2],3],4]') == '[[[[0,9],2],3],4]'
assert final('[7,[6,[5,[4,[3,2]]]]]') == '[7,[6,[5,[7,0]]]]'
assert final('[[6,[5,[4,[3,2]]]],1]') == '[[6,[5,[7,0]]],3]'
assert final('[[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]]') == '[[3,[2,[8,0]]],[9,[5,[7,0]]]]'
assert parse('test5.txt') == '[[[[3,0],[5,3]],[4,4]],[5,5]]'
assert parse('test6.txt') == '[[[[5,0],[7,4]],[5,5]],[6,6]]'
assert parse('test2.txt') == '[[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]'

In [160]:
def findclose(start):
    opens = 1
    cur = start
    while opens > 0:
        cur +=1
        if res[cur] == ']':
            opens -= 1
        if res[cur] == '[':
            opens += 1
    return cur

def score(start, end):
    if '[' not in res[start:end]:
        a,b = res[start:end].split(',')
        return 3*int(a) + 2*int(b)
    else:
        if res[start] == '[':
            endfirst = findclose(start)
            a = score(start+1,endfirst)
        else:
            endfirst = res.find(',', start) - 1
            a = res[start:endfirst+1]
        assert res[endfirst+1] == ','

        if res[endfirst+2]== '[':
            b = score(endfirst+3, end-1)
        else:
            b = res[endfirst+2:end]
        return 3*int(a)+2*int(b)

res = '[9,1]'
assert score(1, len(res)-1) == 29
res = '[1,9]'
assert score(1, len(res)-1) == 21
res = '[[1,2],[[3,4],5]]'
assert score(1, len(res)-1) == 143
res = '[[[[1,1],[2,2]],[3,3]],[4,4]]'
assert score(1, len(res)-1) == 445
res = '[[[[3,0],[5,3]],[4,4]],[5,5]]'
assert score(1, len(res)-1) == 791
res = '[[[[0,7],4],[[7,8],[6,0]]],[8,1]]'
assert score(1, len(res)-1) == 1384
res = '[[[[5,0],[7,4]],[5,5]],[6,6]]'
assert score(1, len(res)-1) == 1137
res = '[[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]'
assert score(1, len(res)-1) == 3488

res = parse('input.txt')
score(1, len(res)-1)
        

4347

In [161]:
lines = open('input.txt').read().splitlines()
best = 0
for i in lines:
    for j in lines:
        if i!= j:
            res = cat(i,j)
            s = score(1, len(res)-1)
            best = max(s, best)
best

4721