In [1]:
import torch
import torchvision.transforms as tvt
from torch.utils.data import ConcatDataset, DataLoader
from torchmetrics import AUROC  # additional dependency
from torchvision.datasets import CIFAR10

from oodtk import NegativeEnergy, Softmax
from oodtk.dataset.img import Textures, CIFAR10C, CIFAR10P, LSUNCrop, LSUNResize, TinyImageNetResize, TinyImageNetCrop
from oodtk.model import WideResNet
from oodtk.utils import is_unknown
from oodtk.transforms import ToRGB

In [2]:
torch.manual_seed(123)

mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]

trans = tvt.Compose([ToRGB(), tvt.Resize((32,32)), tvt.ToTensor(), tvt.Normalize(mean, std)])

# setup data
dataset_train = CIFAR10(root="data", train=True, download=True, transform=trans)
dataset_in_test = CIFAR10(root="data", train=False, transform=trans)
dataset_out_test1 = Textures(root="data", download=True, transform=trans)
dataset_out_test2 = LSUNCrop(root="data", download=True, transform=trans)
dataset_out_test3 = LSUNResize(root="data", download=True, transform=trans)
dataset_out_test4 = TinyImageNetResize(root="data", download=True, transform=trans)
dataset_out_test5 = TinyImageNetCrop(root="data", download=True, transform=trans)
dataset_test = dataset_in_test + dataset_out_test1 + dataset_out_test2 + dataset_out_test3 + dataset_out_test4 + dataset_out_test5
train_loader = DataLoader(dataset_train, batch_size=128)
test_loader = DataLoader(dataset_test, batch_size=128)

Files already downloaded and verified


In [3]:
model = WideResNet.from_pretrained("oe-cifar10-tune", num_classes=10)
model = model.cuda()

In [4]:
# create some methods
energy = NegativeEnergy(model).cuda()
softmax = Softmax(model).cuda()

# evaluate
auroc_energy = AUROC(num_classes=2)
auroc_softmax = AUROC(num_classes=2)
model.eval()

with torch.no_grad():
    for n, batch in enumerate(test_loader):
        x, y = batch
        x = x.cuda()
        y = y.cuda()

        auroc_energy.update(energy(x), is_unknown(y))
        auroc_softmax.update(softmax(x), is_unknown(y))

print(auroc_softmax.compute())
print(auroc_energy.compute())



tensor(0.9857, device='cuda:0')
tensor(0.9866, device='cuda:0')


In [6]:
dataset_in_test = CIFAR10(root="data", train=False, transform=trans)
dataset_out_test = CIFAR10C(root="data", subset="all", download=True, transform=trans)
dataset_test = dataset_in_test +  dataset_out_test
test_loader = DataLoader(dataset_test, batch_size=128)

auroc_energy = AUROC(num_classes=2)
auroc_softmax = AUROC(num_classes=2)
model.eval()

with torch.no_grad():
    for n, batch in enumerate(test_loader):
        x, y = batch
        x = x.cuda()
        y = y.cuda()

        auroc_energy.update(energy(x), is_unknown(y))
        auroc_softmax.update(softmax(x), is_unknown(y))

print(auroc_softmax.compute())
print(auroc_energy.compute())

Downloading https://zenodo.org/record/2535967/files/CIFAR-10-C.tar to data/CIFAR-10-C.tar


  0%|          | 0/2918471680 [00:00<?, ?it/s]

Extracting data/CIFAR-10-C.tar to data




tensor(0., device='cuda:0')
tensor(0., device='cuda:0')




In [None]:
dataset_in_test = CIFAR10(root="data", train=False, transform=trans)
dataset_out_test = CIFAR10P(root="data", download=True, transform=trans)
dataset_test = dataset_in_test +  dataset_out_test
test_loader = DataLoader(dataset_test, batch_size=128)

auroc_energy = AUROC(num_classes=2)
auroc_softmax = AUROC(num_classes=2)
model.eval()

with torch.no_grad():
    for n, batch in enumerate(test_loader):
        x, y = batch
        x = x.cuda()
        y = y.cuda()

        auroc_energy.update(energy(x), is_unknown(y))
        auroc_softmax.update(softmax(x), is_unknown(y))

print(auroc_softmax.compute())
print(auroc_energy.compute())

In [None]:
dataset_out_test = CIFAR10C(root="data", subset="all", download=True, transform=trans)

In [7]:

dataset_out_test = CIFAR10P(root="data", download=True, transform=trans)

Downloading https://zenodo.org/record/2535967/files/CIFAR-10-P.tar to data/CIFAR-10-P.tar


  0%|          | 0/18278422016 [00:00<?, ?it/s]

Extracting data/CIFAR-10-P.tar to data
