In [1]:
import inspect

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import numpy as np
import os
import sys
import matplotlib.pyplot as plt
from sklearn import decomposition
from scipy.stats import norm
import scipy.stats
from sklearn.manifold import TSNE
import seaborn as sns
from pathlib import Path
import collections
import itertools
from numpy import linalg as LA

In [4]:
def get_predictions(logits:np.ndarray) -> np.ndarray:
    preds = {'logits': logits,
             'predicted_classes': logits.argmax(axis=1),
             'class_probabilities': scipy.special.softmax(logits, axis=1),
             'confidences_classifier': scipy.special.softmax(logits, axis=1).max(axis=1),
            }
    preds['number_of_class_predictions'] = collections.Counter(preds['predicted_classes'])
    return preds

In [5]:
features, logits, labels, classes = {}, {}, {}, {}

In [6]:
mod = ''
model_name = f"resnet18{mod}"

In [7]:
logits, labels, classes = {}, {}, {}

In [8]:
outputs_load_paths = [f'outputs/{model_name}', f'/mnt/SHARED/Julian/feature_density_ood/notebooks/outputs/{model_name}']

In [9]:
def load_outputs(*dset_names):
    for dset_name in dset_names:
        for path in outputs_load_paths:
            filename = os.path.join(path, f'{dset_name}.npz')
            if os.path.exists(filename):
                print(filename)
                outputs = np.load(filename, allow_pickle=True)
                break
        else:
            raise FileNotFoundError(f'{dset_name}.npz')
        for d in ['logits', 'labels', 'classes']:
           globals()[d][dset_name] = outputs[d]

In [10]:
load_outputs('imagenet_val')

outputs/resnet18/imagenet_val.npz


In [11]:
predictions_IN_val = get_predictions(logits['imagenet_val'])

In [12]:
predictions_IN_val['predicted_classes']

array([ 48,   0, 391, ..., 999, 333, 999])

In [13]:
def standard_accuracy(logits, targets):
    preds = get_predictions(logits)
    accuracy = (preds['predicted_classes'] == targets).mean()
    return accuracy

In [14]:
def classwise_accuracies(logits, targets):
    preds = get_predictions(logits)
    clw_acc = {}
    for i in set(targets):
        clw_acc[i] = np.equal(preds['predicted_classes'][np.where(targets == i)], i).mean()
    return clw_acc

In [15]:
def classwise_topk_accuracies(logits, targets, k):
    preds = get_predictions(logits)
    clw_topk_acc = {}
    for i in set(targets):
        clw_topk_acc[i] = np.equal(i, np.argsort(preds['class_probabilities'][np.where(targets == i)], axis=1)[:,-k:]).sum(axis=-1).mean()
    return clw_topk_acc

In [16]:
def standard_balanced_topk_accuracy(logits, targets, k):
    clw_topk_acc = classwise_topk_accuracies(logits, targets, k)
    return np.array(list(clw_topk_acc.values())).mean()

In [17]:
def worst_class_accuracy(logits, targets):
    cwa = classwise_accuracies(logits, targets)
    return min(cwa.items(), key=lambda x: x[1])

In [18]:
def worst_class_topk_accuracy(logits, targets,k):
    clw_topk_acc = classwise_topk_accuracies(logits, targets, k)
    return min(clw_topk_acc.items(), key=lambda x: x[1])

In [19]:
def worst_balanced_n_classes_accuracy(logits, targets, n):
    cwa = classwise_accuracies(logits, targets)
    sorted_cwa =  sorted(cwa.items(), key=lambda item: item[1])
    n_worst = sorted_cwa[:n]
    return np.array([x[1] for x in n_worst]).mean(), n_worst

In [20]:
def worst_balanced_n_classes_topk_accuracy(logits, targets, n, k):
    clw_topk_acc = classwise_topk_accuracies(logits, targets, k)
    sorted_clw_topk_acc =  sorted(clw_topk_acc.items(), key=lambda item: item[1])
    n_worst = sorted_clw_topk_acc[:n]
    return np.array([x[1] for x in n_worst]).mean(), n_worst

In [21]:
def worst_balanced_two_class_binary_accuracy(logits, targets):
    classes = list(set(targets))
    binary_accuracies = {}
    for i,j in itertools.combinations(classes, 2):
        i_labelled = logits[np.where(targets == i)]
        j_labelled = logits[np.where(targets == j)]
        i_correct = np.greater(i_labelled[:,i], i_labelled[:,j]).mean()
        j_correct = np.greater(j_labelled[:,j], j_labelled[:,i]).mean()
        binary_accuracies[(i,j)] = (i_correct + j_correct)/2
    sorted_binary_accuracies = sorted(binary_accuracies.items(), key=lambda item: item[1])
    return sorted_binary_accuracies[0][1], sorted_binary_accuracies[0], sorted_binary_accuracies, binary_accuracies

