In [1]:
import inspect

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import numpy as np
import os
import sys
import matplotlib.pyplot as plt
from sklearn import decomposition
from scipy.stats import norm
import scipy.stats
from sklearn.manifold import TSNE
import seaborn as sns
from pathlib import Path
import collections
import itertools
from numpy import linalg as LA

In [4]:
superclass_list = np.genfromtxt('restricted_superclass.csv', delimiter=',', dtype='i8')
superclasses = [tuple(np.where(superclass_list == i)[0]) for i in range(0,9)]
superclass_names = ['Dog', 'Cat', 'Frog', 'Turtle', 'Bird', 'Monkey', 'Fish', 'Crab', 'Insect', '', 'Other']

In [5]:
def get_predictions(logits:np.ndarray) -> np.ndarray:
    preds = {'logits': logits,
             'predicted_classes': logits.argmax(axis=1),
             'class_probabilities': scipy.special.softmax(logits, axis=1),
             'confidences_classifier': scipy.special.softmax(logits, axis=1).max(axis=1),
            }
    preds['number_of_class_predictions'] = collections.Counter(preds['predicted_classes'])
    return preds

In [6]:
features, logits, labels, classes = {}, {}, {}, {}

In [7]:
mod = ''
model_name = f"resnet18{mod}"

In [8]:
logits, labels, classes = {}, {}, {}

In [9]:
outputs_load_paths = [f'outputs/{model_name}', f'/mnt/SHARED/Julian/feature_density_ood/notebooks/outputs/{model_name}']

In [10]:
def load_outputs(*dset_names):
    for dset_name in dset_names:
        for path in outputs_load_paths:
            filename = os.path.join(path, f'{dset_name}.npz')
            if os.path.exists(filename):
                print(filename)
                outputs = np.load(filename, allow_pickle=True)
                break
        else:
            raise FileNotFoundError(f'{dset_name}.npz')
        for d in ['logits', 'labels', 'classes']:
           globals()[d][dset_name] = outputs[d]

In [11]:
dataset = 'imagenet_val_clean'

In [12]:
load_outputs(dataset)

outputs/resnet18/imagenet_val_clean.npz


In [13]:
predictions_IN_val = get_predictions(logits['imagenet_val_clean'])

In [14]:
predictions_IN_val['predicted_classes']

array([ 48,   0,   0, ..., 999, 333, 999])

In [15]:
def standard_accuracy(logits, targets):
    preds = get_predictions(logits)
    accuracy = (preds['predicted_classes'] == targets).mean()
    return accuracy

In [16]:
def classwise_accuracies(logits, targets):
    preds = get_predictions(logits)
    clw_acc = {}
    for i in set(targets):
        clw_acc[i] = np.equal(preds['predicted_classes'][np.where(targets == i)], i).mean()
    return clw_acc

In [17]:
def classwise_sample_numbers(targets):
    clw_sn = {}
    for i in set(targets):
        clw_sn[i] = np.sum(targets == i)
    return clw_sn

In [18]:
def classwise_topk_accuracies(logits, targets, k):
    preds = get_predictions(logits)
    clw_topk_acc = {}
    for i in set(targets):
        clw_topk_acc[i] = np.equal(i, np.argsort(preds['class_probabilities'][np.where(targets == i)], axis=1)[:,-k:]).sum(axis=-1).mean()
    return clw_topk_acc

In [19]:
def standard_balanced_topk_accuracy(logits, targets, k):
    clw_topk_acc = classwise_topk_accuracies(logits, targets, k)
    return np.array(list(clw_topk_acc.values())).mean()

In [20]:
def worst_class_accuracy(logits, targets):
    cwa = classwise_accuracies(logits, targets)
    worst_item = min(cwa.items(), key=lambda x: x[1])
    return worst_item[1], worst_item[0]

In [21]:
def worst_class_topk_accuracy(logits, targets,k):
    clw_topk_acc = classwise_topk_accuracies(logits, targets, k)
    worst_item =  min(clw_topk_acc.items(), key=lambda x: x[1])
    return worst_item[1], worst_item[0]

In [22]:
def worst_balanced_n_classes_accuracy(logits, targets, n):
    cwa = classwise_accuracies(logits, targets)
    sorted_cwa =  sorted(cwa.items(), key=lambda item: item[1])
    n_worst = sorted_cwa[:n]
    return np.array([x[1] for x in n_worst]).mean()

