In [1]:
!pip install pytorch-ood
!pip install pandas scikit-learn pandas

Collecting pytorch-ood
  Downloading pytorch_ood-0.1.6-py3-none-any.whl (123 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m123.3/123.3 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
Collecting torchmetrics>=1.0.0 (from pytorch-ood)
  Downloading torchmetrics-1.1.1-py3-none-any.whl (763 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m763.4/763.4 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from torchmetrics>=1.0.0->pytorch-ood)
  Downloading lightning_utilities-0.9.0-py3-none-any.whl (23 kB)
Installing collected packages: lightning-utilities, torchmetrics, pytorch-ood
Successfully installed lightning-utilities-0.9.0 pytorch-ood-0.1.6 torchmetrics-1.1.1


In [2]:
import pandas as pd
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100, CIFAR10, MNIST, FashionMNIST

from pytorch_ood.dataset.img import (
    LSUNCrop,
    LSUNResize,
    Textures,
    TinyImageNetCrop,
    TinyImageNetResize,
    Places365,
    TinyImageNet,
)

from pytorch_ood.detector import (
    ODIN,
    Mahalanobis,
    MaxSoftmax,
)

from pytorch_ood.model import WideResNet
from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

fix_random_seed(123)

trans = WideResNet.transform_for("cifar100-pt")
norm_std = WideResNet.norm_std_for("cifar100-pt")

In [3]:
in_dataset = CIFAR100(root="data", train=False, transform=trans, download=True)
test_datasets = [
    CIFAR100,
    CIFAR10,
    MNIST,
    FashionMNIST,
    LSUNCrop,
    LSUNResize,
    Textures,
    TinyImageNetCrop,
    TinyImageNetResize,
    Places365,
    TinyImageNet,
]
datasets = {}
for test_dataset in test_datasets:
    dataset_out_test = test_dataset(root="data", transform=trans, target_transform=ToUnknown(), download=True)
    test_loader = DataLoader(in_dataset + dataset_out_test, batch_size=256, num_workers=12)
    datasets[test_dataset.__name__] = test_loader

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:01<00:00, 102482504.47it/s]


Extracting data/cifar-100-python.tar.gz to data
Files already downloaded and verified




Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:01<00:00, 88403498.55it/s]


Extracting data/cifar-10-python.tar.gz to data
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 119569388.59it/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 27574708.36it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 31934335.94it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 7526878.22it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz





Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:01<00:00, 15848182.70it/s]


Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 270175.93it/s]


Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:00<00:00, 4959506.06it/s]


Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 5987874.93it/s]


Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading https://uc20291e877d8f394e58fc6c1ded.dl.dropboxusercontent.com/cd/0/inline2/CCwPxJucP0WhfrsqykTm-lH0mS2ebZ8BXGJdrWNIbcAIJDAl7r4XYVGN4ubz8_bAs-RzQMExfrZaMACGWntYR0EIQxMaHGJrHK4sE3xpNgwKNI2bHUUwiaRlVtvcJh9Ghk1Y9m88O0ZWr8avTP8E0RpgVH0taWnQFWi3208ifppw6mFACQl6QedmnZFRtvX5zs4Gs8zD9YtLxQHg7ZJM3bhFLUJWZobeKD0Yul4skMJbH057Fsk2-nOTe-7AJXOxk_5p6XtoRJFTrl0_BnaC1M0wmBfo9k6O17zWQXQwy0GFNR394ZdJhVxRfTdnAgin6AW3a1ZIoE6bmK00XzAzrSA8pSQJ1BICETSvxPkAbhAZS1jGxVfESQUjdDpkmm171lI/file to data/LSUN.tar.gz


100%|██████████| 17309383/17309383 [00:00<00:00, 31526464.93it/s]


