In [19]:
import torch
import torchvision.transforms as tvt
from torch.utils.data import ConcatDataset, DataLoader
from torchmetrics import AUROC  # additional dependency
from torchvision.datasets import CIFAR10

from oodtk import Mahalanobis
from oodtk.dataset.img import Textures, CIFAR10C, CIFAR10P, LSUNCrop, LSUNResize, TinyImageNetResize, TinyImageNetCrop
from oodtk.model import WideResNet
from oodtk.utils import is_unknown
from oodtk.transforms import ToRGB

In [20]:
torch.manual_seed(123)

mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]

trans = tvt.Compose([ToRGB(), tvt.Resize((32,32)), tvt.ToTensor(), tvt.Normalize(mean, std)])

# setup data
dataset_train = CIFAR10(root="data", train=True, download=True, transform=trans)
dataset_in_test = CIFAR10(root="data", train=False, transform=trans)
dataset_out_test1 = Textures(root="data", download=True, transform=trans)
dataset_out_test2 = LSUNCrop(root="data", download=True, transform=trans)
dataset_out_test3 = LSUNResize(root="data", download=True, transform=trans)
dataset_out_test4 = TinyImageNetResize(root="data", download=True, transform=trans)
dataset_out_test5 = TinyImageNetCrop(root="data", download=True, transform=trans)
dataset_test = dataset_in_test + dataset_out_test1 + dataset_out_test2 + dataset_out_test3 + dataset_out_test4 + dataset_out_test5
train_loader = DataLoader(dataset_train, batch_size=256)
test_loader = DataLoader(dataset_test, batch_size=256)

Files already downloaded and verified


In [21]:
from torch import nn

model = WideResNet.from_pretrained("cifar10-pt", num_classes=10)

model = nn.Sequential(
     model.conv1,
     model.block1,
     model.block2,
     model.block3,
     model.relu,
     model.bn1,
     nn.AvgPool2d(8),
     nn.Flatten())

model = model.cuda()

In [22]:
maha = Mahalanobis(model, eps=0.002, norm_std=std).cuda()
_ = maha.fit(train_loader)

In [23]:
from tqdm.notebook import tqdm

auroc = AUROC(num_classes=2)
model.eval()

for eps in [0.0, 0.0004, 0.0008, 0.0014, 0.002, 0.0024, 0.0028, 0.0032, 0.0038, 0.0048]:
    maha.eps = eps
    for batch in tqdm(test_loader):
        x, y = batch
        x = x.cuda()
        y = y.cuda()
        auroc.update(maha(x), is_unknown(y))

    print(eps, auroc.compute())
    auroc.reset()

  0%|          | 0/218 [00:00<?, ?it/s]

0.0 tensor(0.8439, device='cuda:0')




  0%|          | 0/218 [00:00<?, ?it/s]

0.0004 tensor(0.8534, device='cuda:0')


  0%|          | 0/218 [00:00<?, ?it/s]

0.0008 tensor(0.8606, device='cuda:0')


  0%|          | 0/218 [00:00<?, ?it/s]

0.0014 tensor(0.8686, device='cuda:0')


  0%|          | 0/218 [00:00<?, ?it/s]

0.002 tensor(0.8737, device='cuda:0')


  0%|          | 0/218 [00:00<?, ?it/s]

0.0024 tensor(0.8753, device='cuda:0')


  0%|          | 0/218 [00:00<?, ?it/s]

0.0028 tensor(0.8753, device='cuda:0')


  0%|          | 0/218 [00:00<?, ?it/s]

0.0032 tensor(0.8737, device='cuda:0')


  0%|          | 0/218 [00:00<?, ?it/s]

0.0038 tensor(0.8680, device='cuda:0')


  0%|          | 0/218 [00:00<?, ?it/s]

0.0048 tensor(0.8522, device='cuda:0')


In [24]:
maha.mu.shape

torch.Size([10, 128])

In [25]:
dataset_in_test = CIFAR10(root="data", train=False, transform=trans)
dataset_out_test = CIFAR10C(root="data", subset="all", download=True, transform=trans)
dataset_test = dataset_in_test +  dataset_out_test
test_loader = DataLoader(dataset_test, batch_size=128)

auroc = AUROC(num_classes=2)
model.eval()


for batch in tqdm(test_loader):
    x, y = batch
    x = x.cuda()
    y = y.cuda()

    auroc.update(maha(x), is_unknown(y))

print(eps, auroc.compute())


  0%|          | 0/79 [00:00<?, ?it/s]

ValueError: No positive samples in targets, true positive value should be meaningless

In [None]:
dataset_in_test = CIFAR10(root="data", train=False, transform=trans)
dataset_out_test = CIFAR10P(root="data", download=True, transform=trans)
dataset_test = dataset_in_test +  dataset_out_test
test_loader = DataLoader(dataset_test, batch_size=128)

model.eval()

for batch in tqdm(test_loader):
    x, y = batch
    x = x.cuda()
    y = y.cuda()

auroc.update(maha(x), is_unknown(y))

print(auroc.compute())

In [None]:
dataset_out_test = CIFAR10C(root="data", subset="all", download=True, transform=trans)