In [1]:
import torch
import torchvision.transforms as tvt
from torch.utils.data import DataLoader
from torchmetrics import AUROC  # additional dependency
from torchvision.datasets import CIFAR10
from tqdm.notebook import tqdm

from oodtk import NegativeEnergy, Softmax
from oodtk.dataset.img import Textures, CIFAR10C, CIFAR10P, LSUNCrop, LSUNResize, TinyImageNetResize, TinyImageNetCrop
from oodtk.model import VisionTransformer
from oodtk.utils import is_unknown
from oodtk.transforms import ToRGB

In [2]:
torch.manual_seed(123)
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]

trans = tvt.Compose([ToRGB(), tvt.Resize((384,384)), tvt.ToTensor(), tvt.Normalize(mean, std)])

# setup data
# dataset_train = CIFAR10(root="data", train=True, download=True, transform=trans)
dataset_in_test = CIFAR10(root="data", train=False, transform=trans)
dataset_out_test1 = Textures(root="data", download=True, transform=trans)
dataset_out_test2 = LSUNCrop(root="data", download=True, transform=trans)
dataset_out_test3 = LSUNResize(root="data", download=True, transform=trans)
dataset_out_test4 = TinyImageNetResize(root="data", download=True, transform=trans)
dataset_out_test5 = TinyImageNetCrop(root="data", download=True, transform=trans)
dataset_test = dataset_in_test + dataset_out_test1 + dataset_out_test2 + dataset_out_test3 + dataset_out_test4 + dataset_out_test5
# train_loader = DataLoader(dataset_train, batch_size=16)
test_loader = DataLoader(dataset_test, batch_size=8)

In [3]:
model = VisionTransformer.from_pretrained("b16-cifar10-tune", num_classes=10, image_size=(384, 384))
model = model.cuda()

In [4]:
# create some methods
energy = NegativeEnergy(model).cuda()
softmax = Softmax(model).cuda()

# evaluate
auroc_energy = AUROC(num_classes=2)
auroc_softmax = AUROC(num_classes=2)
model.eval()

with torch.no_grad():
    for batch in tqdm(test_loader):
        x, y = batch
        x = x.cuda()
        y = y.cuda()

        auroc_energy.update(energy(x), is_unknown(y))
        auroc_softmax.update(softmax(x), is_unknown(y))

print(auroc_softmax.compute())
print(auroc_energy.compute())



  0%|          | 0/6955 [00:00<?, ?it/s]

KeyboardInterrupt: 