In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import torch.backends.cudnn as cudnn
from datasets.utils import build_dataset, get_ood_dataloaders
from models.utils import build_model
from utils.eval import *
from utils.loss import LogitNormLoss


gpu = 0         # gpu id, use cpu if None
seed = 1        # random seed
num_to_avg = 10 # number of trials with random images in OOD datasets to calculate the average performance scores

model_type = 'wrn'          # model name
loss_function = 'normal'    # loss function
dataset =  'cifar10'        # training dataset

# file name to save the weights and training infomation
method_name = '_'.join([dataset, model_type, loss_function, 'standard'])
print("method_name: " + method_name)

score_function = 'MSP'   # post-hoc score function used for OOD detection
test_bs = 200       # batch size for training and testing 
num_classes = 10    # number of classes of the training dataset
input_size = 32     # input image size for the model
input_channels = 3  # number of input image channels for the model
mean=(0.492, 0.482, 0.446) # mean value for normalization of input images
std=(0.247, 0.244, 0.262)  # standard deviation for normalization of input images
prefetch_threads = 4       # number of threads used for input image preprocessing
save_path = './snapshots/' # folder path to save the weights and training information


method_name: cifar10_wrn_normal_standard


In [2]:
# initiate the device
if gpu is not None:
    device = torch.device('cuda:{}'.format(int(gpu)))
    torch.cuda.manual_seed(seed)
else:
    device = torch.device('cpu')

torch.manual_seed(seed)
np.random.seed(seed)

# Create model and load weights from .pt file
net = build_model(model_type, num_classes, device, load=True, path=save_path, filename=method_name)

net.eval()
cudnn.benchmark = True  # fire on all cylinders

# load ID test dataset
test_data = build_dataset(dataset, mode="test", size=input_size, channels=input_channels,
                          mean=mean, std=std)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=test_bs, shuffle=False,
                                          num_workers=prefetch_threads, pin_memory=True)


Model restored! Epoch: 199
Files already downloaded and verified


In [3]:
# load OOD test datasets
ood_num_examples = len(test_data) // 5
ood_data_list = ["MNIST", "GTSRB", "Textures", "SVHN", "Places365"] # name of OOD datasets to load
ood_loader_dict = get_ood_dataloaders(ood_data_list, input_size=input_size, input_channels=input_channels, 
                                      mean=mean, std=std, test_bs=test_bs, prefetch_threads=prefetch_threads)

# get and print all the ID and OOD detection performance measures of the model
error_rate, ece_error, auroc, aupr, fpr, auroc_list, aupr_list, fpr_list = get_all_measures(
    net, test_loader, ood_loader_dict, device, temp='optimal', score_function=score_function, 
    recall_level=0.95, ood_num_examples=ood_num_examples, test_bs=test_bs, to_string=True, 
    method_name=method_name)

Before temperature - NLL: 0.205, ECE: 0.031
Optimal temperature: 1.541
After temperature - NLL: 0.168, ECE: 0.010
Error Rate 5.26
				cifar10_wrn_normal_standard
FPR95:			34.22
AUROC: 			92.85
AUPR:  			99.50


ECE Error
ECE Error 1.00
Before temperature - NLL: 0.205, ECE: 0.031
Optimal temperature: 1.541
After temperature - NLL: 0.168, ECE: 0.010


MNIST Detection
				cifar10_wrn_normal_standard
FPR95:			37.72	+/- 0.85
AUROC: 			94.58	+/- 0.12
AUPR:  			98.90	+/- 0.03


GTSRB Detection
				cifar10_wrn_normal_standard
FPR95:			54.56	+/- 1.11
AUROC: 			91.00	+/- 0.22
AUPR:  			98.06	+/- 0.06


Textures Detection
				cifar10_wrn_normal_standard
FPR95:			55.62	+/- 0.64
AUROC: 			87.30	+/- 0.30
AUPR:  			96.34	+/- 0.15


SVHN Detection
				cifar10_wrn_normal_standard
FPR95:			29.72	+/- 0.94
AUROC: 			95.24	+/- 0.12
AUPR:  			98.95	+/- 0.05


Places365 Detection
				cifar10_wrn_normal_standard
FPR95:			55.56	+/- 1.38
AUROC: 			87.14	+/- 0.46
AUPR:  			96.42	+/- 0.21


Mean Test Results
				