In [1]:
import inspect

In [2]:
%load_ext autoreload
%autoreload 2

In [115]:
import numpy as np
import os
import sys
import matplotlib.pyplot as plt
from sklearn import decomposition
from scipy.stats import norm
import scipy.stats
from sklearn.manifold import TSNE
import seaborn as sns
from pathlib import Path
import collections
import itertools
from numpy import linalg as LA

In [24]:
def get_predictions(logits:np.ndarray) -> np.ndarray:
    preds = {'logits': logits,
             'predicted_classes': logits.argmax(axis=1),
             'class_probabilities': scipy.special.softmax(logits, axis=1),
             'confidences_classifier': scipy.special.softmax(logits, axis=1).max(axis=1),
            }
    preds['number_of_class_predictions'] = collections.Counter(preds['predicted_classes'])
    return preds

In [25]:
features, logits, labels, classes = {}, {}, {}, {}

In [26]:
mod = ''
model_name = f"vit_base_patch16_224{mod}"

In [27]:
outputs_load_path = f'/mnt/SHARED/Julian/feature_density_ood/notebooks/outputs/{model_name}'

In [28]:
def load_outputs(*dset_names):
    for dset_name in dset_names:
        outputs = np.load(os.path.join(outputs_load_path, f'{dset_name}.npz'), allow_pickle=True)
        for d in ['features', 'logits', 'labels', 'classes']:
           globals()[d][dset_name] = outputs[d]

In [29]:
load_outputs('raccoons', 'imagenet_val', 'Food101', 'textures', 'OOD_Places_MOS', 'iNaturalist')

In [30]:
predictions_IN_val = get_predictions(logits['imagenet_val'])

In [31]:
predictions_IN_val['predicted_classes']

array([  0,   0,   0, ..., 846, 333, 999])

In [65]:
def standard_accuracy(logits, targets):
    preds = get_predictions(logits)
    accuracy = (preds['predicted_classes'] == targets).mean()
    return accuracy

In [33]:
def classwise_accuracies(logits, targets):
    preds = get_predictions(logits)
    clw_acc = {}
    for i in set(targets):
        clw_acc[i] = np.equal(preds['predicted_classes'][np.where(targets == i)], i).mean()
    return clw_acc

In [34]:
def worst_class_accuracy(logits, targets):
    cwa = classwise_accuracies(logits, targets)
    return min(cwa.items(), key=lambda x: x[1])

In [35]:
def worst_balanced_n_classes_accuracy(logits, targets, n):
    cwa = classwise_accuracies(logits, targets)
    sorted_cwa =  sorted(cwa.items(), key=lambda item: item[1])
    n_worst = sorted_cwa[:n]
    return np.array([x[1] for x in n_worst]).mean(), n_worst

In [36]:
def worst_balanced_two_class_binary_accuracy(logits, targets):
    classes = list(set(targets))
    binary_accuracies = {}
    for i,j in itertools.combinations(classes, 2):
        i_labelled = logits[np.where(targets == i)]
        j_labelled = logits[np.where(targets == j)]
        i_correct = np.greater(i_labelled[:,i], i_labelled[:,j]).mean()
        j_correct = np.greater(j_labelled[:,j], j_labelled[:,i]).mean()
        binary_accuracies[(i,j)] = (i_correct + j_correct)/2
    sorted_binary_accuracies = sorted(binary_accuracies.items(), key=lambda item: item[1])
    return sorted_binary_accuracies[0][1], sorted_binary_accuracies[0], sorted_binary_accuracies, binary_accuracies

In [37]:
def worst_balanced_superclass_accuracy(logits, targets, superclasses):
    cwa = classwise_accuracies(logits, targets)
    scwa = {s: np.array([cwa[c] for c in s]).mean() for s in superclasses}
    return min(scwa.items(), key=lambda x: x[1])

In [97]:
def inter_superclass_accuracies(logits, targets, superclasses):
    isa = {}
    for s in superclasses:
        internal_samples = np.isin(targets, s)
        internal_targets = targets[internal_samples]
        internal_logits = logits[internal_samples][:,s]
        s_targets = np.vectorize(lambda x: s[x])
        internal_preds = s_targets(get_predictions(internal_logits)['predicted_classes'])
        isa[s] = (internal_preds == internal_targets).mean()
    return isa

In [64]:
def worst_balanced_inter_superclass_accuracy(logits, targets, superclasses):
    raise NotImplementedError

