In [2]:
import os
import sys

sys.path.append('../grow_and_prune/')

from treelib import Tree, Node
from copy import deepcopy
import json

from txf_dataset import TxfNode, TxfDataset

In [9]:
# Test basic treelib.Tree object and its load/store
tree = Tree()

tree.create_node(tag='model_hash', identifier='model_hash', 
                 data=TxfNode('model_hash', mode=None, loss=1))
tree.create_node(tag='child_model_hash', identifier='child_model_hash', parent=tree.get_node('model_hash'),
                 data=TxfNode('child_model_hash', mode='grow_attn_head', loss=0.9))
tree.create_node(tag='child_model_hash2', identifier='child_model_hash2', parent=tree.get_node('model_hash'),
                 data=TxfNode('child_model_hash2', mode='grow_attn_head', loss=0.89))
tree.create_node(tag='grand_child_model_hash', identifier='grand_child_model_hash', 
                 parent=tree.get_node('child_model_hash2'),
                 data=TxfNode('grand_child_model_hash', mode='prune_attn_head', loss=0.91))

tree.show(data_property='loss', idhidden=False, line_type='ascii-exr')

1[model_hash]
├── 0.9[child_model_hash]
╰── 0.89[child_model_hash2]
    ╰── 0.91[grand_child_model_hash]



In [3]:
# Test storing of the tree object
tree = Tree()

tree.create_node(tag='model_hash', identifier='model_hash', 
                 data=TxfNode('model_hash', mode=None, loss=1))
tree.create_node(tag='child_model_hash', identifier='child_model_hash', parent=tree.get_node('model_hash'),
                 data=TxfNode('child_model_hash', mode='grow_attn_head', loss=0.9))

tree_dict = eval(str(tree.to_dict(with_data=True)))
json.dump(tree_dict, open('test_dataset.json', 'w+'))

In [53]:
# Test loading of tree object
json_dict = json.load(open('test_dataset.json', 'r'))

def _load_tree(tree: Tree, tree_dict: dict, parent=None):
    """Recursive function to load the tree

    Args:
        tree (Tree): treelib.Tree object
        tree_dict (dict): tree dictionary loaded from dataset_file
        parent (Node, optional): parent node to start with
    """
    model_hash, model_value = list(tree_dict.items())[0]

    if parent is None:
        tree.create_node(tag=model_hash, identifier=model_hash, data=TxfNode(**model_value['data']))
        parent = tree.get_node(model_hash)
    else:
        tree.create_node(tag=model_hash, identifier=model_hash, parent=parent, data=TxfNode(**model_value['data']))

    for child in tree_dict[model_hash].get('children', []):  
        _load_tree(tree, child, parent)

    return tree

tree = _load_tree(Tree(), json_dict)
tree.show(data_property='loss', idhidden=False, line_type='ascii-exr')

1[model_hash]
╰── 0.9[child_model_hash]



In [25]:
# Test TxfDataset
txf_dataset = TxfDataset('./test_dataset.json', '../models/')

txf_dataset.to_dict()

{'model_hash': {'children': [{'child_model_hash': {'data': {'model_hash': 'child_model_hash',
      'mode': 'grow_attn_head',
      'loss': 0.9}}}],
  'data': {'model_hash': 'model_hash', 'mode': None, 'loss': 1}}}

In [41]:
# Test update_node()
# txf_dataset.dataset.update_node('child_model_hash', data=TxfNode('child_model_hash', 
#                                         txf_dataset.dataset.get_node('child_model_hash').data.mode,
#                                         0.85))
# txf_dataset.to_dict()

txf_dataset.dataset['child_model_hash'].data.loss = 0.85
txf_dataset.to_dict()

{'model_hash': {'children': [{'child_model_hash': {'data': {'model_hash': 'child_model_hash',
      'mode': 'grow_attn_head',
      'loss': 0.85}}}],
  'data': {'model_hash': 'model_hash', 'mode': None, 'loss': 1}}}