In [1]:
import os
from glob import glob
from collections import OrderedDict, defaultdict
import math
import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from torchvision.datasets import CIFAR10

In [2]:
def entropy(p):
    nats = -p * np.log(p, where=p>0.0)
    return nats.sum()

def get_dist(targets, num_classes:int=10):
    freq = np.zeros(num_classes)
    uniques, cnts = np.unique(targets, return_counts=True)
    for u, c in zip(uniques, cnts):
        freq[u] = c
    freq = freq / sum(cnts)
    return freq

In [9]:
EXP_PATH = 'saved/c100_r18_se_bald'

result_files = sorted(glob(os.path.join(EXP_PATH, "*", "results.json")))
config_files = sorted(glob(os.path.join(EXP_PATH, "*", "config.json")))

print(f"Found {len(result_files)} experiments.")

Found 1 experiments.


In [10]:
run_names = []
all_results = []
all_configs = []

for r, c in zip(result_files, config_files):
    assert os.path.split(r)[0] == os.path.split(c)[0]
    with open(c, 'r') as f:
        config = json.load(f)
        run_names.append(config['run_name'])
        all_configs.append(config)
    with open(r, 'r') as f:
        all_results.append(json.load(f))

In [11]:
# cifar10 = CIFAR10('datasets/cifar10', train=True)
# targets = np.asarray(cifar10.targets)

result_dict = OrderedDict()
for run_name, result in zip(run_names, all_results):
    all_queried_ids = []
    all_metrics = defaultdict(list)
    for ep in result:
        ep_idx = ep['episode']

        # Typo
        try:
            queried_ids = ep['episode/indices']
        except:
            queried_ids = ep['episode/indicies']
        all_queried_ids.extend(queried_ids)
        
        # all_metrics['entropy'].append(entropy(get_dist(targets[queried_ids], 10)))
        # all_metrics['overall_entropy'].append(entropy(get_dist(targets[all_queried_ids], 10)))
        
        if ep_idx == 0:
            continue
        
        for k, v in ep.items():
            if isinstance(v, float):
                all_metrics[k].append(v)

    min_len = 100
    for k, v in all_metrics.items():
        if len(v) < min_len:
            min_len = len(v)
    all_metrics = {k: v[:min_len] for k, v in all_metrics.items()}
    
    result_dict[run_name] = all_metrics

In [12]:
print(all_configs[0])

{'file': 'configs/cifar100_resnet18.json', 'run_name': 'cifar100_snapshot_20220924_083812', 'save_path': 'saved/c100_r18_se_bald/cifar100_snapshot_20220924_083812', 'dataset_name': 'cifar100', 'dataset_path': 'datasets', 'seed': 7, 'arch': 'resnet18', 'disable_tqdm': False, 'resume_from': None, 'learning_rate': 0.001, 'weight_decay': 0.01, 'momentum': 0.9, 'batch_size': 64, 'num_epochs': 250, 'optimizer_type': 'sgd', 'lr_scheduler_type': 'onecycle', 'lr_scheduler_param': 10.0, 'num_workers': 4, 'use_fp16': False, 'log_every': 10, 'eval_every': 10, 'swa_start': 200, 'swa_anneal_epochs': 10, 'swa_lr_multiplier': 5.0, 'swa_scheduler_type': 'constant', 'start_swa_at_end': True, 'num_episodes': 30, 'num_ensembles': 5, 'query_size': 1000, 'query_type': 'ensbald', 'init_query_size': 5000, 'init_query_type': 'random', 'eval_query_size': 1000, 'eval_query_type': 'random'}


