In [118]:
from collections import Counter, deque, defaultdict
from itertools import product, count
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx

In [217]:
def load_input(fn):
    with open(fn) as fh:
        snns = [p for p in fh.read().split('\n')]
        
    return snns


def get_pairs(p):
    b_count = 0
    for ind, s in enumerate(p[1:]):
        if s == '[':
            b_count += 1
        elif s == ']':
            b_count -= 1
        if b_count == 0:
            split_ind = ind + 2
            p1, p2 = p[1:split_ind], p[split_ind+1:-1]
            return p1, p2

        
def add_sn(sn1, sn2):
    sn = f'[{sn1},{sn2}]'
    return sn


def sn_graph(sn_top):
    G = nx.DiGraph()
    
    names = count(0)
    
    level = 0
    name = next(names)
    parent = None
    G.add_node(name, s=sn_top, l=level, p=parent, side=None, ind = 0, exp=False)
    
    to_add = deque([name]) #string, level, name, parent
    while to_add:
        pname = to_add.popleft()
        plevel = G.nodes[pname]['l']
        pstring = G.nodes[pname]['s']
        pind = G.nodes[pname]['ind']
        d_ind = 1
        
        partners = []
        for p, side in zip(get_pairs(pstring), ['left', 'right']):            
            name = next(names)
            partners.append(name)
            exp = False
            if plevel + 1 > 4:
                exp = True
            G.add_node(name, s=p, l=plevel + 1, p=pname, side=side, ind=pind+d_ind, exp=exp)
            G.add_edge(pname, name)
            d_ind = len(p) + 2
            if '[' in p:
                to_add.append(name)
        for p1, p2 in zip((0, 1), (1, 0)):
            G.nodes[partners[p1]]['pairw'] = partners[p2]
    G_inds = {G.nodes[n]['ind']:n for n in G}
    
    return G, G_inds


def explode(G, G_inds):
    # get pair to explode
    ind_sort = sorted(list(G_inds.keys()))
    for ind_i, ind in enumerate(ind_sort):
        if G.nodes[G_inds[ind]]['exp']:
            eind_i, eind = ind_i, ind
            break
    print(G_inds[eind])
    if eind_i -1 >= 0:
        l = G_inds[ind_sort[eind_i - 1]]
        print(l)

# def graph_to_str(G, G_inds):
#     sn = ''
#     for ind in sorted(list(G_inds.keys())):
#         print(ind)
#     return sn

In [180]:
day = '18'
fn = f'../{day}/{day}.txt'

snns = load_input(fn)

In [181]:
print(snns[0])
get_pairs(snns[0])

[[[[6,3],7],0],[[7,0],0]]


('[[[6,3],7],0]', '[[7,0],0]')

In [207]:
# explode test
# sn = '[[[[[9,8],1],2],3],4]'
# sn = '[9,[8,7]]'
# sn = '[[[[1,3],[5,3]],[[1,3],[8,7]]],[[[4,9],[6,9]],[[8,2],[7,3]]]]'
sn = '[[3,[2,[1,[7,3]]]],[6,[5,[4,[3,2]]]]]'

G, G_inds = sn_graph(sn)
for n in G.nodes(data=True):
    print(n)
    
# for e in G.edges:
#     print(e)

(0, {'s': '[[3,[2,[1,[7,3]]]],[6,[5,[4,[3,2]]]]]', 'l': 0, 'p': None, 'side': None, 'ind': 0, 'exp': False})
(1, {'s': '[3,[2,[1,[7,3]]]]', 'l': 1, 'p': 0, 'side': 'left', 'ind': 1, 'exp': False, 'pairw': 2})
(2, {'s': '[6,[5,[4,[3,2]]]]', 'l': 1, 'p': 0, 'side': 'right', 'ind': 19, 'exp': False, 'pairw': 1})
(3, {'s': '3', 'l': 2, 'p': 1, 'side': 'left', 'ind': 2, 'exp': False, 'pairw': 4})
(4, {'s': '[2,[1,[7,3]]]', 'l': 2, 'p': 1, 'side': 'right', 'ind': 4, 'exp': False, 'pairw': 3})
(5, {'s': '6', 'l': 2, 'p': 2, 'side': 'left', 'ind': 20, 'exp': False, 'pairw': 6})
(6, {'s': '[5,[4,[3,2]]]', 'l': 2, 'p': 2, 'side': 'right', 'ind': 22, 'exp': False, 'pairw': 5})
(7, {'s': '2', 'l': 3, 'p': 4, 'side': 'left', 'ind': 5, 'exp': False, 'pairw': 8})
(8, {'s': '[1,[7,3]]', 'l': 3, 'p': 4, 'side': 'right', 'ind': 7, 'exp': False, 'pairw': 7})
(9, {'s': '5', 'l': 3, 'p': 6, 'side': 'left', 'ind': 23, 'exp': False, 'pairw': 10})
(10, {'s': '[4,[3,2]]', 'l': 3, 'p': 6, 'side': 'right', 'ind'

In [218]:
explode(G, G_inds)

15
12