In [62]:
def worst_class_specifity(logits, targets):
    preds = get_predictions(logits)
    classes = list(set(targets))
    sc = {}
    for c in classes:
        erroneous_c = (preds['predicted_classes'] == c)*(targets != c)
        predicted_c = (preds['predicted_classes'] == c)
        if predicted_c.sum():
            sc[c] = erroneous_c.sum()/predicted_c.sum()
        else:
            sc[c] = 0
    sorted_sc = sorted(sc.items(), key=lambda item: item[1], reverse=True)
    return sorted_sc

In [102]:
def class_confusion(logits, targets):
    preds = get_predictions(logits)
    classes = list(set(targets))
    confusion = np.zeros((len(classes), len(classes)))
    for i,c in enumerate(targets):
        confusion[c, preds['predicted_classes'][i]] += 1
    return confusion

In [38]:
predictions_IN_val['predicted_classes'][np.where(labels['imagenet_val'] == 0)]

array([  0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 133,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0])

In [67]:
standard_accuracy(logits['imagenet_val'], labels['imagenet_val'])

0.84528

In [66]:
cwa = classwise_accuracies(logits['imagenet_val'], labels['imagenet_val'])

In [42]:
np.array(list(cwa.values())).mean()

0.8452799999999999

In [43]:
worst_class_accuracy(logits['imagenet_val'], labels['imagenet_val'])

(282, 0.24)

In [44]:
classes['imagenet_val'][282]

('tiger cat',)

In [45]:
worst_balanced_n_classes_accuracy(logits['imagenet_val'], labels['imagenet_val'], 10)

(0.31399999999999995,
 [(282, 0.24),
  (240, 0.28),
  (782, 0.28),
  (620, 0.3),
  (638, 0.3),
  (657, 0.3),
  (482, 0.32),
  (810, 0.36),
  (167, 0.38),
  (744, 0.38)])

In [46]:
classes['imagenet_val'][620]

('laptop', 'laptop computer')

In [47]:
classes['imagenet_val'][681]

('notebook', 'notebook computer')

In [48]:
wbtcba = worst_balanced_two_class_binary_accuracy(logits['imagenet_val'], labels['imagenet_val'])

In [49]:
wbtcba[:2]

(0.4, ((620, 681), 0.4))

In [50]:
superclasses = [range(500), range(500,1000)]

In [51]:
worst_balanced_superclass_accuracy(logits['imagenet_val'], labels['imagenet_val'], superclasses)

(range(500, 1000), 0.83228)

In [96]:
inter_superclass_accuracies(logits['imagenet_val'], labels['imagenet_val'], superclasses)

(50000,)
(50000,)
[  0   0   0 ... 499 499 499]
(50000,)
(50000,)
[500 500 500 ... 846 700 999]


{range(0, 500): 0.8832, range(500, 1000): 0.85196}

In [69]:
worst_class_specifity(logits['imagenet_val'], labels['imagenet_val'])[:10]

[(681, 0.696969696969697),
 (620, 0.6590909090909091),
 (638, 0.6341463414634146),
 (657, 0.625),
 (848, 0.6212121212121212),
 (744, 0.5957446808510638),
 (639, 0.5789473684210527),
 (482, 0.5555555555555556),
 (527, 0.547945205479452),
 (73, 0.543859649122807)]

In [104]:
confusion_full = class_confusion(logits['imagenet_val'], labels['imagenet_val'])

In [106]:
np.trace(confusion_full)/np.sum(confusion_full)

0.84528

In [107]:
confusion_normalized = confusion_full/confusion_full.sum(axis=1)

In [114]:
confusion_normalized.trace()/len(confusion_normalized)

0.8452799999999999

In [114]:
confusion_normalized.trace()/len(confusion_normalized)

0.8452799999999999

In [118]:
confusion_normalized_errors = np.copy(confusion_normalized)
np.fill_diagonal(confusion_normalized_errors, 0)

In [125]:
confusion_normalized_errors.sum()/len(confusion_normalized_errors) #standard error

0.15472000000000002

In [126]:
LA.norm(confusion_normalized_errors, 1)

0.9200000000000002

In [127]:
LA.norm(confusion_normalized_errors, np.inf) #worst class error

0.76

In [129]:
LA.norm(confusion_normalized_errors, 2)

0.6006040051358144