In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import pickle
import glob
import os.path
import pandas as pd

import qgrid


def get_keychain_value_iter(d, key_chain=None):
    key_chain = [] if key_chain is None else list(key_chain).copy()       
    
    if not isinstance(d, dict):
        
        yield tuple(key_chain), d
    else:
        for k, v in d.items():
            yield from get_keychain_value_iter(v, key_chain + [k])
            
def get_keychain_value(d, key_chain):
    
    try:
        for k in key_chain:
            d = d[k]
            
    except Exception as ex:
        raise KeyError() from ex
        
    return d

In [None]:
training_cfg = {
    'lr': float, 
    'lr_drop_fact': float, 
    'num_epochs': int,
    'epoch_step': int,
    'batch_size': int,
    'weight_decay': float,
    'validation_ratio': float, 
}

model_cfg_meta = {
    'model_type': 'PershomModel',
    'model_kwargs': {
        'use_sup_lvlset_filt': bool,
        'filtration_kwargs': {
            'use_node_deg': bool,
            'use_node_lab': bool,
            'num_gin': int,
            'hidden_dim': int, 
            'use_mlp': bool
        }, 
        'classifier_kwargs': {
            'num_struct_elem': int
        }
    }
}

exp_cfg_meta = {
    'dataset_name': str, 
    'training': training_cfg, 
    'model': model_cfg_meta
}

exp_res_meta = {
    'exp_cfg': exp_cfg_meta, 
    'cv_test_acc': list, 
    'cv_val_acc': list, 
    'cv_indices_trn_tst_val': list, 
    'cv_epoch_loss': list,
    'start_time': list, 
    'id': str    
}

In [None]:
kc = {k: k[-1] for k, v in list(get_keychain_value_iter(exp_res_meta))}
kc;

In [None]:
COL_NAMES = {
 ('exp_cfg', 'dataset_name'): 'dataset_name',
#  ('exp_cfg', 'training', 'lr'): 'lr',
#  ('exp_cfg', 'training', 'lr_drop_fact'): 'lr_drop_fact',
 ('exp_cfg', 'training', 'num_epochs'): 'num_epochs',
#  ('exp_cfg', 'training', 'epoch_step'): 'epoch_step',
#  ('exp_cfg', 'training', 'batch_size'): 'batch_size',
#  ('exp_cfg', 'training', 'weight_decay'): 'weight_decay',
    
#  ('exp_cfg', 'training', 'validation_ratio'): 'validation_ratio',
    
#  ('exp_cfg', 'model', 'model_type'): 'model_type',
 ('exp_cfg',
  'model',
  'model_kwargs',
  'use_sup_lvlset_filt'): 'use_sup_lvlset_filt',
 ('exp_cfg',
  'model',
  'model_kwargs',
  'filtration_kwargs',
  'use_node_deg'): 'use_node_deg',
 ('exp_cfg',
  'model',
  'model_kwargs',
  'filtration_kwargs',
  'use_node_lab'): 'use_node_lab',
#  ('exp_cfg',
#   'model',
#   'model_kwargs',
#   'filtration_kwargs',
#   'num_gin'): 'num_gin',
 ('exp_cfg',
  'model',
  'model_kwargs',
  'filtration_kwargs',
  'hidden_dim'): 'hidden_dim',
 ('exp_cfg',
  'model',
  'model_kwargs',
  'filtration_kwargs',
  'use_mlp'): 'use_mlp',
 ('exp_cfg',
  'model',
  'model_kwargs',
  'classifier_kwargs',
  'num_struct_elem'): 'num_struct_elem',
#  ('cv_test_acc',): 'cv_test_acc',
#  ('cv_val_acc'): 'cv_val_acc',
#  ('cv_indices_trn_tst_val',): 'cv_indices_trn_tst_val', # may not be existent in older versions
#  ('cv_epoch_loss',): 'cv_epoch_loss',
#  ('start_time',): 'start_time',
#  ('id',): 'id'
}

In [None]:
def read_files(path):
    files = glob.glob(os.path.join(path, "*.pickle"))
    res = []
    for f in files: 
        res.append(pickle.load(open(f, 'rb')))

    print("# Found", len(res), "files.")
    return res


def pd_frame(path):
    
    f = read_files(path)
    
    data_frames = []
    for i, res in enumerate(f):
        row = {}
        finished_training = (len(res['cv_test_acc'][-1]) == res['exp_cfg']['training']['num_epochs'])
        
        row['finished_training'] = finished_training
            
        cv_acc = [x[-1] for x in res['cv_test_acc'] if len(x) > 0]
        
        row['cv_test_acc_last_mean'] = np.mean(cv_acc)
        row['cv_test_acc_last_std'] = np.std(cv_acc)
        
        for k, v in COL_NAMES.items():
            try:
                row[v] = get_keychain_value(res, k)
            except KeyError:
                pass
    
        
#         print(row)
        f = pd.DataFrame(row, index=[i])
        
        data_frames.append(f)
        
        
    return pd.concat(data_frames, sort=True)

In [None]:
path = './results/'
RES = pd_frame(path)
qgrid_widget = qgrid.show_grid(RES, show_toolbar=True)


In [None]:
qgrid_widget

In [None]:
qgrid_widget.get_changed_df()