In [267]:
from dataclasses import dataclass
from typing import Union
import math
import json
from functools import reduce
from itertools import permutations
import copy

In [301]:
@dataclass
class Node:
    value: Union[int, 'Pair']
        
    def __repr__(self):
        return repr(self.value)

@dataclass
class Pair:
    left: Node
    right: Node
        
    def __repr__(self):
        return f'[{self.left}, {self.right}]'
        
def split(n):
    half = n.value // 2
    n.value = Pair(Node(half), Node(half if n.value % 2 == 0 else half+1))
    
def explode(n, l, r):
    if l: l.value += n.value.left.value
    if r: r.value += n.value.right.value
    n.value = 0

def build_nodes(xs):
    if isinstance(xs, list):
        left, right = xs
        return Node(Pair(build_nodes(left), build_nodes(right)))
    else:
        return Node(xs)
    
def reduce_number(n):
    def _reduce_split(n, state):
        if isinstance(n.value, Pair):
            return (_reduce_split(n.value.left, state) or
                    _reduce_split(n.value.right, state))
        elif n.value >= 10:
            state['node'] = n
            return True
        else:
            return False
            
    def _reduce_explode(n, state, level=1):
        if isinstance(n.value, Pair):
            if level > 4 and 'node' not in state:
                state['node'] = n
            else:
                return (_reduce_explode(n.value.left, state, level+1) or
                        _reduce_explode(n.value.right, state, level+1) or
                        (level == 1 and 'node' in state))
        else:
            if 'node' in state:
                state['right'] = n
                return True
            else:
                state['left'] = n
        return False

    while True:
        state = {}
        if _reduce_explode(n, state):
            explode(state['node'], state.get('left'), state.get('right'))
        elif _reduce_split(n, state):
            split(state['node'])
        else: 
            return n

def parse_numbers(s):
    return [build_nodes(json.loads(l)) for l in s.splitlines()]

def sum_numbers(ns):
    ns = copy.deepcopy(ns)
    return reduce(lambda acc, n: reduce_number(Node(Pair(acc, n))), ns[1:], ns[0])

def magnitude(n):
    if isinstance(n.value, Pair):
        return 3*magnitude(n.value.left) + 2*magnitude(n.value.right)
    else:
        return n.value
    
def max_magnitude(ns):
    return max(magnitude(sum_numbers(np)) for np in permutations(ns, 2))

In [302]:
max_magnitude(parse_numbers("""[[[0,[5,8]],[[1,7],[9,6]]],[[4,[1,2]],[[1,4],2]]]
[[[5,[2,8]],4],[5,[[9,9],0]]]
[6,[[[6,2],[5,6]],[[7,6],[4,7]]]]
[[[6,[0,7]],[0,9]],[4,[9,[9,0]]]]
[[[7,[6,4]],[3,[1,3]]],[[[5,5],1],9]]
[[6,[[7,3],[3,2]]],[[[3,8],[5,7]],4]]
[[[[5,4],[7,7]],8],[[8,3],8]]
[[9,3],[[9,9],[6,[4,9]]]]
[[2,[[7,7],7]],[[5,8],[[9,3],[0,2]]]]
[[[[5,2],5],[8,[3,7]]],[[5,[7,5]],[4,4]]]"""))

3993

In [303]:
ns = parse_numbers("""[[[[4,3],4],4],[7,[[8,4],9]]]
[1,1]""")
sum_numbers(ns)

[[[[0, 7], 4], [[7, 8], [6, 0]]], [8, 1]]

In [304]:
ns = parse_numbers("""[[[0,[4,5]],[0,0]],[[[4,5],[2,6]],[9,5]]]
[7,[[[3,7],[4,3]],[[6,3],[8,8]]]]
[[2,[[0,8],[3,4]]],[[[6,7],1],[7,[1,6]]]]
[[[[2,4],7],[6,[0,5]]],[[[6,8],[2,8]],[[2,1],[4,5]]]]
[7,[5,[[3,8],[1,4]]]]
[[2,[2,2]],[8,[8,1]]]
[2,9]
[1,[[[9,3],9],[[9,0],[0,7]]]]
[[[5,[7,4]],7],1]
[[[[4,2],2],6],[8,7]]""")
sum_numbers(ns)

[[[[8, 7], [7, 7]], [[8, 6], [7, 7]]], [[[0, 7], [6, 6]], [8, 7]]]

In [305]:
%%time
with open('../data/day18.txt') as infile:
    ns = parse_numbers(infile.read())
    print('[p1] Magnitude of sum:', magnitude(sum_numbers(ns)))
    print('[p2] Max magnitude:', max_magnitude(ns))

[p1] Magnitude of sum: 3816
[p2] Max magnitude: 4819
CPU times: user 4.64 s, sys: 0 ns, total: 4.64 s
Wall time: 4.64 s