In [22]:
def worst_balanced_superclass_recall(logits, targets, superclasses):
    cwa = classwise_accuracies(logits, targets)
    scwa = {i: np.array([cwa[c] for c in s]).mean() for i,s in enumerate(superclasses)}
    return min(scwa.items(), key=lambda x: x[1])

In [23]:
def inter_superclass_accuracies(logits, targets, superclasses):
    isa = {}
    for i,s in enumerate(superclasses):
        internal_samples = np.isin(targets, s)
        internal_targets = targets[internal_samples]
        internal_logits = logits[internal_samples][:,s]
        s_targets = np.vectorize(lambda x: s[x])
        internal_preds = s_targets(get_predictions(internal_logits)['predicted_classes'])
        isa[i] = (internal_preds == internal_targets).mean()
    return isa

In [24]:
def worst_balanced_inter_superclass_accuracy(logits, targets, superclasses):
    isa = inter_superclass_accuracies(logits, targets, superclasses)
    return min(isa.items(), key=lambda x: x[1])

In [25]:
def worst_class_precision(logits, targets):
    preds = get_predictions(logits)
    classes = list(set(targets))
    sc = {}
    for c in classes:
        erroneous_c = (preds['predicted_classes'] == c)*(targets != c)
        correct_c = (preds['predicted_classes'] == c)*(targets == c)
        predicted_c = (preds['predicted_classes'] == c)
        if predicted_c.sum():
            sc[c] = correct_c.sum()/predicted_c.sum() #1-erroneous_c.sum()/predicted_c.sum()
        else:
            sc[c] = 1
    sorted_sc = sorted(sc.items(), key=lambda item: item[1])
    return sorted_sc[0]

In [26]:
def class_confusion(logits, targets):
    preds = get_predictions(logits)
    classes = list(set(targets))
    confusion = np.zeros((len(classes), len(classes)))
    for i,c in enumerate(targets):
        confusion[c, preds['predicted_classes'][i]] += 1
    return confusion

In [27]:
predictions_IN_val['predicted_classes'][np.where(labels['imagenet_val'] == 0)]

array([ 48,   0, 391,   0,   0,   0,   0,   0,   0,   0, 133,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0, 391,   0,   0, 389,  33,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0])

In [28]:
standard_accuracy(logits['imagenet_val'], labels['imagenet_val'])

0.69758

In [29]:
cwa = classwise_accuracies(logits['imagenet_val'], labels['imagenet_val'])

In [30]:
np.array(list(cwa.values())).mean()

0.69758

In [31]:
worst_class_accuracy(logits['imagenet_val'], labels['imagenet_val'])

(836, 0.08)

In [32]:
clw_top5_acc = classwise_topk_accuracies(logits['imagenet_val'], labels['imagenet_val'], 5)

In [33]:
np.array(list(clw_top5_acc.values())).mean()

0.89078

In [34]:
standard_balanced_topk_accuracy(logits['imagenet_val'], labels['imagenet_val'], 5)

0.89078

In [35]:
worst_class_topk_accuracy(logits['imagenet_val'], labels['imagenet_val'], 5)

(885, 0.38)

In [36]:
worst_balanced_n_classes_accuracy(logits['imagenet_val'], labels['imagenet_val'], 10)

(0.15399999999999997,
 [(836, 0.08),
  (282, 0.12),
  (620, 0.12),
  (885, 0.12),
  (585, 0.18),
  (600, 0.18),
  (618, 0.18),
  (813, 0.18),
  (906, 0.18),
  (638, 0.2)])

In [37]:
worst_balanced_n_classes_accuracy(logits['imagenet_val'], labels['imagenet_val'], 100)[0]

0.34099999999999986

In [38]:
worst_balanced_n_classes_topk_accuracy(logits['imagenet_val'], labels['imagenet_val'], 10, 5)[0]

0.49400000000000005

In [39]:
worst_balanced_n_classes_topk_accuracy(logits['imagenet_val'], labels['imagenet_val'], 100, 5)[0]

0.6810000000000003

In [40]:
superclass_list = np.genfromtxt('restricted_superclass.csv', delimiter=',', dtype='i8')

In [41]:
superclasses = [tuple(np.where(superclass_list == i)[0]) for i in range(0,9)]