In [23]:
def worst_heuristic_n_classes_recall(logits, targets, n):
    cwa = classwise_accuracies(logits, targets)
    clw_sn =classwise_sample_numbers(targets)
    sorted_cwa =  sorted(cwa.items(), key=lambda item: item[1])
    n_worst = sorted_cwa[:n]
    nwc = np.array([v*clw_sn[c] for c,v in n_worst]).sum()/np.array([clw_sn[c] for c,v in n_worst]).sum()
    return nwc

In [24]:
def worst_balanced_n_classes_topk_accuracy(logits, targets, n, k):
    clw_topk_acc = classwise_topk_accuracies(logits, targets, k)
    sorted_clw_topk_acc =  sorted(clw_topk_acc.items(), key=lambda item: item[1])
    n_worst = sorted_clw_topk_acc[:n]
    return np.array([x[1] for x in n_worst]).mean()

In [25]:
def worst_heuristic_n_classes_topk_recall(logits, targets, n, k):
    clw_topk_acc = classwise_topk_accuracies(logits, targets, k)
    clw_sn =classwise_sample_numbers(targets)
    sorted_clw_topk_acc =  sorted(clw_topk_acc.items(), key=lambda item: item[1])
    n_worst = sorted_clw_topk_acc[:n]
    nwc = np.array([v*clw_sn[c] for c,v in n_worst]).sum()/np.array([clw_sn[c] for c,v in n_worst]).sum()
    return nwc

In [26]:
def worst_balanced_two_class_binary_accuracy(logits, targets):
    classes = list(set(targets))
    binary_accuracies = {}
    for i,j in itertools.combinations(classes, 2):
        i_labelled = logits[np.where(targets == i)]
        j_labelled = logits[np.where(targets == j)]
        i_correct = np.greater(i_labelled[:,i], i_labelled[:,j]).mean()
        j_correct = np.greater(j_labelled[:,j], j_labelled[:,i]).mean()
        binary_accuracies[(i,j)] = (i_correct + j_correct)/2
    sorted_binary_accuracies = sorted(binary_accuracies.items(), key=lambda item: item[1])
    worst_item = sorted_binary_accuracies[0]
    return worst_item[1], worst_item[0][0], worst_item[0][1]

In [27]:
def worst_balanced_superclass_recall(logits, targets, superclasses):
    cwa = classwise_accuracies(logits, targets)
    scwa = {i: np.array([cwa[c] for c in s]).mean() for i,s in enumerate(superclasses)}
    worst_item = min(scwa.items(), key=lambda x: x[1])
    return worst_item[1], -worst_item[0]-1

In [28]:
def worst_superclass_recall(logits, targets, superclasses):
    cwa = classwise_accuracies(logits, targets)
    clw_sn =classwise_sample_numbers(targets)
    scwa = {i: np.array([cwa[c]*clw_sn[c] for c in s]).sum()/np.array([clw_sn[c] for c in s]).sum() for i,s in enumerate(superclasses)}
    worst_item = min(scwa.items(), key=lambda x: x[1])
    return worst_item[1], -worst_item[0]-1

In [29]:
def intra_superclass_accuracies(logits, targets, superclasses):
    isa = {}
    for i,s in enumerate(superclasses):
        internal_samples = np.isin(targets, s)
        internal_targets = targets[internal_samples]
        internal_logits = logits[internal_samples][:,s]
        s_targets = np.vectorize(lambda x: s[x])
        internal_preds = s_targets(get_predictions(internal_logits)['predicted_classes'])
        isa[i] = (internal_preds == internal_targets).mean()
    return isa

In [30]:
def worst_intra_superclass_accuracy(logits, targets, superclasses):
    isa = intra_superclass_accuracies(logits, targets, superclasses)
    worst_item =  min(isa.items(), key=lambda x: x[1])
    return worst_item[1], -worst_item[0]-1

In [31]:
def worst_class_precision(logits, targets):
    preds = get_predictions(logits)
    classes = list(set(targets))
    sc = {}
    for c in classes:
        erroneous_c = (preds['predicted_classes'] == c)*(targets != c)
        correct_c = (preds['predicted_classes'] == c)*(targets == c)
        predicted_c = (preds['predicted_classes'] == c)
        if predicted_c.sum():
            sc[c] = correct_c.sum()/predicted_c.sum() #1-erroneous_c.sum()/predicted_c.sum()
        else:
            sc[c] = 1
    sorted_sc = sorted(sc.items(), key=lambda item: item[1])
    worst_item = sorted_sc[0]
    return worst_item[1], worst_item[0]