Extracting data/LSUN.tar.gz to data
Downloading https://uce1743cbcf2489c19c9e0a576ea.dl.dropboxusercontent.com/cd/0/inline2/CCzbdsh7GZUd1-sc_ajE-mmDTHrhuWiLLSDrQdo4i1za_ls1IuDuzdR8hj43825vktXCW5seOlt_BlAK7d-XBqZnNgCo3fjUbds5H3gqQeJqUN1GvZUy25yJrAJW46yuyzdKg6RDZH-l-JoSKU8TvlF8i8XKpGuS9dWhgJnhWkJwiTZH4s01LP12IlvXz4F1fVgPNAqX-ShBywJOf445gge4iaRg0AbEufl5panpFR6U7Q45yD7rBejJnhM3qvZ2i2qG4HH3-gpLppHArS-6Gk8YXmmuPUdWyfdCLqGstO-0GwwypveDLyRxekmMojcX4QfwXcKengAPWc87awUN1hu3d_W-3u_OxEGU76-6ChPY4zBCNhhgpvsWsD-PE2YAuf8/file to data/LSUN_resize.tar.gz


100%|██████████| 4688973/4688973 [00:00<00:00, 11235441.93it/s]


Extracting data/LSUN_resize.tar.gz to data
Downloading https://thor.robots.ox.ac.uk/datasets/dtd/dtd-r1.0.1.tar.gz to data/textures-r1_0_1.tar.gz


100%|██████████| 625239812/625239812 [00:21<00:00, 29412653.29it/s]


Extracting data/textures-r1_0_1.tar.gz to data
Downloading https://uc23b8b280f107d041f0f1b282e8.dl.dropboxusercontent.com/cd/0/inline2/CCw7-Q6f-9EkTyr_Tt-R7ZNg9hUBY652fyl56IdmXfA7sDNFcz_WPrNO0j32f8Vr8emXt-vFdG4TPXLU3OTrzLbgsSZjgcO_n8_S3xRalch6RKQlxooBQ5xCFKllFjSVqeJgGVsPboVBzTRHX5us-DKY8F0IpFuBP2-Goh-2guqp_uNCw6qwGgjOsDE1ZQTMu-SB-I3Cei2ZzkzpFn0_w3g3vy8PfqXtsMRmXUvFCeHZHVEUDCeixzQNGHJVo-vMtqyfwbkaWHw-Anb-qJ8cJzmbjJai0UDJIviQNU0NXfnkAqGNoAKvD4U_2jmbKkMVqdIOXaINQCzF0C1M4KYhAAmg2bcf3d7s3Q0cec5iHgoraeNbxNIX9LWsWaKjwbhA-uU/file to data/Imagenet.tar.gz


100%|██████████| 26501958/26501958 [00:00<00:00, 40959283.07it/s]


Extracting data/Imagenet.tar.gz to data
Downloading https://ucae71e9375670b324862df54a07.dl.dropboxusercontent.com/cd/0/inline2/CCyAMO-6fQdbU3sinf1lkrBLBSLtYH3w7cTyiMSpMQRhPDLHi_5tYUtpw8Sy7lMIoCkqxICj2s929A7YmWK1kIohGwdFZYiQzFA5Mu94NP-ngpiOJotr9pPYZfhN4u-N3pqg9emxGYXIku6VCNRoqgAnRLpq18IT3f2sWs0i4MLw-QkBeR5P8JOWjWCtXLYPikInMjuFoaWafE64m6FgZyscoeULURbN47X11-LWJA90m1vQMyaxS4Bbg0_3uR7NyR05QaGuEZJENaLBCCDI2g8q_krBNrFyVmc0vegsEo40kFu7BFsquz7MUtVRsyRZyirm1j6ONkhc5IAFsJywRmzvl46O8ShivvzBH6rQgxVd6DLq0inVhIZlr_qGS2Rjd5I/file to data/Imagenet_resize.tar.gz


100%|██████████| 4550980/4550980 [00:00<00:00, 45748385.52it/s]


Extracting data/Imagenet_resize.tar.gz to data


Downloading...
From: https://drive.google.com/uc?id=1Ec-LRSTf6u5vEctKX9vRp9OA6tqnJ0Ay
To: /content/data/places365.zip
100%|██████████| 497M/497M [00:06<00:00, 80.6MB/s]


