In [1]:
import sys

sys.path.append('../utils')

In [2]:
import json

from aoc import *

In [3]:
YEAR = 2021
DAY = 18

In [4]:
sample = '''[[[0,[5,8]],[[1,7],[9,6]]],[[4,[1,2]],[[1,4],2]]]
[[[5,[2,8]],4],[5,[[9,9],0]]]
[6,[[[6,2],[5,6]],[[7,6],[4,7]]]]
[[[6,[0,7]],[0,9]],[4,[9,[9,0]]]]
[[[7,[6,4]],[3,[1,3]]],[[[5,5],1],9]]
[[6,[[7,3],[3,2]]],[[[3,8],[5,7]],4]]
[[[[5,4],[7,7]],8],[[8,3],8]]
[[9,3],[[9,9],[6,[4,9]]]]
[[2,[[7,7],7]],[[5,8],[[9,3],[0,2]]]]
[[[[5,2],5],[8,[3,7]]],[[5,[7,5]],[4,4]]]'''

In [5]:
inp = get_input(YEAR, DAY)

In [6]:
def do_transform(inp):
    return list(filter(lambda x: len(x) > 0, inp.split('\n')))

In [7]:
sample = do_transform(sample)
inp = do_transform(inp)

In [8]:
class Node:
    def __init__(self, parent, value):
        self.parent = parent
        
        if type(value) == int:
            self.value = value
            self.left = None
            self.right = None
        elif type(value) == list:
            self.value = None
            self.left = Node(self, value[0])
            self.right = Node(self, value[1])
        else:
            self.value = None
            self.left = None
            self.right = None
            
    def __str__(self):
        if self.value != None:
            return str(self.value)
        
        return '[' + str(self.left) + ',' + str(self.right) + ']'
    
    def __add__(a, b):
        node = Node(None, None)
        node.left = a.copy()
        node.right = b.copy()
        node.left.parent = node
        node.right.parent = node
        node.reduce()
        return node
    
    def pred(self):
        if self.parent == None:
            return None
        
        if self == self.parent.left:
            return self.parent.pred()
        
        ans = self.parent.left
        while ans.right != None:
            ans = ans.right
        
        return ans
    
    def succ(self):
        if self.parent == None:
            return None
        
        if self == self.parent.right:
            return self.parent.succ()
        
        ans = self.parent.right
        while ans.left != None:
            ans = ans.left
        
        return ans
    
    def mag(self):
        if self.value != None:
            return self.value
        
        return 3 * self.left.mag() + 2 * self.right.mag()
    
    def copy(self):
        node = Node(None, None)
        node.value = self.value
        if self.left != None:
            node.left = self.left.copy()
            node.left.parent = node
        if self.right != None:
            node.right = self.right.copy()
            node.right.parent = node
        return node
    
    def explode(self):
        lv = self.left.value
        rv = self.right.value
                
        l = self.pred()
        if l != None:
            assert(l.value != None)
            l.value += lv
            
        r = self.succ()
        if r != None:
            assert(r.value != None)
            r.value += rv
            
        self.value = 0
        self.left = None
        self.right = None
        
    def split(self):
        lv = self.value // 2
        rv = self.value - lv
        self.value = None
        self.left = Node(self, lv)
        self.right = Node(self, rv)
    
    def reduce(self):        
        changed = [False]
        
        def check_explosion(node, depth):
            if node == None or changed[0]:
                return
            
            if depth >= 4 and node.left != None and node.left.value != None:
                node.explode()
                changed[0] = True
                return
            
            check_explosion(node.left, depth + 1)
            check_explosion(node.right, depth + 1)
            
        def check_split(node):
            if node == None or changed[0]:
                return

            if node.value != None and node.value >= 10:
                node.split()
                changed[0] = True
                return

            check_split(node.left)
            check_split(node.right)
            
        check_explosion(self, 0)
        
        if changed[0]:
            self.reduce()
            return

        check_split(self)
        
        if changed[0]:
            self.reduce()

In [9]:
def part1(inp):
    nums = [Node(None, json.loads(line)) for line in inp]   
    n = len(nums)
    ans = nums[0]

    for i in range(n - 1):
        node = ans + nums[i + 1]
        ans = node
        
    return ans.mag()

In [10]:
part1(sample)

4140

In [11]:
part1_ans = part1(inp)
part1_ans

3756

In [None]:
submit_answer(part1_ans, YEAR, DAY)

In [12]:
def part2(inp):
    nums = [Node(None, json.loads(line)) for line in inp]   
    n = len(nums)
    ans = 0

    for i in range(n):
        for j in range(n):
            if i != j:
                node = nums[i] + nums[j]
                ans = max(ans, node.mag())
                        
    return ans

In [13]:
part2(sample)

3993

In [14]:
part2_ans = part2(inp)
part2_ans

4585

In [None]:
submit_answer(part2_ans, YEAR, DAY, level=2)