In [None]:
import numpy as np
import h5py
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
import os
from zss import Node, simple_distance
import re
from tqdm.auto import tqdm
import json
import time
import pandas as pd

In [None]:
np.random.seed(42)

def rev_bfs(tree, root_index=0):
    rbfs = list()

    queue = [root_index]
    while (len(queue) > 0):
        current_index = queue.pop(0)
        rbfs.insert(0, current_index)
        for c in tree[current_index]['children']:
            queue.append(c)

    return rbfs

In [None]:
def parse_graphviz_tree_zss(gv_tree, root_index=0, constant_bins=None, **args):
    tree = dict()
    # nodes = re.findall(r'\d+ \[label=\"\S+', gv_tree)
    nodes = re.findall(r'^(.*?)(?=\s+fillcolor)', gv_tree, re.MULTILINE)

    for i in range(len(nodes)):
        
        nodes[i] = nodes[i].split(' [label=')
        nodes[i][0] = int(nodes[i][0])
        
        nodes[i][1] = nodes[i][1].replace(',', '')
        tree[nodes[i][0]] = {'depth': 1, 'value': nodes[i][1], 'children': list()}
    
    children = re.findall(r'\d+ -> \d+', gv_tree)
    for i in range(len(children)):
        
        children[i] = children[i].split(' -> ')
        children[i][0] = int(children[i][0])
        children[i][1] = int(children[i][1])
        tree[children[i][0]]['children'].append(children[i][1])
        tree[children[i][0]]['children'].sort()
    

    starting_index = min(tree.keys())
    stack = [starting_index]
    tree_height = 0
    while len(stack) > 0:
        current_node = stack.pop(0)
        current_depth = tree[current_node]['depth']
        if (tree_height < current_depth):
            tree_height = current_depth
        for i in tree[current_node]['children']:
            tree[i]['depth'] = current_depth + 1
            stack.append(i)

    indexes = rev_bfs(tree)

    zss_list = [None for i in range(len(indexes))]

    for i in indexes:
        try:
            tree[i]['value'] = float(tree[i]['value'])
        except ValueError:
            pass
        else:
            if (constant_bins is not None):
                tree[i]['value'] = 'const_%.2f' % get_bin(constant_bins, tree[i]['value'])
        if (len(tree[i]['children']) > 0):
            zss_list[i] = Node(tree[i]['value'], [zss_list[c] for c in tree[i]['children']])
        else:
            zss_list[i] = Node(tree[i]['value'])
    # print(zss_list[0])
    return zss_list[0]