In [1]:
import typing
import io
import os

from tqdm import tqdm
import torch
import numpy as np
!pip install opencv-python
import cv2
import matplotlib.pyplot as plt

from urllib.request import urlretrieve

from PIL import Image
from torchvision import transforms

from models.modeling import VisionTransformer, CONFIGS



In [41]:
class args:
    img_size = 224
    pretrained_dir = "output_models/cifar10/ViT-L_32"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dataset = "cifar10"
    local_rank = -1
    train_batch_size = 64
    eval_batch_size = 64
    

config_vit_b_16 = CONFIGS["ViT-B_16"]
config_vit_b_32 = CONFIGS["ViT-B_32"]
config_vit_l_16 = CONFIGS["ViT-L_16"]
config_vit_l_32 = CONFIGS["ViT-L_32"]

models = []

for i in range(1,4):
    model = VisionTransformer(config_vit_l_32, args.img_size, zero_head=True, num_classes=10, vis=True)
    model.load_state_dict(torch.load(args.pretrained_dir + f'/Run{i}/{args.dataset}_vit_l_32_seed_{i}_checkpoint.bin'))
    model.to(args.device)
    models.append(model)
# model = VisionTransformer(config_vit_b_16, args.img_size, zero_head=True, num_classes=10, vis=True)
# model.load_state_dict(torch.load(args.pretrained_dir + f'/Run1/{args.dataset}_vit_b_16_seed_1_checkpoint.bin'))
# model.to(args.device)
# num_params = count_parameters(model)

In [42]:
# Load the data loaders (CIFAR-10 test loader and SVHN test loader)
from utils.data_utils import get_loader
from ipywidgets import IntProgress

cifar10_train_loader, cifar10_test_loader = get_loader(args)
args.dataset = "cifar100"
cifar100_train_loader, cifar100_test_loader = get_loader(args)
args.dataset = "svhn"
svhn_train_loader, svhn_test_loader = get_loader(args)


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ./data/train_32x32.mat
Using downloaded and verified file: ./data/test_32x32.mat


In [43]:
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
import torch.nn.functional as F

def get_logits_labels(model, data_loader, device):
    """
    Utility function to get logits and labels.
    """
    model.eval()
    logits = []
    labels = []
    with torch.no_grad():
        for data, label in data_loader:
            data = data.to(device)
            label = label.to(device)

            logit, _ = model(data)
            logits.append(logit)
            labels.append(label)
    logits = torch.cat(logits, dim=0)
    labels = torch.cat(labels, dim=0)
    return logits, labels

def test_classification_net_softmax(softmax_prob, labels):
    """
    This function reports classification accuracy and confusion matrix given softmax vectors and
    labels from a model.
    """
    labels_list = []
    predictions_list = []
    confidence_vals_list = []

    confidence_vals, predictions = torch.max(softmax_prob, dim=1)
    labels_list.extend(labels.cpu().numpy())
    predictions_list.extend(predictions.cpu().numpy())
    confidence_vals_list.extend(confidence_vals.cpu().numpy())
    accuracy = accuracy_score(labels_list, predictions_list)
    return (
        confusion_matrix(labels_list, predictions_list),
        accuracy,
        labels_list,
        predictions_list,
        confidence_vals_list,
    )

def test_classification_net_logits(logits, labels):
    """
    This function reports classification accuracy and confusion matrix given logits and labels
    from a model.
    """
    softmax_prob = F.softmax(logits, dim=1)
    return test_classification_net_softmax(softmax_prob, labels)


def test_classification_net(model, data_loader, device):
    """
    This function reports classification accuracy and confusion matrix over a dataset.
    """
    logits, labels = get_logits_labels(model, data_loader, device)
    return test_classification_net_logits(logits, labels)

device = "cuda" if torch.cuda.is_available() else "cpu"

accs = []
for model in models:
    _, accuracy, _, _, _ = test_classification_net(model, cifar10_test_loader, device)
    accs.append(accuracy)

