In [1]:
import torch
import torchvision.transforms as tvt
from torch.utils.data import ConcatDataset, DataLoader
from torchmetrics import AUROC  # additional dependency
from torchvision.datasets import CIFAR10
from tqdm.notebook import tqdm

from oodtk import NegativeEnergy, Softmax, ODIN
from oodtk.dataset.img import Textures, CIFAR10C, CIFAR10P, LSUNCrop, LSUNResize, TinyImageNetResize, TinyImageNetCrop
from oodtk.model import WideResNet
from oodtk.utils import is_unknown, OODMetrics
from oodtk.transforms import ToRGB


In [2]:
torch.manual_seed(123)

mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]

trans = tvt.Compose([ToRGB(), tvt.Resize((32,32)), tvt.ToTensor(), tvt.Normalize(mean, std)])

# setup data
dataset_train = CIFAR10(root="data", train=True, download=True, transform=trans)
dataset_in_test = CIFAR10(root="data", train=False, transform=trans)
dataset_out_test1 = Textures(root="data", download=True, transform=trans)
dataset_out_test2 = LSUNCrop(root="data", download=True, transform=trans)
dataset_out_test3 = LSUNResize(root="data", download=True, transform=trans)
dataset_out_test4 = TinyImageNetResize(root="data", download=True, transform=trans)
dataset_out_test5 = TinyImageNetCrop(root="data", download=True, transform=trans)
dataset_test = dataset_in_test + dataset_out_test1 + dataset_out_test2 + dataset_out_test3 + dataset_out_test4 + dataset_out_test5
train_loader = DataLoader(dataset_train, batch_size=128, num_workers=20)
test_loader = DataLoader(dataset_test, batch_size=128, num_workers=20)

Files already downloaded and verified




In [3]:
model = WideResNet.from_pretrained("oe-cifar10-tune", num_classes=10).eval().cuda()
method = Softmax(model).cuda()
metrics = OODMetrics()

with torch.no_grad():
    for n, batch in enumerate(test_loader):
        x, y = batch
        x = x.cuda()
        y = y.cuda()

        metrics.update(method.predict(x), y)

print(metrics.compute())
metrics.reset()



{'AUROC': 0.9857486486434937, 'AUPR-IN': 0.9964964389801025, 'AUPR-OUT': 0.9549047946929932, 'ACC95TPR': 0.9451833367347717, 'FPR95TPR': 0.07680000364780426}


In [4]:
model = WideResNet.from_pretrained("er-cifar10-tune", num_classes=10).eval().cuda()
energy = NegativeEnergy(model)
metrics = OODMetrics()

with torch.no_grad():
    for batch in tqdm(test_loader):
        x, y = batch
        x = x.cuda()
        y = y.cuda()
        metrics.update(energy.predict(x), y)

print(metrics.compute())
metrics.reset()

  0%|          | 0/435 [00:00<?, ?it/s]

{'AUROC': 0.9888905882835388, 'AUPR-IN': 0.996794581413269, 'AUPR-OUT': 0.965325653553009, 'ACC95TPR': 0.9515097141265869, 'FPR95TPR': 0.041600000113248825}


In [5]:
model = WideResNet.from_pretrained("cifar10-pt", num_classes=10).eval().cuda()
method = Softmax(model)
metrics = OODMetrics()

with torch.no_grad():
    for batch in tqdm(test_loader):
        x, y = batch
        x = x.cuda()
        y = y.cuda()

        metrics.update(method.predict(x), y)

print(metrics.compute())
metrics.reset()

  0%|          | 0/435 [00:00<?, ?it/s]

{'AUROC': 0.9217170476913452, 'AUPR-IN': 0.9772027730941772, 'AUPR-OUT': 0.79755699634552, 'ACC95TPR': 0.9085010886192322, 'FPR95TPR': 0.2809000015258789}


In [6]:
model = WideResNet.from_pretrained("cifar10-pt", num_classes=10).eval().cuda()
method = NegativeEnergy(model)
metrics = OODMetrics()

with torch.no_grad():
    for batch in tqdm(test_loader):
        x, y = batch
        x = x.cuda()
        y = y.cuda()

        metrics.update(method.predict(x), y)

print(metrics.compute())
metrics.reset()

  0%|          | 0/435 [00:00<?, ?it/s]

{'AUROC': 0.9384509325027466, 'AUPR-IN': 0.9850852489471436, 'AUPR-OUT': 0.767615795135498, 'ACC95TPR': 0.9011861681938171, 'FPR95TPR': 0.3215999901294708}


In [10]:
model = WideResNet.from_pretrained("cifar10-pt", num_classes=10).eval().cuda()
odin = ODIN(model, eps=0.002, norm_std=std)
metrics = OODMetrics()

with torch.no_grad():
    for batch in tqdm(test_loader):
        x, y = batch
        x = x.cuda()
        y = y.cuda()

        metrics.update(method.predict(x).detach(), y)

print(metrics.compute())
metrics.reset()



  0%|          | 0/435 [00:00<?, ?it/s]



{'AUROC': 0.4104202389717102, 'AUPR-IN': 0.7914378643035889, 'AUPR-OUT': 0.14317642152309418, 'ACC95TPR': 0.7848490476608276, 'FPR95TPR': 0.9783999919891357}