In [42]:
superclass_names = ['Dog', 'Cat', 'Frog', 'Turtle', 'Bird', 'Monkey', 'Fish', 'Crab', 'Insect', '', 'Other']

In [43]:
sum([len(x) for x in superclasses])

203

In [44]:
len(np.where(superclass_list == 10)[0])

797

In [45]:
worst_balanced_superclass_recall(logits['imagenet_val'], labels['imagenet_val'], superclasses)

(1, 0.608)

In [46]:
inter_superclass_accuracies(logits['imagenet_val'], labels['imagenet_val'], superclasses)

{0: 0.7359322033898306,
 1: 0.732,
 2: 0.7733333333333333,
 3: 0.756,
 4: 0.9114285714285715,
 5: 0.76,
 6: 0.86,
 7: 0.79,
 8: 0.78}

In [70]:
superclass_names[1]

'Cat'

In [47]:
worst_balanced_inter_superclass_accuracy(logits['imagenet_val'], labels['imagenet_val'], superclasses)

(1, 0.732)

In [48]:
worst_class_precision(logits['imagenet_val'], labels['imagenet_val'])

(620, 0.17142857142857143)

In [49]:
confusion_full = class_confusion(logits['imagenet_val'], labels['imagenet_val'])

In [50]:
np.trace(confusion_full)/np.sum(confusion_full)

0.69758

In [51]:
confusion_normalized = confusion_full/confusion_full.sum(axis=1)

In [52]:
confusion_normalized.trace()/len(confusion_normalized)

0.69758

In [53]:
confusion_normalized.trace()/len(confusion_normalized)

0.69758

In [54]:
confusion_normalized_errors = np.copy(confusion_normalized)
np.fill_diagonal(confusion_normalized_errors, 0)

In [55]:
confusion_normalized_errors.sum()/len(confusion_normalized_errors) #standard error

0.30242000000000013

In [56]:
LA.norm(confusion_normalized_errors, 1)

1.2800000000000007

In [57]:
LA.norm(confusion_normalized_errors, np.inf) #worst class error

0.9199999999999999

In [58]:
LA.norm(confusion_normalized_errors, 2)

0.6681436581563724

In [59]:
1 - .696969696969697

0.303030303030303

In [60]:
np.max(confusion_normalized_errors)

0.64

In [61]:
np.unravel_index(np.argmax(confusion_normalized_errors), confusion_normalized_errors.shape)

(620, 681)

In [67]:
model_names = ['resnet18', 'resnet50', 'vgg16', 'vit_base_patch16_224']

In [68]:
table = {}

In [69]:
for model in model_names:
    logits, labels, classes = {}, {}, {}

    outputs_load_paths = [f'outputs/{model_name}', f'/mnt/SHARED/Julian/feature_density_ood/notebooks/outputs/{model_name}']

    load_outputs('imagenet_val')

    table_entries = {
        'Standard accuracy': standard_accuracy,
        'Worst class accuracy': lambda logits, targets: worst_class_accuracy(logits, targets)[1],
        'Worst superclass recall': lambda logits, targets: worst_balanced_superclass_recall(logits, targets, superclasses)[1],
        'Worst intra-superclass accuracy': lambda logits, targets: worst_balanced_inter_superclass_accuracy(logits, targets, superclasses)[1],
        'Worst 10-classes accuracy:': lambda logits, targets: worst_balanced_n_classes_accuracy(logits, targets, 10)[0],
        'Worst 100-classes accuracy:': lambda logits, targets: worst_balanced_n_classes_accuracy(logits, targets, 100)[0],
        'Worst class top-5 accuracy:': lambda logits, targets: worst_class_topk_accuracy(logits, targets, 5)[1],
        'Worst 100-classes top-5 accuracy:': lambda logits, targets: worst_balanced_n_classes_topk_accuracy(logits, targets, 100, 5)[0],
        'Worst class precision': lambda logits, targets: worst_class_precision(logits, targets)[1],
        'Worst 2-class binary accuracy': lambda logits, targets: worst_balanced_two_class_binary_accuracy(logits, targets)[0],
    }

    table = {k: v(logits['imagenet_val'], labels['imagenet_val']) for k,v in table_entries.items()}

    table_str = f'{model_name}\n'
    for k,v in table.items():
        table_str += f'{k:<40}{100*v:.2F}\n'

    table[model_name] = table_str

outputs/resnet18/imagenet_val.npz
outputs/resnet18/imagenet_val.npz
outputs/resnet18/imagenet_val.npz
outputs/resnet18/imagenet_val.npz
