In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import torch.backends.cudnn as cudnn
from datasets.utils import build_dataset, get_ood_dataloaders
from models.utils import build_model
from utils.eval import *
from utils.loss import LogitNormLoss


gpu = 0         # gpu id, use cpu if None
seed = 1        # random seed
num_to_avg = 10 # number of trials with random images in OOD datasets to calculate the average performance scores

model_type = 'wrn'          # model name
loss_function = 'logitnorm'    # loss function 'normal' 'logitnorm'
dataset =  'cifar10'        # training dataset

# file name to save the weights and training infomation
method_name = '_'.join([dataset, model_type, loss_function, 'standard'])
print("method_name: " + method_name)

score_function = 'energy'   # post-hoc score function used for OOD detection
test_bs = 200 # 200       # batch size for training and testing 
num_classes = 10    # number of classes of the training dataset
input_size = 32     # input image size for the model
input_channels = 3  # number of input image channels for the model
mean=(0.492, 0.482, 0.446) # mean value for normalization of input images
std=(0.247, 0.244, 0.262)  # standard deviation for normalization of input images
prefetch_threads = 4       # number of threads used for input image preprocessing
save_path = './snapshots/' # folder path to save the weights and training information

if loss_function == 'normal':
    init_temp = 1.5
elif loss_function == 'logitnorm':
    init_temp = 0.1

method_name: cifar10_wrn_logitnorm_standard


In [2]:
# initiate the device
# if gpu is not None:
#     device = torch.device('cuda:{}'.format(int(gpu)))
#     torch.cuda.manual_seed(seed)
# else:
#     device = torch.device('cpu')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

torch.manual_seed(seed)
np.random.seed(seed)

# Create model and load weights from .pt file
net = build_model(model_type, num_classes, device, load=True, path=save_path, filename=method_name)

net.eval()
cudnn.benchmark = True  # fire on all cylinders

# load ID test dataset
test_data = build_dataset(dataset, mode="test", size=input_size, channels=input_channels,
                          mean=mean, std=std)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=test_bs, shuffle=False,
                                          num_workers=prefetch_threads, pin_memory=True)


AssertionError: could not resume

In [None]:
# load OOD test datasets
ood_num_examples = len(test_data) // 5
print('ood_num_examples:', ood_num_examples)
ood_data_list = ["Textures", "SVHN", 'LSUN-C', 'LSUN-R', "iSUN", "Places365"] # name of OOD datasets to load
ood_loader_dict = get_ood_dataloaders(ood_data_list, input_size=input_size, input_channels=input_channels, 
                                      mean=mean, std=std, test_bs=test_bs, prefetch_threads=prefetch_threads,
                                      seed=seed)

# get and print all the ID and OOD detection performance measures of the model
error_rate, ece_error, auroc, aupr, fpr, auroc_list, aupr_list, fpr_list = get_all_measures(
    net, test_loader, ood_loader_dict, device, temp=0.1, init_temp=init_temp, score_function=score_function, 
    recall_level=0.95, ood_num_examples=ood_num_examples, test_bs=test_bs, to_string=True, num_to_avg=num_to_avg,
    method_name=method_name)

ood_num_examples: 2000
Error Rate 5.50
				cifar10_wrn_logitnorm_standard
FPR95:			65.09
AUROC: 			87.35
AUPR:  			99.14


ECE Error
ECE Error 1.43


Textures Detection
				cifar10_wrn_logitnorm_standard
FPR95:			94.44	+/- 0.55
AUROC: 			62.99	+/- 0.42
AUPR:  			69.83	+/- 0.43


SVHN Detection
				cifar10_wrn_logitnorm_standard
FPR95:			99.47	+/- 0.11
AUROC: 			65.00	+/- 0.25
AUPR:  			76.23	+/- 0.14


LSUN-C Detection
				cifar10_wrn_logitnorm_standard
FPR95:			99.30	+/- 0.19
AUROC: 			60.21	+/- 0.42
AUPR:  			71.71	+/- 0.23


LSUN-R Detection
				cifar10_wrn_logitnorm_standard
FPR95:			99.21	+/- 0.11
AUROC: 			74.04	+/- 0.26
AUPR:  			82.15	+/- 0.21


iSUN Detection
				cifar10_wrn_logitnorm_standard
FPR95:			99.51	+/- 0.10
AUROC: 			71.96	+/- 0.21
AUPR:  			80.91	+/- 0.13


Places365 Detection
				cifar10_wrn_logitnorm_standard
FPR95:			87.26	+/- 0.48
AUROC: 			78.22	+/- 0.35
AUPR:  			81.39	+/- 0.28


Mean Test Results
				cifar10_wrn_logitnorm_standard
FPR95:			96.53
AUROC: 			68.