Downloading http://cs231n.stanford.edu/tiny-imagenet-200.zip to data/tiny-imagenet-200.zip


100%|██████████| 248100043/248100043 [00:05<00:00, 46671787.23it/s]


Extracting data/tiny-imagenet-200.zip to data


In [4]:
model = WideResNet(num_classes=100, pretrained="cifar100-pt").eval().to(device)

detectors = {}
detectors["MSP"] = MaxSoftmax(model)
detectors["Mahalanobis"] = Mahalanobis(model.features)
detectors["ODIN"] = ODIN(model, norm_std=norm_std, eps=0.002)

print(f"- Fitting {len(detectors)} detectors -")
loader_in_train = DataLoader(CIFAR100(root="data", train=True, transform=trans), batch_size=256, num_workers=12)
for name, detector in detectors.items():
    print(f"  > Fitting {name}")
    detector.fit(loader_in_train, device=device)

Downloading: "https://github.com/wetliu/energy_ood/raw/master/CIFAR/snapshots/pretrained/cifar100_wrn_pretrained_epoch_99.pt" to /root/.cache/torch/hub/checkpoints/wrn-cifar100-pt.pt
100%|██████████| 8.66M/8.66M [00:00<00:00, 101MB/s]


- Fitting 3 detectors -
  > Fitting MSP
  > Fitting Mahalanobis
  > Fitting ODIN


In [5]:
results = []

with torch.no_grad():
    for detector_name, detector in detectors.items():
        print(f"- Evaluating {detector_name} -")
        for dataset_name, loader in datasets.items():
            print(f"  {dataset_name}")
            metrics = OODMetrics()
            for x, y in loader:
                metrics.update(detector(x.to(device)), y.to(device))

            r = {"Detector": detector_name, "Dataset": dataset_name}
            r.update(metrics.compute())
            results.append(r)

- Evaluating MSP -
  CIFAR100
  CIFAR10
  MNIST
  FashionMNIST
  LSUNCrop
  LSUNResize
  Textures
  TinyImageNetCrop
  TinyImageNetResize
  Places365
  TinyImageNet
- Evaluating Mahalanobis -
  CIFAR100
  CIFAR10
  MNIST
  FashionMNIST
  LSUNCrop
  LSUNResize
  Textures
  TinyImageNetCrop
  TinyImageNetResize
  Places365
  TinyImageNet
- Evaluating ODIN -
  CIFAR100
  CIFAR10
  MNIST
  FashionMNIST
  LSUNCrop
  LSUNResize
  Textures
  TinyImageNetCrop
  TinyImageNetResize
  Places365
  TinyImageNet


In [7]:
df = pd.DataFrame(results).set_index(['Dataset', 'Detector'])
print(df)
df.to_csv("./result/result.csv")

                                   AUROC   AUPR-IN  AUPR-OUT  FPR95TPR
Dataset            Detector                                           
CIFAR100           MSP          0.344688  0.738764  0.133239    0.9542
CIFAR10            MSP          0.755213  0.924956  0.500790    0.6427
MNIST              MSP          0.746989  0.929356  0.535628    0.5735
FashionMNIST       MSP          0.893815  0.976642  0.737556    0.3661
LSUNCrop           MSP          0.855878  0.843563  0.874041    0.4713
LSUNResize         MSP          0.741252  0.699077  0.773360    0.6496
Textures           MSP          0.735491  0.575050  0.831475    0.7139
TinyImageNetCrop   MSP          0.863217  0.848073  0.882328    0.4335
TinyImageNetResize MSP          0.739807  0.701496  0.767825    0.6605
Places365          MSP          0.739189  0.894454  0.500872    0.6991
TinyImageNet       MSP          0.764159  0.963127  0.327413    0.6524
CIFAR100           Mahalanobis  0.439412  0.803425  0.148923    0.9549
CIFAR1