In [1]:
import matplotlib.pyplot as plt
import numpy as np
import seml
import pandas as pd
import json
from collections import defaultdict

  from tqdm.autonotebook import tqdm


In [2]:
collection_name = 'week4_density'
collection = seml.database.get_collection(collection_name)
results = [{'config' : r['config'], 'result' : r['result'], 'id' : r['_id']} for r in collection.find() if r['status'] in ('COMPLETED',)]

In [3]:
def flatten_dict(d):
    new = dict()
    for k, v in d.items():
        if isinstance(v, dict):
            for kk, vv in flatten_dict(v).items():
                new[k + '.' + kk] = vv
        else:
            new[k] = v
    return new
    

In [4]:
# Collect everything into a data frame
df = defaultdict(list)
for result in results:
    for k, v in flatten_dict(result['config']).items():
        df[k].append(v)
    with open(result['result']) as f:
        metrics = json.load(f)
        for metric, values in metrics.items():
            df[metric + '.mean'].append(np.array(values).mean())
            df[metric + '.std'].append(np.array(values).std())
df = pd.DataFrame(df)
    

In [5]:
df.columns

Index(['overwrite', 'db_collection', 'data.dataset', 'data.num_dataset_splits',
       'data.split_type', 'data.test_portion', 'data.test_portion_fixed',
       'data.train_labels', 'data.train_labels_remove_other',
       'data.train_portion', 'data.val_labels', 'data.val_portion',
       'evaluation.pipeline', 'model.activation',
       'model.freeze_residual_projection', 'model.hidden_sizes',
       'model.leaky_relu_slope', 'model.model_type',
       'model.num_initializations', 'model.residual', 'model.use_bias',
       'model.use_spectral_norm', 'model.weight_scale', 'run.args', 'run.name',
       'training.early_stopping.min_delta', 'training.early_stopping.mode',
       'training.early_stopping.monitor', 'training.early_stopping.patience',
       'training.gpus', 'training.learning_rate', 'training.max_epochs',
       'seed', 'val_loss.mean', 'val_loss.std', 'val_accuracy.mean',
       'val_accuracy.std', 'auroc_gpc-train.mean', 'auroc_gpc-train.std',
       'auroc_gpc-all.mean

In [16]:
df_auroc = df[['model.hidden_sizes', 'model.weight_scale', 'model.residual', 'model.freeze_residual_projection', 'data.train_labels_remove_other'] + [col for col in df.columns if 'auroc' in col]]

In [17]:
df_auroc.sort_values(by=[col for col in df_auroc.columns if col.startswith('auroc') and col.endswith('.mean')], ascending=False)

Unnamed: 0,model.hidden_sizes,model.weight_scale,model.residual,model.freeze_residual_projection,data.train_labels_remove_other,auroc_gpc-train.mean,auroc_gpc-train.std,auroc_gpc-all.mean,auroc_gpc-all.std,auroc_gpc-32-pca-train.mean,...,auroc_gpc-32-pca-all.mean,auroc_gpc-32-pca-all.std,auroc_gpc-16-pca-train.mean,auroc_gpc-16-pca-train.std,auroc_gpc-16-pca-all.mean,auroc_gpc-16-pca-all.std,auroc_gpc-8-pca-train.mean,auroc_gpc-8-pca-train.std,auroc_gpc-8-pca-all.mean,auroc_gpc-8-pca-all.std
36,"[64, 32]",2.0,True,True,False,0.67545,0.083071,0.22285,0.079137,0.56815,...,0.349783,0.055596,0.62005,0.088374,0.31185,0.057903,0.687383,0.087294,0.20575,0.080682
37,"[64, 32]",3.0,True,True,False,0.6664,0.079207,0.229867,0.080365,0.608533,...,0.36525,0.052076,0.6525,0.100362,0.323217,0.060713,0.685,0.094364,0.226817,0.075828
38,"[64, 32]",4.0,True,True,False,0.60585,0.08742,0.303317,0.075797,0.56815,...,0.360933,0.063543,0.5965,0.112022,0.326533,0.060859,0.621267,0.113897,0.246267,0.068315
35,"[64, 32]",1.5,True,True,False,0.596733,0.095442,0.256067,0.078967,0.53505,...,0.342133,0.053783,0.576167,0.086557,0.313783,0.057556,0.64595,0.09238,0.2081,0.079036
40,"[64, 64]",1.0,True,True,False,0.59465,0.08504,0.263283,0.078793,0.5336,...,0.3595,0.079731,0.573883,0.092809,0.303167,0.073811,0.62915,0.095098,0.216333,0.071635
39,"[64, 32]",5.0,True,True,False,0.566167,0.074192,0.347433,0.084136,0.540883,...,0.360333,0.074183,0.5724,0.08674,0.33125,0.066373,0.615567,0.076883,0.274217,0.063248
0,[64],0.9,True,True,True,0.564533,0.075408,0.415108,0.059423,0.523167,...,0.43565,0.066202,0.511283,0.066261,0.400133,0.069305,0.512367,0.065066,0.368217,0.050517
3,[64],1.5,True,True,True,0.5616,0.072563,0.406033,0.058863,0.515617,...,0.43815,0.066612,0.524533,0.076479,0.402133,0.079835,0.528233,0.072532,0.3626,0.04597
2,[64],1.1,True,True,True,0.560117,0.071771,0.411017,0.061375,0.528383,...,0.435717,0.067551,0.516683,0.066143,0.3973,0.074612,0.519517,0.061365,0.359917,0.050991
26,[64],1.1,True,True,False,0.557133,0.080958,0.241083,0.065088,0.4724,...,0.369567,0.06858,0.492383,0.089233,0.3185,0.077549,0.56755,0.078376,0.24105,0.066699
