In [1]:
import inspect

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import numpy as np
import os
import sys
import matplotlib.pyplot as plt
from sklearn import decomposition
from scipy.stats import norm
import scipy.stats
from sklearn.manifold import TSNE
import seaborn as sns
from pathlib import Path
import collections
import itertools

In [4]:
def get_predictions(logits):
    preds = {'logits': logits,
             'predicted_classes': logits.argmax(axis=1),
             'class_probabilities': scipy.special.softmax(logits, axis=1),
             'confidences_classifier': scipy.special.softmax(logits, axis=1).max(axis=1),
            }
    preds['number_of_class_predictions'] = collections.Counter(preds['predicted_classes'])
    return preds

In [5]:
features, logits, labels, classes = {}, {}, {}, {}

In [6]:
mod = ''
model_name = f"vit_base_patch16_224{mod}"

In [7]:
outputs_load_path = f'/mnt/SHARED/Julian/feature_density_ood/notebooks/outputs/{model_name}'

In [8]:
def load_outputs(*dset_names):
    for dset_name in dset_names:
        outputs = np.load(os.path.join(outputs_load_path, f'{dset_name}.npz'), allow_pickle=True)
        for d in ['features', 'logits', 'labels', 'classes']:
           globals()[d][dset_name] = outputs[d]

In [9]:
load_outputs('raccoons', 'imagenet_val', 'Food101', 'textures', 'OOD_Places_MOS', 'iNaturalist')

In [10]:
predictions_IN_val = get_predictions(logits['imagenet_val'])

In [11]:
predictions_IN_val['predicted_classes']

array([  0,   0,   0, ..., 846, 333, 999])

In [12]:
def standard_accuracy(logits, targets):
    preds = get_predictions(logits)
    accuracy = np.equal(preds['predicted_classes'], targets).mean()
    return accuracy

In [13]:
def classwise_accuracies(logits, targets):
    preds = get_predictions(logits)
    clw_acc = {}
    for i in set(targets):
        clw_acc[i] = np.equal(preds['predicted_classes'][np.where(targets == i)], i).mean()
    return clw_acc

In [14]:
def worst_class_accuracy(logits, targets):
    cwa = classwise_accuracies(logits, targets)
    return min(cwa.items(), key=lambda x: x[1])

In [15]:
def worst_balanced_n_classes_accuracy(logits, targets, n):
    cwa = classwise_accuracies(logits, targets)
    sorted_cwa =  sorted(cwa.items(), key=lambda item: item[1])
    n_worst = sorted_cwa[:n]
    return np.array([x[1] for x in n_worst]).mean(), n_worst

In [16]:
def worst_balanced_two_class_binary_accuracy(logits, targets):
    classes = list(set(targets))
    binary_accuracies = {}
    for i,j in itertools.combinations(classes, 2):
        i_labelled = logits[np.where(targets == i)]
        j_labelled = logits[np.where(targets == j)]
        i_correct = np.greater(i_labelled[:,i], i_labelled[:,j]).mean()
        j_correct = np.greater(j_labelled[:,j], j_labelled[:,i]).mean()
        binary_accuracies[(i,j)] = (i_correct + j_correct)/2
    sorted_binary_accuracies = sorted(binary_accuracies.items(), key=lambda item: item[1])
    return sorted_binary_accuracies[0][1], sorted_binary_accuracies[0], sorted_binary_accuracies, binary_accuracies

In [17]:
defdef worst_balanced_superclass_accuracy(logits, targets, superclasses):
:
    cwa = classwise_accuracies(logits, targets)
    scwa = {s: np.array([cwa[c] for c in s]).mean() for s in superclasses}
    return min(scwa.items(), key=lambda x: x[1])

SyntaxError: invalid syntax (2466461043.py, line 1)

In [None]:
predictions_IN_val['predicted_classes'][np.where(labels['imagenet_val'] == 0)]

In [None]:
standard_accuracy(logits['imagenet_val'], labels['imagenet_val'])

In [None]:
cwa = classwise_accuracies(logits['imagenet_val'], labels['imagenet_val'])

In [None]:
np.array(list(cwa.values())).mean()

In [None]:
worst_class_accuracy(logits['imagenet_val'], labels['imagenet_val'])

In [None]:
classes['imagenet_val'][282]

In [None]:
worst_balanced_n_classes_accuracy(logits['imagenet_val'], labels['imagenet_val'], 10)

In [None]:
classes['imagenet_val'][620]

In [None]:
classes['imagenet_val'][681]

In [None]:
wbtcba = worst_balanced_two_class_binary_accuracy(logits['imagenet_val'], labels['imagenet_val'])

In [None]:
wbtcba[:2]

In [None]:
superclasses = [range(500), range(500,1000)]

In [None]:
worst_balanced_superclass_accuracy(logits['imagenet_val'], labels['imagenet_val'], superclasses)