In [1]:
%load_ext autoreload

%autoreload 2
from utils import parse_table, read_csv
import pandas as pd
import numpy as np

from pretty import plot_kde, ColorTheme



print_mapping = {
    'fashion-mnist': 'FMNIST',
    'mnist': 'MNIST',
    'emnist': 'EMNIST',
    'omniglot': 'Omniglot',
    'celeba-small': 'CelebA',
    'svhn': 'SVHN',
    'cifar10': 'CIFAR10',
    'cifar100': 'CIFAR100',
    'tiny-imagenet': 'Tiny',
}

cclr_hps = ['0.1', '0.25', '0.33', '0.5', '0.66']
all_tasks = {}

for hp in cclr_hps:
    df_grayscale = read_csv(f'grayscale_cclr_{hp}.csv')
    df_rgb = read_csv(f'rgb_cclr_{hp}.csv')
    all_grayscale_tasks = parse_table(df_grayscale)
    all_rgb_tasks = parse_table(df_rgb)
    all_tasks_internal = {
        'grayscale': all_grayscale_tasks,
        'rgb': all_rgb_tasks,
    }

    all_pairs = []
    for tp in all_tasks_internal.keys():
        for in_distr in all_tasks_internal[tp].keys():
                for ood in all_tasks_internal[tp].keys():
                    if in_distr != ood:
                        all_pairs.append((in_distr, ood, tp))
    all_tasks[hp] = all_tasks_internal

3/5: 100%|██████████| 4/4 [00:01<00:00,  3.36it/s]
4/6: 100%|██████████| 5/5 [00:02<00:00,  2.36it/s]
3/5: 100%|██████████| 4/4 [00:01<00:00,  3.30it/s]
4/6: 100%|██████████| 5/5 [00:02<00:00,  2.27it/s]
3/5: 100%|██████████| 4/4 [00:01<00:00,  3.35it/s]
4/6: 100%|██████████| 5/5 [00:02<00:00,  2.34it/s]
3/5: 100%|██████████| 4/4 [00:01<00:00,  3.39it/s]
4/6: 100%|██████████| 5/5 [00:02<00:00,  2.32it/s]
3/5: 100%|██████████| 4/4 [00:01<00:00,  3.22it/s]
4/6: 100%|██████████| 5/5 [00:02<00:00,  2.27it/s]


In [6]:
all_tasks['0.1'].keys()

dict_keys(['grayscale', 'rgb'])

In [3]:

def get_scores(in_distr, ood, type, all_tasks): 
    global score_in, score_ood, score_generated
    in_vs_out = all_tasks[type][in_distr][ood]
    # find the column that starts with 'Cclr with frac'
    col = [c for c in in_vs_out.columns if c.startswith('Cclr with frac')][0]
    score_generated = in_vs_out[in_vs_out['name'] == 'generated'][col].values
    score_in = in_vs_out[in_vs_out['name'] == 'test'][col].values
    score_ood = in_vs_out[in_vs_out['name'] == 'ood'][col].values

In [8]:
from roc_analysis import get_roc_graph, get_convex_hull, get_auc
from tqdm import tqdm

for in_distr, ood, tp in all_pairs:
    print(f'{print_mapping[in_distr]} vs {print_mapping[ood]}')
    max_auc = -1.
    max_auc_generated = -1.
    for hp in cclr_hps:
        get_scores(in_distr, ood, tp, all_tasks[hp])
        x_naive, y_naive = get_roc_graph(
            pos_x = score_in,
            neg_x = score_ood,
            verbose=0,
        )
        x_curve, y_curve = get_convex_hull(x_naive, y_naive)
        test_vs_ood_auc = get_auc(x_curve, y_curve)
        max_auc = max(max_auc, test_vs_ood_auc)
        
        x_naive, y_naive = get_roc_graph(
            pos_x = score_generated,
            neg_x = score_ood,
            verbose=0,
        )
        x_curve, y_curve = get_convex_hull(x_naive, y_naive)
        generated_vs_ood_auc = get_auc(x_curve, y_curve)
        max_auc_generated = max(max_auc_generated, generated_vs_ood_auc)
    print("best AUC of test-vs-out", "{:.3f}".format(max_auc))
    print("best AUC of generated-vs-out", "{:.3f}".format(max_auc_generated))
    print("----")

EMNIST vs FMNIST
best AUC of test-vs-out 0.238
best AUC of generated-vs-out 0.997
----
EMNIST vs Omniglot
best AUC of test-vs-out 0.285
best AUC of generated-vs-out 0.997
----
EMNIST vs MNIST
best AUC of test-vs-out 0.436
best AUC of generated-vs-out 0.998
----
FMNIST vs EMNIST
best AUC of test-vs-out 0.576
best AUC of generated-vs-out 0.997
----
FMNIST vs Omniglot
best AUC of test-vs-out 0.388
best AUC of generated-vs-out 0.996
----
FMNIST vs MNIST
best AUC of test-vs-out 0.781
best AUC of generated-vs-out 0.998
----
Omniglot vs EMNIST
best AUC of test-vs-out 0.263
best AUC of generated-vs-out 0.996
----
Omniglot vs FMNIST
best AUC of test-vs-out 0.169
best AUC of generated-vs-out 0.995
----
Omniglot vs MNIST
best AUC of test-vs-out 0.338
best AUC of generated-vs-out 0.997
----
MNIST vs EMNIST
best AUC of test-vs-out 0.391
best AUC of generated-vs-out 0.998
----
MNIST vs FMNIST
best AUC of test-vs-out 0.224
best AUC of generated-vs-out 0.996
----
MNIST vs Omniglot
best AUC of test-vs-