In [13]:
pd.DataFrame(result_dict[run_names[0]]).T

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15
eval/acc,0.355,0.368,0.319,0.409,0.384,0.401,0.426,0.395,0.355,0.397,0.423,0.399,0.406,0.421,0.398,0.399
eval/max_acc,0.482,0.528,0.554,0.58,0.616,0.62,0.611,0.642,0.64,0.659,0.67,0.681,0.693,0.696,0.691,0.705
test/swa_acc,0.5176,0.546,0.5697,0.5956,0.6114,0.6193,0.6241,0.6318,0.6404,0.6407,0.6453,0.6425,0.6461,0.6456,0.6504,0.6479
test/swa_nll,1.854197,1.713237,1.615115,1.502358,1.431764,1.39159,1.358834,1.324044,1.295043,1.292729,1.2792,1.272465,1.258585,1.261906,1.240329,1.249689
test/swa_ece,0.05075,0.03691,0.036262,0.022628,0.021022,0.024544,0.031458,0.034079,0.040874,0.044908,0.048035,0.046877,0.052851,0.052911,0.059903,0.057782
test/swa_top5,0.8055,0.8315,0.848,0.8642,0.8772,0.8837,0.8883,0.8918,0.9007,0.8992,0.9029,0.9044,0.9043,0.9053,0.9077,0.9051
ens/acc,0.4383,0.4652,0.4826,0.5087,0.512,0.526,0.5333,0.5398,0.5358,0.5479,0.5404,0.5634,0.5451,0.5487,0.5643,0.5467
ens/nll,2.270604,2.147083,2.060256,1.970901,1.955554,1.898804,1.842378,1.826268,1.852305,1.827896,1.827619,1.805142,1.878925,1.822114,1.774191,1.813301
ens/ece,0.091145,0.12536,0.135015,0.148194,0.165492,0.170108,0.170287,0.179277,0.182906,0.201395,0.185681,0.217029,0.219392,0.196892,0.21477,0.196229
ens/top5,0.7219,0.7531,0.7766,0.7959,0.8023,0.8149,0.8275,0.8317,0.8297,0.8424,0.836,0.8406,0.8324,0.8417,0.8479,0.8363


In [14]:
METRICS = ['test/swa_acc', 'test/swa_nll', 'test/swa_ece', 'test/swa_top5', 'before/acc', 'before/nll', 'before/ece']

all_dfs = []
for run_name, metric in result_dict.items():
    df = pd.DataFrame(metric)[METRICS]
    all_dfs.append(df.T)
pd.concat(all_dfs, axis=0)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15
test/swa_acc,0.5176,0.546,0.5697,0.5956,0.6114,0.6193,0.6241,0.6318,0.6404,0.6407,0.6453,0.6425,0.6461,0.6456,0.6504,0.6479
test/swa_nll,1.854197,1.713237,1.615115,1.502358,1.431764,1.39159,1.358834,1.324044,1.295043,1.292729,1.2792,1.272465,1.258585,1.261906,1.240329,1.249689
test/swa_ece,0.05075,0.03691,0.036262,0.022628,0.021022,0.024544,0.031458,0.034079,0.040874,0.044908,0.048035,0.046877,0.052851,0.052911,0.059903,0.057782
test/swa_top5,0.8055,0.8315,0.848,0.8642,0.8772,0.8837,0.8883,0.8918,0.9007,0.8992,0.9029,0.9044,0.9043,0.9053,0.9077,0.9051
before/acc,0.4927,0.5268,0.5429,0.5846,0.6108,0.6139,0.6338,0.6516,0.6621,0.6736,0.6785,0.6883,0.694,0.7013,0.7073,0.707
before/nll,2.173026,2.017881,1.945873,1.772382,1.644741,1.597764,1.515287,1.435381,1.374218,1.355008,1.295995,1.270305,1.224363,1.201637,1.177052,1.163177
before/ece,0.04223,0.043934,0.042109,0.041031,0.047733,0.032394,0.038721,0.040342,0.031604,0.037367,0.025364,0.030599,0.03142,0.03037,0.029872,0.023149


In [108]:
result_dict.keys()

odict_keys(['cifar10_snapshot_20220919_141947', 'cifar10_snapshot_20220919_141955', 'cifar10_snapshot_20220919_141959', 'cifar10_snapshot_20220919_142002', 'cifar10_snapshot_20220919_142006'])