accs = torch.tensor(accs)
print (torch.mean(accs))
print (torch.std(accs))

tensor(0.9902, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)


In [44]:
# model = VisionTransformer(config_vit_b_16, args.img_size, zero_head=True, num_classes=10, vis=True).cuda()
# model.load_state_dict(torch.load('/data/home/mukhotij/ViT/Code/output/cifar10-100_500_checkpoint.bin'))

In [45]:
# conf_matrix, acc, _, _, _ = test_classification_net(model, cifar10_test_loader, device)

# print (conf_matrix)
# print (acc)

# # for data, label in cifar10_test_loader:
# #     print (data.shape)
# #     print (data.min())
# #     print (data.max())
# #     print (label.shape)
# #     break

In [46]:
# import math
# import torch
# import numpy as np
# from torch import nn
# from torch.nn import functional as F

# import matplotlib.pyplot as plt

# plt.rcParams.update({"font.size": 20})


# # Some keys used for the following dictionaries
# COUNT = "count"
# CONF = "conf"
# ACC = "acc"
# BIN_ACC = "bin_acc"
# BIN_CONF = "bin_conf"


# def _bin_initializer(num_bins=10):
#     bin_dict = {}
#     for i in range(num_bins):
#         bin_dict[i] = {}
#         bin_dict[i][COUNT] = 0
#         bin_dict[i][CONF] = 0
#         bin_dict[i][ACC] = 0
#         bin_dict[i][BIN_ACC] = 0
#         bin_dict[i][BIN_CONF] = 0

#     return bin_dict


# def _populate_bins(confs, preds, labels, num_bins=10):

#     bin_dict = _bin_initializer(num_bins)
#     num_test_samples = len(confs)

#     for i in range(0, num_test_samples):
#         confidence = confs[i]
#         prediction = preds[i]
#         label = labels[i]
#         binn = int(math.ceil(((num_bins * confidence) - 1)))
#         bin_dict[binn][COUNT] = bin_dict[binn][COUNT] + 1
#         bin_dict[binn][CONF] = bin_dict[binn][CONF] + confidence
#         bin_dict[binn][ACC] = bin_dict[binn][ACC] + (1 if (label == prediction) else 0)

#     for binn in range(0, num_bins):
#         if bin_dict[binn][COUNT] == 0:
#             bin_dict[binn][BIN_ACC] = 0
#             bin_dict[binn][BIN_CONF] = 0
#         else:
#             bin_dict[binn][BIN_ACC] = float(bin_dict[binn][ACC]) / bin_dict[binn][COUNT]
#             bin_dict[binn][BIN_CONF] = bin_dict[binn][CONF] / float(bin_dict[binn][COUNT])
#     return bin_dict


# def expected_calibration_error(confs, preds, labels, num_bins=10):
#     bin_dict = _populate_bins(confs, preds, labels, num_bins)
#     num_samples = len(labels)
#     ece = 0
#     for i in range(num_bins):
#         bin_accuracy = bin_dict[i][BIN_ACC]
#         bin_confidence = bin_dict[i][BIN_CONF]
#         bin_count = bin_dict[i][COUNT]
#         ece += (float(bin_count) / num_samples) * abs(bin_accuracy - bin_confidence)
#     return ece


# # Calibration error scores in the form of loss metrics
# class ECELoss(nn.Module):
#     """
#     Compute ECE (Expected Calibration Error)
#     """

#     def __init__(self, n_bins=15):
#         super(ECELoss, self).__init__()
#         bin_boundaries = torch.linspace(0, 1, n_bins + 1)
#         self.bin_lowers = bin_boundaries[:-1]
#         self.bin_uppers = bin_boundaries[1:]

#     def forward(self, logits, labels):
#         softmaxes = F.softmax(logits, dim=1)
#         confidences, predictions = torch.max(softmaxes, 1)
#         accuracies = predictions.eq(labels)

