In [1]:
def gen(*, levels, gt_point=(-1,), lt_point=(2,)):
    for a in range(2):
        if lt_point > (a,) > gt_point:
            yield (a,)
        if levels > 1:
            for b in range(2):
                if lt_point > (a, b) > gt_point:
                    yield (a, b)
                if levels > 2:
                    for c in range(2):
                        if lt_point > (a, b, c) > gt_point:
                            yield (a, b, c)
                        if levels > 3:
                            for d in range(2):
                                if lt_point > (a, b, c, d) > gt_point:
                                    yield (a, b, c, d)
                                if levels > 4:
                                    for e in range(2):
                                        if lt_point > (a, b, c, d, e) > gt_point:
                                            yield (a, b, c, d, e)

In [2]:
from itertools import product
from math import floor, ceil
import copy

class SnailfishNumber:
    def __init__(self, n):
        self.n = copy.deepcopy(n)
        self.reduce()
        
    def __repr__(self):
        return str(self)
    
    def __str__(self):
        return str(self.n)
    
    def __add__(self, other):
        return SnailfishNumber([self.n, other.n])

    @property
    def positions(self):
        for g in gen(levels=5):
            if len(g) == 1:
                a = g[0]
                try:
                    if type(self.n[a]) is int:
                        yield (a,)
                except (TypeError, IndexError):
                    pass
            elif len(g) == 2:
                a, b = g
                try:
                    if type(self.n[a][b]) is int:
                        yield (a, b)
                except (TypeError, IndexError):
                    pass
            elif len(g) == 3:
                a, b, c = g
                try:
                    if type(self.n[a][b][c]) is int:
                        yield (a, b, c)
                except (TypeError, IndexError):
                    pass
            elif len(g) == 4:
                a, b, c, d = g
                try:
                    if type(self.n[a][b][c][d]) is int:
                        yield (a, b, c, d)
                except (TypeError, IndexError):
                    pass
            elif len(g) == 5:
                a, b, c, d, e = g
                try:
                    if type(self.n[a][b][c][d][e]) is int:
                        yield (a, b, c, d, e)
                except (TypeError, IndexError):
                    pass
    
    @property
    def magnitude(self):
        a, b = self.n
        if type(a) is int:
            ma = 3 * a
        else:
            ma = 3 * SnailfishNumber(a).magnitude
        if type(b) is int:
            mb = 2 * b
        else:
            mb = 2 * SnailfishNumber(b).magnitude
        return ma + mb
    
    def reduce(self):
        while True:
            explode_point = self.get_explode_point()
            if explode_point is not None:
                self.explode(explode_point)
            else:
                split_point = self.get_split_point()
                if split_point is not None:
                    self.split(split_point)
                else:
                    return
    
    def get_explode_point(self):
        for a, b, c, d, e in product(range(2), repeat=5):
            try:
                self.n[a][b][c][d][e]
                return (a, b, c, d, e)
            except (TypeError, IndexError):
                pass
            
    def explode(self, explode_point):
        a, b, c, d, e = explode_point
        pair = self.n[a][b][c][d]
        pl, pr = pair
        
        left = self.get_next_number(explode_point, left=True)
        if left is not None:
            if len(left) == 1:
                la = left[0]
                self.n[la] += pl
            elif len(left) == 2:
                la, lb = left
                self.n[la][lb] += pl
            elif len(left) == 3:
                la, lb, lc = left
                self.n[la][lb][lc] += pl
            elif len(left) == 4:
                la, lb, lc, ld = left
                self.n[la][lb][lc][ld] += pl
            elif len(left) == 5:
                la, lb, lc, ld, le = left
                self.n[la][lb][lc][ld][le] += pl
            
        right = self.get_next_number(self.get_next_number(explode_point))
        if right is not None:
            if len(right) == 1:
                ra = right[0]
                self.n[ra] += pr
            elif len(right) == 2:
                ra, rb = right
                self.n[ra][rb] += pr
            elif len(right) == 3:
                ra, rb, rc = right
                self.n[ra][rb][rc] += pr
            elif len(right) == 4:
                ra, rb, rc, rd = right
                self.n[ra][rb][rc][rd] += pr
            elif len(right) == 5:
                ra, rb, rc, rd, re = right
                self.n[ra][rb][rc][rd][re] += pr
            
        self.n[a][b][c][d] = 0
        
    def get_next_number(self, explode_point, left=False):
        found = False
        positions = list(self.positions)
        if left:
            positions = reversed(positions)
        for pos in positions:
            if found:
                return pos
            elif pos == explode_point:
                found = True
    
    def get_split_point(self):
        for g in gen(levels=4):
            if len(g) == 1:
                a = g[0]
                try:
                    n = self.n[a]
                    if type(n) is int and n > 9:
                        return (a,)
                except (TypeError, IndexError):
                    continue
            elif len(g) == 2:
                a, b = g
                try:
                    n = self.n[a][b]
                    if type(n) is int and n > 9:
                        return (a, b)
                except (TypeError, IndexError):
                    continue
            elif len(g) == 3:
                a, b, c = g
                try:
                    n = self.n[a][b][c]
                    if type(n) is int and n > 9:
                        return (a, b, c)
                except (TypeError, IndexError):
                    continue
            elif len(g) == 4:
                a, b, c, d = g
                try:
                    n = self.n[a][b][c][d]
                    if type(n) is int and n > 9:
                        return (a, b, c, d)
                except (TypeError, IndexError):
                    continue
    
    def split(self, split_point):
        if len(split_point) == 1:
            a = split_point[0]
            n = self.n[a]
            self.n[a] = [floor(n / 2), ceil(n / 2)]
        elif len(split_point) == 2:
            a, b = split_point
            n = self.n[a][b]
            self.n[a][b] = [floor(n / 2), ceil(n / 2)]
        elif len(split_point) == 3:
            a, b, c = split_point
            n = self.n[a][b][c]
            self.n[a][b][c] = [floor(n / 2), ceil(n / 2)]
        elif len(split_point) == 4:
            a, b, c, d = split_point
            n = self.n[a][b][c][d]
            self.n[a][b][c][d] = [floor(n / 2), ceil(n / 2)]

In [3]:
from ast import literal_eval

def get_numbers():
    with open('input') as f:
        return [literal_eval(line.strip()) for line in f]

In [4]:
numbers = get_numbers()
first, *rest = numbers

a = SnailfishNumber(first)
for b in rest:
    a += SnailfishNumber(b)
    
print("Part 1:")
print(a.magnitude)

Part 1:
3793


In [5]:
max_score = 0
for i, a in enumerate(numbers):
    for j, b in enumerate(numbers):
        if i == j:
            continue
        n = SnailfishNumber(a) + SnailfishNumber(b)
        score = n.magnitude
        if score > max_score:
            max_score = score
            
print("Part 1:")
print(max_score)

Part 1:
4695