In [32]:
def class_confusion(logits, targets):
    preds = get_predictions(logits)
    classes = list(set(targets))
    confusion = np.zeros((len(classes), len(classes)))
    for i,c in enumerate(targets):
        confusion[c, preds['predicted_classes'][i]] += 1
    return confusion

In [33]:
dataset_tables = {}

In [38]:
dataset = 'imagenet_val_clean'

In [39]:
model_names = ['resnet18', 'resnet50', 'vgg16', 'vit_base_patch16_224']

In [40]:
full_table = {}
print_table = {}

In [41]:
worst_balanced_n_classes_accuracy(logits[dataset], labels[dataset], 10)

0.13981971306116694

In [42]:
load_outputs('imagenet_val_clean')

outputs/resnet18/imagenet_val_clean.npz


In [43]:
for model_name in model_names:
    logits, labels, classes = {}, {}, {}

    outputs_load_paths = [f'outputs/{model_name}']

    load_outputs(dataset)

    table_entries = {
        'Standard accuracy': standard_accuracy,
        'Worst class accuracy': lambda logits, targets: worst_class_accuracy(logits, targets),
        'Worst-class precision': lambda logits, targets: worst_class_precision(logits, targets),
        'Worst superclass-accuracy': lambda logits, targets: worst_intra_superclass_accuracy(logits, targets, superclasses),
        'Worst superclass recall': lambda logits, targets: worst_superclass_recall(logits, targets, superclasses),
        'Worst 10-class recall': lambda logits, targets: worst_heuristic_n_classes_recall(logits, targets, 10),
        'Worst 100-class recall': lambda logits, targets: worst_heuristic_n_classes_recall(logits, targets, 100),
        'Worst 2-class binary accuracy': lambda logits, targets: worst_balanced_two_class_binary_accuracy(logits, targets),
        'Worst class top-5 accuracy': lambda logits, targets: worst_class_topk_accuracy(logits, targets, 5),
        'Worst 10-class top-5 recall': lambda logits, targets: worst_heuristic_n_classes_topk_recall(logits, targets, 10, 5),
        'Worst 100-class top-5 recall': lambda logits, targets: worst_heuristic_n_classes_topk_recall(logits, targets, 100, 5),
    }

    table = {k: v(logits[dataset], labels[dataset]) for k,v in table_entries.items()}
    full_table[model_name] = table

outputs/resnet18/imagenet_val_clean.npz
outputs/resnet50/imagenet_val_clean.npz
outputs/vgg16/imagenet_val_clean.npz
outputs/vit_base_patch16_224/imagenet_val_clean.npz


In [44]:
dataset_tables[dataset] = full_table

In [45]:
def class_number_to_name(x):
    if x >= 0:
        return classes[dataset][x]
    else:
        sc_idx = -(x+1)
        return (superclass_names[sc_idx],)

In [46]:
def write_table_entry(a):
    b = (lambda x: x if type(x) is tuple else (x,))(a)
    value = f'{100*b[0]:.2f}'.zfill(5)
    detail = ''
    for e in b[1:]:
        detail += ' ' + class_number_to_name(e)[0] #'/'.join(class_number_to_name(e))
    return value + detail

In [47]:
latex_model_names = {
    'resnet18': 'ResNet-18',
    'resnet50': 'ResNet-50',
    'vgg16': 'VGG16',
    'vit_base_patch16_224': 'ViT base patch 16 224',
}

In [48]:
metric_latex_dict = {
    'Standard accuracy': 'A',
    'Worst class accuracy': 'WCA',
    'Worst-class precision': 'WCP',
    'Worst superclass-accuracy': 'WSupCA',
    'Worst superclass recall': 'WSupCR',
    'Worst 10-class recall': 'W10CR',
    'Worst 100-class recall': 'W100CR',
    'Worst 2-class binary accuracy': 'W2CA',
    'Worst class top-5 accuracy': 'WCA@5',
    'Worst 10-class top-5 recall': 'W10CR@5',
    'Worst 100-class top-5 recall': 'W100CR@5',
}

In [49]:
full_table

