## import libraries

In [7]:
import os
import sys
from tqdm import tqdm
from time import time
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from sklearn import datasets
from sklearn.manifold import TSNE

sys.path.append("../")

# Network architectures
from net.resnet import resnet50
from net.vgg import vgg16

from data_utils.ood_detection import cifar10,svhn,cifar100,fer2013,lsun,mnist,tiny_imagenet
import metrics.uncertainty_confidence as uncertainty_confidence
from utils.gmm_utils import get_embeddings, gmm_fit, gmm_evaluate
# Import metrics to compute
from metrics.classification_metrics import (test_classification_net, test_classification_net_logits, test_classification_net_ensemble)
from metrics.calibration_metrics import expected_calibration_error
from metrics.uncertainty_confidence import entropy, logsumexp, confidence, sumexp, max
from metrics.ood_metrics import get_roc_auc, get_roc_auc_logits, get_roc_auc_ensemble

In [13]:
# Dataset params
dataset_num_classes = {"cifar10": 10, "cifar100": 100, "svhn": 10, "lsun": 10, "tiny_iamgenet": 200}

dataset_loader = {
    "cifar10": cifar10,
    "cifar100": cifar100,
    "svhn": svhn,
    "fer2013": fer2013,
    "mnist": mnist,
    "lsun": lsun,
    "tiny_imagenet": tiny_imagenet
}
dataset="cifar10"
ood_dataset="tiny_imagenet"
model="resnet50"
batch_size=512
models = { "resnet50": resnet50 , "vgg16": vgg16,}
model_to_num_dim = {"resnet18": 512, "resnet50": 2048, "resnet101": 2048, "resnet152": 2048, "wide_resnet": 640, "vgg16": 512, "vit": 512}

torch.manual_seed(1)
device = torch.device(f"cuda:1")
num_classes = 10
train_loader, val_loader = dataset_loader[dataset].get_train_valid_loader(
    root="../data/",
    batch_size=batch_size,
    val_seed=1,
    augment=False,
    val_size=0.1,
)
test_loader =  dataset_loader[dataset].get_test_loader(batch_size,root="../data/")
ood_test_loader = dataset_loader[ood_dataset].get_test_loader(batch_size,root="../data/")

tiny-imagenet test:10000


## load model

In [14]:
saved_model_name = "../saved_models/run18/resnet50_sn_3.0_mod_seed_1_contrastive1/2024_06_14_17_53_27/resnet50_sn_3.0_mod_seed_1_contrastive1_best.model"
print(f"load {saved_model_name}")
net = models[model](
    spectral_normalization=True,
    mod=3.0,
    num_classes=num_classes,
    temp=1.0,
).to(device)
_ = net.load_state_dict(torch.load(str(saved_model_name), map_location=device), strict=True)
_ = net.eval()


(
    conf_matrix,
    accuracy,
    labels_list,
    predictions,
    confidences,
) = test_classification_net(net, test_loader, device)

print("GMM Model")
embeddings, labels, norm_threshold = get_embeddings(
    net,
    train_loader,
    num_dim=model_to_num_dim[model],
    dtype=torch.double,
    device=device,
    storage_device=device,
)

gaussians_model, jitter_eps = gmm_fit(embeddings=embeddings, labels=labels, num_classes=num_classes)
logits, labels = gmm_evaluate(
    net,
    gaussians_model,
    test_loader,
    device=device,
    num_classes=num_classes,
    storage_device=device,
)

ood_logits, ood_labels = gmm_evaluate(
    net,
    gaussians_model,
    ood_test_loader,
    device=device,
    num_classes=num_classes,
    storage_device=device,
)
m1_fpr95, m1_auroc, m1_auprc = get_roc_auc_logits(logits, ood_logits, logsumexp, device, conf=True)
print(f"contrastive+:m1_auroc1:{m1_auroc:.4f},m1_auprc:{m1_auprc:.4f}")

load ../saved_models/run18/resnet50_sn_3.0_mod_seed_1_contrastive1/2024_06_14_17_53_27/resnet50_sn_3.0_mod_seed_1_contrastive1_best.model
GMM Model
get embeddings from dataloader...


100%|██████████| 88/88 [00:14<00:00,  6.02it/s]


norm threshold:0.05582457035779953


100%|██████████| 20/20 [00:01<00:00, 12.47it/s]
100%|██████████| 20/20 [00:01<00:00, 11.85it/s]

contrastive+:m1_auroc1:0.8089,m1_auprc:0.8221





In [15]:
# saved_model_name = "../saved_models/run17/vgg16_sn_3.0_mod_seed_1/2024_05_27_19_08_26/vgg16_sn_3.0_mod_seed_1_best.model"
saved_model_name = "../saved_models/run17/resnet50_sn_3.0_mod_seed_1/2024_05_21_16_49_32/resnet50_sn_3.0_mod_seed_1_best.model"
print(f"load {saved_model_name}")
net = models[model](
    spectral_normalization=True,
    mod=3.0,
    num_classes=num_classes,
    temp=1.0,
).to(device)
_ = net.load_state_dict(torch.load(str(saved_model_name), map_location=device), strict=True)
_ = net.eval()

print("GMM Model")
embeddings, labels, norm_threshold = get_embeddings(
    net,
    train_loader,
    num_dim=model_to_num_dim[model],
    dtype=torch.double,
    device=device,
    storage_device=device,
)

gaussians_model, jitter_eps = gmm_fit(embeddings=embeddings, labels=labels, num_classes=num_classes)
logits, labels = gmm_evaluate(
    net,
    gaussians_model,
    test_loader,
    device=device,
    num_classes=num_classes,
    storage_device=device,
)

ood_logits, ood_labels = gmm_evaluate(
    net,
    gaussians_model,
    ood_test_loader,
    device=device,
    num_classes=num_classes,
    storage_device=device,
)
m1_fpr95, m1_auroc, m1_auprc = get_roc_auc_logits(logits, ood_logits, logsumexp, device, conf=True)
print(f"contrastive-:m1_auroc1:{m1_auroc:.4f},m1_auprc:{m1_auprc:.4f}")

load ../saved_models/run17/resnet50_sn_3.0_mod_seed_1/2024_05_21_16_49_32/resnet50_sn_3.0_mod_seed_1_best.model
GMM Model
get embeddings from dataloader...


100%|██████████| 88/88 [00:14<00:00,  6.10it/s]


norm threshold:0.01930484175682068


100%|██████████| 20/20 [00:01<00:00, 12.62it/s]
100%|██████████| 20/20 [00:01<00:00, 11.96it/s]

contrastive-:m1_auroc1:0.9561,m1_auprc:0.9622



