In [1]:
import os
from glob import glob
from collections import OrderedDict, defaultdict
import math
import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from torchvision.datasets import CIFAR10

In [2]:
def entropy(p):
    nats = -p * np.log(p, where=p>0.0)
    return nats.sum()

def get_dist(targets, num_classes:int=10):
    freq = np.zeros(num_classes)
    uniques, cnts = np.unique(targets, return_counts=True)
    for u, c in zip(uniques, cnts):
        freq[u] = c
    freq = freq / sum(cnts)
    return freq

In [30]:
EXP_PATH = 'saved/c10_v16_maxent'

result_files = sorted(glob(os.path.join(EXP_PATH, "*", "results.json")))
config_files = sorted(glob(os.path.join(EXP_PATH, "*", "config.json")))

print(f"Found {len(result_files)} experiments.")

Found 1 experiments.


In [31]:
run_names = []
all_results = []
all_configs = []

for r, c in zip(result_files, config_files):
    assert os.path.split(r)[0] == os.path.split(c)[0]
    with open(c, 'r') as f:
        config = json.load(f)
        run_names.append(config['run_name'])
        all_configs.append(config)
    with open(r, 'r') as f:
        all_results.append(json.load(f))

In [32]:
cifar10 = CIFAR10('/opt/datasets/cifar10', train=True)
targets = np.asarray(cifar10.targets)

result_dict = OrderedDict()
for run_name, result in zip(run_names, all_results):
    all_queried_ids = []
    all_metrics = defaultdict(list)
    for ep in result:
        ep_idx = ep['episode']

        # Typo
        try:
            queried_ids = ep['episode/indices']
        except:
            queried_ids = ep['episode/indicies']
        all_queried_ids.extend(queried_ids)
        
        all_metrics['entropy'].append(entropy(get_dist(targets[queried_ids], 10)))
        all_metrics['overall_entropy'].append(entropy(get_dist(targets[all_queried_ids], 10)))
        
        if ep_idx == 0:
            continue
        
        for k, v in ep.items():
            if isinstance(v, float):
                all_metrics[k].append(v)

    min_len = 100
    for k, v in all_metrics.items():
        if len(v) < min_len:
            min_len = len(v)
    all_metrics = {k: v[:min_len] for k, v in all_metrics.items()}
    
    result_dict[run_name] = all_metrics

In [33]:
print(all_configs[0])

{'file': 'configs/cifar10_vgg16.json', 'run_name': 'cifar10_snapshot_20220919_174546', 'save_path': 'saved/c10_v16_maxent/cifar10_snapshot_20220919_174546', 'dataset_name': 'cifar10', 'dataset_path': '/opt/datasets', 'seed': 42, 'arch': 'vgg16', 'disable_tqdm': False, 'resume_from': None, 'learning_rate': 0.001, 'weight_decay': 0.01, 'momentum': 0.9, 'batch_size': 64, 'num_epochs': 200, 'optimizer_type': 'sgd', 'lr_scheduler_type': 'constant', 'lr_scheduler_param': 10.0, 'num_workers': 4, 'log_every': 10, 'eval_every': 10, 'swa_start': 100, 'swa_anneal_epochs': 50, 'swa_lr_multiplier': 10.0, 'swa_scheduler_type': 'constant', 'start_swa_at_end': True, 'num_episodes': 30, 'num_ensembles': 5, 'query_size': 500, 'query_type': 'maxentropy', 'init_query_size': 500, 'init_query_type': 'random', 'eval_query_size': 500, 'eval_query_type': 'random', 'eval_size': 500}


In [34]:
pd.DataFrame(result_dict[run_names[0]]).T

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
entropy,2.291243,2.248366,2.276165,2.183876,2.121027,2.186897,2.228213,2.157343,2.207487,2.184636,2.07464,1.901184,1.993374,1.937773,2.19838,1.866092,2.176153,2.208813,2.195562,2.091958
overall_entropy,2.291243,2.284166,2.29352,2.282068,2.28909,2.287332,2.285109,2.279786,2.283619,2.279585,2.285705,2.279408,2.28624,2.279615,2.282042,2.279221,2.278855,2.277678,2.278527,2.280042
eval/acc,0.402,0.332,0.39,0.528,0.48,0.508,0.522,0.538,0.554,0.53,0.576,0.628,0.668,0.572,0.684,0.68,0.62,0.704,0.704,0.76
eval/max_acc,0.402,0.43,0.49,0.528,0.548,0.554,0.576,0.606,0.626,0.626,0.628,0.644,0.68,0.68,0.684,0.682,0.668,0.704,0.704,0.76
test/swa_acc,0.4492,0.5153,0.5844,0.6012,0.6185,0.6644,0.6821,0.7015,0.7256,0.7369,0.7706,0.774,0.7823,0.7894,0.8082,0.8091,0.8199,0.8347,0.827,0.8417
test/swa_nll,2.312959,1.941686,1.630996,1.529135,1.29962,1.194669,1.06569,1.001198,0.886611,0.854641,0.71644,0.691743,0.674876,0.641045,0.58405,0.577012,0.559762,0.518876,0.532278,0.510366
test/swa_ece,0.342425,0.293719,0.243289,0.231784,0.18644,0.173881,0.148843,0.13752,0.107079,0.104168,0.054092,0.047032,0.03675,0.025469,0.016743,0.014856,0.013409,0.017686,0.017217,0.03753
test/swa_top5,0.8919,0.9099,0.9291,0.9336,0.9484,0.9533,0.9649,0.9694,0.969,0.9727,0.9772,0.9821,0.9774,0.9811,0.9873,0.9844,0.9869,0.9894,0.9882,0.9888
ens/acc,0.4464,0.4719,0.5353,0.5837,0.6287,0.6417,0.6403,0.6842,0.7105,0.6699,0.711,0.7122,0.7394,0.7671,0.7436,0.7652,0.7936,0.7837,0.7974,0.8069
ens/nll,1.871005,1.615373,1.380093,1.23028,1.06166,1.046766,1.057427,0.933162,0.858658,0.972901,0.861337,0.830404,0.802452,0.72864,0.778188,0.707454,0.667927,0.672652,0.645757,0.610079


In [39]:
METRICS = ['test/swa_acc', 'test/swa_nll', 'test/swa_ece']

all_dfs = []
for run_name, metric in result_dict.items():
    df = pd.DataFrame(metric)[METRICS]
    all_dfs.append(df.T)
pd.concat(all_dfs, axis=0).iloc[:, 10:]

Unnamed: 0,10,11,12,13,14,15,16,17,18,19
test/swa_acc,0.7706,0.774,0.7823,0.7894,0.8082,0.8091,0.8199,0.8347,0.827,0.8417
test/swa_nll,0.71644,0.691743,0.674876,0.641045,0.58405,0.577012,0.559762,0.518876,0.532278,0.510366
test/swa_ece,0.054092,0.047032,0.03675,0.025469,0.016743,0.014856,0.013409,0.017686,0.017217,0.03753


In [108]:
result_dict.keys()

odict_keys(['cifar10_snapshot_20220919_141947', 'cifar10_snapshot_20220919_141955', 'cifar10_snapshot_20220919_141959', 'cifar10_snapshot_20220919_142002', 'cifar10_snapshot_20220919_142006'])