{'resnet18': {'Standard accuracy': 0.697919381461211,
  'Worst class accuracy': (0.06521739130434782, 836),
  'Worst-class precision': (0.15384615384615385, 618),
  'Worst superclass-accuracy': (0.7180616740088106, -2),
  'Worst superclass recall': (0.5947136563876652, -2),
  'Worst 10-class recall': 0.13973799126637554,
  'Worst 100-class recall': 0.3415004336513443,
  'Worst 2-class binary accuracy': (0.4417874396135266, 620, 681),
  'Worst class top-5 accuracy': (0.3958333333333333, 885),
  'Worst 10-class top-5 recall': 0.49122807017543857,
  'Worst 100-class top-5 recall': 0.6816202090592335},
 'resnet50': {'Standard accuracy': 0.7615758839371036,
  'Worst class accuracy': (0.1875, 638),
  'Worst-class precision': (0.25, 620),
  'Worst superclass-accuracy': (0.75, -4),
  'Worst superclass recall': (0.6343612334801763, -2),
  'Worst 10-class recall': 0.22317596566523606,
  'Worst 100-class recall': 0.41853932584269665,
  'Worst 2-class binary accuracy': (0.44130434782608696, 620, 6

In [50]:
def write_latex_table(full_table):
    table_str = 'Model '.ljust(50)
    for model_name in full_table.keys():
        table_str += ' &   ' + latex_model_names[model_name].ljust(25)
    table_str += r' \\ \hline' + ' %' + dataset + '\n'
    for mk, mv in metric_latex_dict.items():
        table_str += mv.ljust(50)
        for model_name, v in full_table.items():
            table_str += ' &   ' + write_table_entry(v[mk]).ljust(25)
        table_str += r' \\' + '\n'
    return table_str

In [51]:
len(logits[dataset])

46044

In [52]:
print(write_latex_table(dataset_tables[dataset]))

Model                                              &   ResNet-18                 &   ResNet-50                 &   VGG16                     &   ViT base patch 16 224     \\ \hline %imagenet_val_clean
A                                                  &   69.79                     &   76.16                     &   71.62                     &   84.62                     \\
WCA                                                &   06.52 sunglass            &   18.75 maillot             &   09.09 letter opener       &   21.74 tiger cat           \\
WCP                                                &   15.38 ladle               &   25.00 laptop              &   26.32 laptop              &   30.00 notebook            \\
WSupCA                                             &   71.81 Cat                 &   75.00 Turtle              &   70.93 Cat                 &   76.65 Cat                 \\
WSupCR                                             &   59.47 Cat                 &   63.44 Cat         

In [57]:
dataset = 'imagenet_val'

In [58]:
model_names = ['resnet18', 'resnet50', 'vgg16', 'vit_base_patch16_224']

In [59]:
for model_name in model_names:
    logits, labels, classes = {}, {}, {}

    outputs_load_paths = [f'outputs/{model_name}']

    load_outputs(dataset)

    table_entries = {
        'Standard accuracy': standard_accuracy,
        'Worst class accuracy': lambda logits, targets: worst_class_accuracy(logits, targets),
        'Worst-class precision': lambda logits, targets: worst_class_precision(logits, targets),
        'Worst superclass-accuracy': lambda logits, targets: worst_intra_superclass_accuracy(logits, targets, superclasses),
        'Worst superclass recall': lambda logits, targets: worst_superclass_recall(logits, targets, superclasses),
        'Worst 10-class recall': lambda logits, targets: worst_heuristic_n_classes_recall(logits, targets, 10),
        'Worst 100-class recall': lambda logits, targets: worst_heuristic_n_classes_recall(logits, targets, 100),
        'Worst 2-class binary accuracy': lambda logits, targets: worst_balanced_two_class_binary_accuracy(logits, targets),
        'Worst class top-5 accuracy': lambda logits, targets: worst_class_topk_accuracy(logits, targets, 5),
        'Worst 10-class top-5 recall': lambda logits, targets: worst_heuristic_n_classes_topk_recall(logits, targets, 10, 5),
        'Worst 100-class top-5 recall': lambda logits, targets: worst_heuristic_n_classes_topk_recall(logits, targets, 100, 5),
    }

    table = {k: v(logits[dataset], labels[dataset]) for k,v in table_entries.items()}
    full_table[model_name] = table

outputs/resnet18/imagenet_val.npz
outputs/resnet50/imagenet_val.npz
outputs/vgg16/imagenet_val.npz
outputs/vit_base_patch16_224/imagenet_val.npz


In [60]:
dataset_tables[dataset] = full_table

In [61]:
def class_number_to_name(x):
    if x >= 0:
        return classes[dataset][x]
    else:
        sc_idx = -(x+1)
        return (superclass_names[sc_idx],)

In [62]:
def write_table_entry(a):
    b = (lambda x: x if type(x) is tuple else (x,))(a)
    value = f'{100*b[0]:.2f}'.zfill(5)
    detail = ''
    for e in b[1:]:
        detail += ' ' + class_number_to_name(e)[0] #'/'.join(class_number_to_name(e))
    return value + detail

In [63]:
latex_model_names = {
    'resnet18': 'ResNet-18',
    'resnet50': 'ResNet-50',
    'vgg16': 'VGG16',
    'vit_base_patch16_224': 'ViT base patch 16 224',
}

In [64]:
metric_latex_dict = {
    'Standard accuracy': 'A',
    'Worst class accuracy': 'WCA',
    'Worst-class precision': 'WCP',
    'Worst superclass-accuracy': 'WSupCA',
    'Worst superclass recall': 'WSupCR',
    'Worst 10-class recall': 'W10CR',
    'Worst 100-class recall': 'W100CR',
    'Worst 2-class binary accuracy': 'W2CA',
    'Worst class top-5 accuracy': 'WCA@5',
    'Worst 10-class top-5 recall': 'W10CR@5',
    'Worst 100-class top-5 recall': 'W100CR@5',
}

In [65]:
full_table

{'resnet18': {'Standard accuracy': 0.69758,
  'Worst class accuracy': (0.08, 836),
  'Worst-class precision': (0.17142857142857143, 620),
  'Worst superclass-accuracy': (0.732, -2),
  'Worst superclass recall': (0.608, -2),
  'Worst 10-class recall': 0.154,
  'Worst 100-class recall': 0.341,
  'Worst 2-class binary accuracy': (0.42, 620, 681),
  'Worst class top-5 accuracy': (0.38, 885),
  'Worst 10-class top-5 recall': 0.494,
  'Worst 100-class top-5 recall': 0.681},
 'resnet50': {'Standard accuracy': 0.7613,
  'Worst class accuracy': (0.18, 638),
  'Worst-class precision': (0.2391304347826087, 620),
  'Worst superclass-accuracy': (0.756, -4),
  'Worst superclass recall': (0.656, -2),
  'Worst 10-class recall': 0.226,
  'Worst 100-class recall': 0.4192,
  'Worst 2-class binary accuracy': (0.44, 620, 681),
  'Worst class top-5 accuracy': (0.52, 818),
  'Worst 10-class top-5 recall': 0.632,
  'Worst 100-class top-5 recall': 0.7744},
 'vgg16': {'Standard accuracy': 0.71592,
  'Worst clas

In [66]:
def write_latex_table(full_table):
    table_str = 'Model '.ljust(50)
    for model_name in full_table.keys():
        table_str += ' &   ' + latex_model_names[model_name].ljust(25)
    table_str += r' \\ \hline' + ' %' + dataset + '\n'
    for mk, mv in metric_latex_dict.items():
        table_str += mv.ljust(50)
        for model_name, v in full_table.items():
            table_str += ' &   ' + write_table_entry(v[mk]).ljust(25)
        table_str += r' \\' + '\n'
    return table_str

In [67]:
len(logits[dataset])

50000

In [68]:
print(write_latex_table(dataset_tables[dataset]))

Model                                              &   ResNet-18                 &   ResNet-50                 &   VGG16                     &   ViT base patch 16 224     \\ \hline %imagenet_val
A                                                  &   69.76                     &   76.13                     &   71.59                     &   84.53                     \\
WCA                                                &   08.00 sunglass            &   18.00 maillot             &   10.00 velvet              &   24.00 tiger cat           \\
WCP                                                &   17.14 laptop              &   23.91 laptop              &   24.39 laptop              &   30.30 notebook            \\
WSupCA                                             &   73.20 Cat                 &   75.60 Turtle              &   72.40 Cat                 &   77.60 Cat                 \\
WSupCR                                             &   60.80 Cat                 &   65.60 Cat               