#         ece = torch.zeros(1, device=logits.device)
#         for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
#             # Calculated |confidence - accuracy| in each bin
#             in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
#             prop_in_bin = in_bin.float().mean()
#             if prop_in_bin.item() > 0:
#                 accuracy_in_bin = accuracies[in_bin].float().mean()
#                 avg_confidence_in_bin = confidences[in_bin].mean()
#                 ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

#         return ece


# # Methods for plotting reliability diagrams and bin-strength plots
# def reliability_plot(confs, preds, labels, num_bins=15, model_name='model'):
#     """
#     Method to draw a reliability plot from a model's predictions and confidences.
#     """
#     bin_dict = _populate_bins(confs, preds, labels, num_bins)
#     bns = [(i / float(num_bins)) for i in range(num_bins)]
#     y = []
#     for i in range(num_bins):
#         y.append(bin_dict[i][BIN_ACC])
#     plt.figure(figsize=(10, 8))  # width:20, height:3
#     plt.bar(bns, bns, align="edge", width=0.03, color="pink", label="Expected")
#     plt.bar(bns, y, align="edge", width=0.03, color="blue", alpha=0.5, label="Actual")
#     plt.ylabel("Accuracy", fontsize=30)
#     plt.xlabel("Confidence", fontsize=30)
#     plt.xticks(fontsize=30)
#     plt.yticks(fontsize=30)
#     plt.legend(fontsize=30, loc='upper left')
#     plt.savefig(f'./reliability_plot_{model_name}.pdf')
#     plt.savefig(f'./reliability_plot_{model_name}.png')
#     plt.show()

In [47]:
# ece = expected_calibration_error(confidences, predictions, labels, num_bins=15)
# print (ece)
# reliability_plot(confidences, predictions, labels, num_bins=25)

In [48]:
import torch
import torch.nn.functional as F


def entropy(logits):
    p = F.softmax(logits, dim=1)
    logp = F.log_softmax(logits, dim=1)
    plogp = p * logp
    entropy = -torch.sum(plogp, dim=1)
    return entropy


def logsumexp(logits):
    return torch.logsumexp(logits, dim=1, keepdim=False)


def confidence(logits):
    p = F.softmax(logits, dim=1)
    confidence, _ = torch.max(p, dim=1)
    return confidence


def entropy_prob(probs):
    p = probs
    eps = 1e-12
    logp = torch.log(p + eps)
    plogp = p * logp
    entropy = -torch.sum(plogp, dim=1)
    return entropy


def mutual_information_prob(probs):
    mean_output = torch.mean(probs, dim=0)
    predictive_entropy = entropy_prob(mean_output)

    # Computing expectation of entropies
    p = probs
    eps = 1e-12
    logp = torch.log(p + eps)
    plogp = p * logp
    exp_entropies = torch.mean(-torch.sum(plogp, dim=2), dim=0)

    # Computing mutual information
    mi = predictive_entropy - exp_entropies
    return mi

In [49]:
import torch
import torch.nn.functional as F
from sklearn import metrics


def get_roc_auc(net, test_loader, ood_test_loader, uncertainty, device, confidence=False):
    logits, _ = get_logits_labels(net, test_loader, device)
    ood_logits, _ = get_logits_labels(net, ood_test_loader, device)

    return get_roc_auc_logits(logits, ood_logits, uncertainty, device, confidence=confidence)


def get_roc_auc_logits(logits, ood_logits, uncertainty, device, confidence=False):
    uncertainties = uncertainty(logits)
    ood_uncertainties = uncertainty(ood_logits)

    # In-distribution
    bin_labels = torch.zeros(uncertainties.shape[0]).to(device)
    in_scores = uncertainties

    # OOD
    bin_labels = torch.cat((bin_labels, torch.ones(ood_uncertainties.shape[0]).to(device)))

    if confidence:
        bin_labels = 1 - bin_labels
    ood_scores = ood_uncertainties  # entropy(ood_logits)
    scores = torch.cat((in_scores, ood_scores))

    fpr, tpr, thresholds = metrics.roc_curve(bin_labels.cpu().numpy(), scores.cpu().numpy())
    precision, recall, prc_thresholds = metrics.precision_recall_curve(bin_labels.cpu().numpy(), scores.cpu().numpy())
    auroc = metrics.roc_auc_score(bin_labels.cpu().numpy(), scores.cpu().numpy())
    auprc = metrics.average_precision_score(bin_labels.cpu().numpy(), scores.cpu().numpy())

    return (fpr, tpr, thresholds), (precision, recall, prc_thresholds), auroc, auprc


