In [1]:
import os
import sys

sys.path.append('../grow_and_prune/')

from treelib import Tree, Node
from copy import deepcopy
import json

from txf_dataset import TxfNode, TxfDataset

In [2]:
# Test basic treelib.Tree object and its load/store
tree = Tree()

tree.create_node(tag='model_hash', identifier='model_hash', 
                 data=TxfNode('model_hash', mode=None, loss=1))
tree.create_node(tag='child_model_hash', identifier='child_model_hash', parent=tree.get_node('model_hash'),
                 data=TxfNode('child_model_hash', mode='grow_attn_head', loss=0.9))
tree.create_node(tag='child_model_hash2', identifier='child_model_hash2', parent=tree.get_node('model_hash'),
                 data=TxfNode('child_model_hash2', mode='grow_attn_head', loss=0.89))
tree.create_node(tag='grand_child_model_hash', identifier='grand_child_model_hash', 
                 parent=tree.get_node('child_model_hash2'),
                 data=TxfNode('grand_child_model_hash', mode='prune_attn_head', loss=0.91))

tree.show(data_property='loss', idhidden=False, line_type='ascii-exr')

1[model_hash]
├── 0.9[child_model_hash]
╰── 0.89[child_model_hash2]
    ╰── 0.91[grand_child_model_hash]



In [5]:
# Get the identifier of the root node
print(tree.root)

# Get the children of a node
tree.children('model_hash')

model_hash


[Node(tag=child_model_hash, identifier=child_model_hash, data={'model_hash': 'child_model_hash', 'mode': 'grow_attn_head', 'loss': 0.9}),
 Node(tag=child_model_hash2, identifier=child_model_hash2, data={'model_hash': 'child_model_hash2', 'mode': 'grow_attn_head', 'loss': 0.89})]

In [3]:
# Test storing of the tree object
tree = Tree()

tree.create_node(tag='model_hash', identifier='model_hash', 
                 data=TxfNode('model_hash', mode=None, loss=1))
tree.create_node(tag='child_model_hash', identifier='child_model_hash', parent=tree.get_node('model_hash'),
                 data=TxfNode('child_model_hash', mode='grow_attn_head', loss=0.9))

tree_dict = eval(str(tree.to_dict(with_data=True)))
json.dump(tree_dict, open('test_dataset.json', 'w+'))

In [53]:
# Test loading of tree object
json_dict = json.load(open('test_dataset.json', 'r'))

def _load_tree(tree: Tree, tree_dict: dict, parent=None):
    """Recursive function to load the tree

    Args:
        tree (Tree): treelib.Tree object
        tree_dict (dict): tree dictionary loaded from dataset_file
        parent (Node, optional): parent node to start with
    """
    model_hash, model_value = list(tree_dict.items())[0]

    if parent is None:
        tree.create_node(tag=model_hash, identifier=model_hash, data=TxfNode(**model_value['data']))
        parent = tree.get_node(model_hash)
    else:
        tree.create_node(tag=model_hash, identifier=model_hash, parent=parent, data=TxfNode(**model_value['data']))

    for child in tree_dict[model_hash].get('children', []):  
        _load_tree(tree, child, parent)

    return tree

tree = _load_tree(Tree(), json_dict)
tree.show(data_property='loss', idhidden=False, line_type='ascii-exr')

1[model_hash]
╰── 0.9[child_model_hash]



In [25]:
# Test TxfDataset
txf_dataset = TxfDataset('./test_dataset.json', '../models/')

txf_dataset.to_dict()

{'model_hash': {'children': [{'child_model_hash': {'data': {'model_hash': 'child_model_hash',
      'mode': 'grow_attn_head',
      'loss': 0.9}}}],
  'data': {'model_hash': 'model_hash', 'mode': None, 'loss': 1}}}

In [41]:
# Test update_node()
# txf_dataset.dataset.update_node('child_model_hash', data=TxfNode('child_model_hash', 
#                                         txf_dataset.dataset.get_node('child_model_hash').data.mode,
#                                         0.85))
# txf_dataset.to_dict()

txf_dataset.dataset['child_model_hash'].data.loss = 0.85
txf_dataset.to_dict()

{'model_hash': {'children': [{'child_model_hash': {'data': {'model_hash': 'child_model_hash',
      'mode': 'grow_attn_head',
      'loss': 0.85}}}],
  'data': {'model_hash': 'model_hash', 'mode': None, 'loss': 1}}}

In [7]:
# Preparing and saving preliminary dataset
BERT_BASE_HASH = '8b20da51c159887b310cabce176da7fb'
BERT_BASE_LOSS = 1.322 
BERT_BASE_STEPS = 100000

models_dir = '../models'
txf_dataset_file = '../grow_and_prune/dataset/dataset_base.json'

txf_dataset = TxfDataset(txf_dataset_file, models_dir, debug=True)

txf_dataset.add_node(model_hash=BERT_BASE_HASH, 
                     mode=None, 
                     loss=BERT_BASE_LOSS, 
                     steps=BERT_BASE_STEPS, 
                     parent_model_hash=None)

for model_hash in os.listdir(models_dir):
    if 'log_history.json' not in os.listdir(os.path.join(models_dir, model_hash)) or model_hash == BERT_BASE_HASH:
        continue
    txf_dataset.add_node(model_hash=model_hash, 
                         mode='grow_attn_head', 
                         loss=None, 
                         steps=None, 
                         parent_model_hash=BERT_BASE_HASH)
    
txf_dataset.update_dataset()
txf_dataset.show_dataset(data_property='loss')
txf_dataset.show_dataset(data_property='steps')
txf_dataset.show_dataset(data_property='mode')

[94mModel with best loss (1.318) has hash: fa8fab36a056ef491e17f7100b8ccbf5[0m
1.322[8b20da51c159887b310cabce176da7fb]
├── 1.3197[4813d0282fdffa4cd219318c04249a5d]
├── 1.3191[b8f4c354531c9499aaab1727c5e3e5e8]
├── 1.322[c9d0e9133b10da6af36b6c1643da3db5]
├── 1.319[f811b9a9f5d93fd00e2c9d8d7017fa02]
╰── 1.318[fa8fab36a056ef491e17f7100b8ccbf5]

100000[8b20da51c159887b310cabce176da7fb]
├── 110000[4813d0282fdffa4cd219318c04249a5d]
├── 110000[b8f4c354531c9499aaab1727c5e3e5e8]
├── 110000[c9d0e9133b10da6af36b6c1643da3db5]
├── 110000[f811b9a9f5d93fd00e2c9d8d7017fa02]
╰── 110000[fa8fab36a056ef491e17f7100b8ccbf5]

None[8b20da51c159887b310cabce176da7fb]
├── grow_attn_head[4813d0282fdffa4cd219318c04249a5d]
├── grow_attn_head[b8f4c354531c9499aaab1727c5e3e5e8]
├── grow_attn_head[c9d0e9133b10da6af36b6c1643da3db5]
├── grow_attn_head[f811b9a9f5d93fd00e2c9d8d7017fa02]
╰── grow_attn_head[fa8fab36a056ef491e17f7100b8ccbf5]