In [50]:
aurocs = []
auprcs = []
for model in models:
    (_, _, _), (_, _, _), auroc, auprc = get_roc_auc(model, cifar10_test_loader, svhn_test_loader, entropy, device)
    aurocs.append(auroc)
    auprcs.append(auprc)

aurocs = torch.tensor(aurocs)
auprcs = torch.tensor(auprcs)

auroc_mean = torch.mean(aurocs)
auroc_std = torch.std(aurocs)

auprc_mean = torch.mean(auprcs)
auprc_std = torch.std(auprcs)

print (f"AUROC: {auroc_mean}+-{auroc_std}")
print (f"AUPRC: {auprc_mean}+-{auroc_std}")

AUROC: 0.9978352553267772+-0.00020354844715152136
AUPRC: 0.9991356971197624+-0.00020354844715152136


In [51]:
aurocs = []
auprcs = []
for model in models:
    (_, _, _), (_, _, _), auroc, auprc = get_roc_auc(model, cifar10_test_loader, cifar100_test_loader, entropy, device)
    aurocs.append(auroc)
    auprcs.append(auprc)

aurocs = torch.tensor(aurocs)
auprcs = torch.tensor(auprcs)

auroc_mean = torch.mean(aurocs)
auroc_std = torch.std(aurocs)

auprc_mean = torch.mean(auprcs)
auprc_std = torch.std(auprcs)

print (f"AUROC: {auroc_mean}+-{auroc_std}")
print (f"AUPRC: {auprc_mean}+-{auroc_std}")

AUROC: 0.9844781083333335+-0.0004933044593689182
AUPRC: 0.985478343799175+-0.0004933044593689182


In [52]:
aurocs = []
auprcs = []
for model in models:
    (_, _, _), (_, _, _), auroc, auprc = get_roc_auc(model, cifar10_test_loader, svhn_test_loader, confidence, device, confidence=True)
    aurocs.append(auroc)
    auprcs.append(auprc)

aurocs = torch.tensor(aurocs)
auprcs = torch.tensor(auprcs)

auroc_mean = torch.mean(aurocs)
auroc_std = torch.std(aurocs)

auprc_mean = torch.mean(auprcs)
auprc_std = torch.std(auprcs)

print (f"AUROC: {auroc_mean}+-{auroc_std}")
print (f"AUPRC: {auprc_mean}+-{auroc_std}")

AUROC: 0.996324323909035+-0.00037792024791639695
AUPRC: 0.9932660611948467+-0.00037792024791639695


In [53]:
aurocs = []
auprcs = []
for model in models:
    (_, _, _), (_, _, _), auroc, auprc = get_roc_auc(model, cifar10_test_loader, cifar100_test_loader, confidence, device, confidence=True)
    aurocs.append(auroc)
    auprcs.append(auprc)

aurocs = torch.tensor(aurocs)
auprcs = torch.tensor(auprcs)

auroc_mean = torch.mean(aurocs)
auroc_std = torch.std(aurocs)

auprc_mean = torch.mean(auprcs)
auprc_std = torch.std(auprcs)

print (f"AUROC: {auroc_mean}+-{auroc_std}")
print (f"AUPRC: {auprc_mean}+-{auroc_std}")

AUROC: 0.9829110766666668+-0.0005265952466633619
AUPRC: 0.9825370548250124+-0.0005265952466633619
