Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added some datasets from OpenOOD #18

Merged
merged 1 commit into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ The package can be installed via PyPI:

* ``libmr`` for the OpenMax Detector [#OpenMax]_ . The library is currently broken and unlikely to be repaired.
You will have to install ``cython`` and ``libmr`` afterwards manually.
* ``scikit`` for ViM
* ``scikit-learn`` for ViM
* ``gdown`` to download some datasets and model weights
* ``pandas`` for the `examples <https://pytorch-ood.readthedocs.io/en/latest/auto_examples/index.html>`_.
* ``segmentation-models-pytorch`` to run the examples for anomaly segmentation

Expand Down
4 changes: 2 additions & 2 deletions examples/benchmarks/interface/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
Benchmark
-------------------------

Ready-to-use benchmarks provide a simple interface to replicate the experiments
used in some publications. While they are convenient, this comes at the price of less flexibility.
Ready-to-use benchmarks provide a simple interface to (approximately) replicate the experiments
of other publications. While they are convenient, this comes at the price of less flexibility.

6 changes: 3 additions & 3 deletions examples/benchmarks/interface/cifar100_odin.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""

ODIN - CIFAR 100
ODIN - CIFAR100
==================

Reproduces the ODIN benchmark for OOD detection.

Reproduces the ODIN benchmark for OOD detection, from the paper
*Enhancing the reliability of out-of-distribution image detection in neural networks*.

"""
import pandas as pd # additional dependency, used here for convenience
Expand Down
5 changes: 3 additions & 2 deletions examples/benchmarks/interface/cifar10_odin.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""

ODIN - CIFAR 10
ODIN - CIFAR10
==================

Reproduces the ODIN benchmark for OOD detection.
Reproduces the ODIN benchmark for OOD detection, from the paper
*Enhancing the reliability of out-of-distribution image detection in neural networks*.


"""
Expand Down
68 changes: 68 additions & 0 deletions examples/benchmarks/interface/cifar10_openood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""

OpenOOD - CIFAR10
==================

Reproduces the OpenOOD benchmark for OOD detection, using the WideResNet
model from the Hendrycks baseline paper.

.. warning :: This is currently incomplete, see :class:`CIFAR10-OpenOOD <pytorch_ood.benchmark.CIFAR10_OpenOOD>`.

"""
import pandas as pd # additional dependency, used here for convenience
import torch

from pytorch_ood.benchmark import CIFAR10_OpenOOD
from pytorch_ood.detector import ODIN, MaxSoftmax
from pytorch_ood.model import WideResNet
from pytorch_ood.utils import fix_random_seed

fix_random_seed(123)

device = "cuda:0"
loader_kwargs = {"batch_size": 64}

# %%
model = WideResNet(num_classes=10, pretrained="cifar10-pt").eval().to(device)
trans = WideResNet.transform_for("cifar10-pt")
norm_std = WideResNet.norm_std_for("cifar10-pt")

# %%
detectors = {
"MSP": MaxSoftmax(model),
}

# %%
results = []
benchmark = CIFAR10_OpenOOD(root="data", transform=trans)

with torch.no_grad():
for detector_name, detector in detectors.items():
print(f"> Evaluating {detector_name}")
res = benchmark.evaluate(detector, loader_kwargs=loader_kwargs, device=device)
for r in res:
r.update({"Detector": detector_name})
results += res

df = pd.DataFrame(results)
print((df.set_index(["Dataset", "Detector"]) * 100).to_csv(float_format="%.2f"))

# %%
# This should produce the following table:
#
# +--------------+----------+-------+---------+----------+----------+
# | Dataset | Detector | AUROC | AUPR-IN | AUPR-OUT | FPR95TPR |
# +==============+==========+=======+=========+==========+==========+
# | CIFAR100 | MSP | 87.83 | 85.20 | 88.42 | 43.08 |
# +--------------+----------+-------+---------+----------+----------+
# | TinyImageNet | MSP | 87.06 | 85.05 | 86.82 | 51.27 |
# +--------------+----------+-------+---------+----------+----------+
# | MNIST | MSP | 92.66 | 90.29 | 94.33 | 22.47 |
# +--------------+----------+-------+---------+----------+----------+
# | FashionMNIST | MSP | 94.95 | 93.36 | 96.18 | 15.59 |
# +--------------+----------+-------+---------+----------+----------+
# | Textures | MSP | 88.51 | 78.50 | 92.99 | 40.86 |
# +--------------+----------+-------+---------+----------+----------+
# | Places365 | MSP | 88.24 | 95.61 | 71.17 | 44.65 |
# +--------------+----------+-------+---------+----------+----------+
#
70 changes: 70 additions & 0 deletions examples/benchmarks/interface/imagenet_openood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""
OpenOOD - ImageNet
===================

Reproduces the OpenOOD benchmark for OOD detection, using a pre-trained ResNet 50.

.. warning :: This is currently incomplete, see :class:`ImageNet-OpenOOD <pytorch_ood.benchmark.ImageNet_OpenOOD>`.

"""
import pandas as pd # additional dependency, used here for convenience
import torch
from torchvision.models import resnet50
from torchvision.models.resnet import ResNet50_Weights

from pytorch_ood.benchmark import ImageNet_OpenOOD
from pytorch_ood.detector import MaxSoftmax
from pytorch_ood.utils import fix_random_seed

fix_random_seed(123)

device = "cuda:0"
loader_kwargs = {"batch_size": 16, "num_workers": 12}

# %%
model = resnet50(ResNet50_Weights.IMAGENET1K_V1).eval().to(device)
trans = ResNet50_Weights.IMAGENET1K_V1.transforms()

print(trans)

# %%
# If you want to test more detectors, you can just add them here
detectors = {
"MSP": MaxSoftmax(model),
}

# %%
# The ImageNet root should contain at least the validation tar, the dev kit tar, and the meta.bin
# that is generated by the torchvision ImageNet implementation.
results = []
benchmark = ImageNet_OpenOOD(root="data", image_net_root="data/imagenet-2012/", transform=trans)


with torch.no_grad():
for detector_name, detector in detectors.items():
print(f"> Evaluating {detector_name}")
res = benchmark.evaluate(detector, loader_kwargs=loader_kwargs, device=device)
for r in res:
r.update({"Detector": detector_name})
results += res

df = pd.DataFrame(results)
print((df.set_index(["Dataset", "Detector"]) * 100).to_csv(float_format="%.2f"))

# %%
# This should produce a table with the following output:
#
# +-------------+----------+-------+---------+----------+----------+
# | Dataset | Detector | AUROC | AUPR-IN | AUPR-OUT | FPR95TPR |
# +=============+==========+=======+=========+==========+==========+
# | ImageNetO | MSP | 28.64 | 2.52 | 94.85 | 91.20 |
# +-------------+----------+-------+---------+----------+----------+
# | OpenImagesO | MSP | 84.98 | 62.61 | 94.67 | 49.95 |
# +-------------+----------+-------+---------+----------+----------+
# | Textures | MSP | 80.46 | 37.50 | 96.80 | 67.75 |
# +-------------+----------+-------+---------+----------+----------+
# | SVHN | MSP | 97.62 | 95.56 | 98.77 | 11.58 |
# +-------------+----------+-------+---------+----------+----------+
# | MNIST | MSP | 90.04 | 90.45 | 89.88 | 39.03 |
# +-------------+----------+-------+---------+----------+----------+
#
26 changes: 24 additions & 2 deletions src/pytorch_ood/benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
API
==================

Each benchmark implements a common API.
Each benchmark implements a common interface.

.. note :: This is currently a draft and likely subject to change in the
future.
Expand Down Expand Up @@ -44,6 +44,12 @@
:members:


OpenOOD Benchmark
-----------------

.. autoclass:: pytorch_ood.benchmark.CIFAR10_OpenOOD
:members:


CIFAR 100
^^^^^^^^^^^
Expand All @@ -54,7 +60,23 @@
.. autoclass:: pytorch_ood.benchmark.CIFAR100_ODIN
:members:

OpenOOD Benchmark
-----------------

.. autoclass:: pytorch_ood.benchmark.CIFAR100_OpenOOD
:members:


ImageNet
^^^^^^^^^^^

OpenOOD Benchmark
-----------------

.. autoclass:: pytorch_ood.benchmark.ImageNet_OpenOOD
:members:


"""
from .base import Benchmark
from .img import CIFAR10_ODIN, CIFAR100_ODIN
from .img import CIFAR10_ODIN, CIFAR100_ODIN, CIFAR10_OpenOOD, CIFAR100_OpenOOD, ImageNet_OpenOOD
5 changes: 3 additions & 2 deletions src/pytorch_ood/benchmark/img/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .cifar10 import CIFAR10_ODIN
from .cifar100 import CIFAR100_ODIN
from .cifar10 import CIFAR10_ODIN, CIFAR10_OpenOOD
from .cifar100 import CIFAR100_ODIN, CIFAR100_OpenOOD
from .imagenet import ImageNet_OpenOOD
124 changes: 118 additions & 6 deletions src/pytorch_ood/benchmark/img/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from typing import Dict, List

from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import CIFAR10
from torchvision.datasets import CIFAR10, CIFAR100, MNIST, FashionMNIST
from torchvision.transforms import Compose

from pytorch_ood.api import Detector
from pytorch_ood.dataset.img import LSUNCrop, LSUNResize, TinyImageNetCrop, TinyImageNetResize, GaussianNoise, UniformNoise
from pytorch_ood.utils import OODMetrics, ToUnknown
from pytorch_ood.dataset.img import LSUNCrop, LSUNResize, TinyImageNetCrop, TinyImageNetResize, GaussianNoise, \
UniformNoise, TinyImageNet, Textures, Places365
from pytorch_ood.utils import OODMetrics, ToUnknown, ToRGB
from pytorch_ood.benchmark import Benchmark


Expand Down Expand Up @@ -55,7 +57,7 @@ def __init__(self, root, transform):
]

self.ood_names: List[str] = [] #: OOD Dataset names
self.ood_names = ["TinyImageNetCrop", "TinyImageNetResize", "LSUNResize", "LSUNCrop", "Uniform", "Gaussian"]
self.ood_names = [type(d).__name__ for d in self.test_oods]

def train_set(self) -> Dataset:
"""
Expand All @@ -79,12 +81,122 @@ def test_sets(self, known=True, unknown=True) -> List[Dataset]:
return [self.train_in]

if not known and unknown:
return self.ood_datasets
return self.test_oods

raise ValueError()

def evaluate(
self, detector: Detector, loader_kwargs: Dict = None, device: str = "cpu"
self, detector: Detector, loader_kwargs: Dict = None, device: str = "cpu"
) -> List[Dict]:
"""
Evaluates the given detector on all datasets and returns a list with the results

:param detector: the detector to evaluate
:param loader_kwargs: keyword arguments to give to the data loader
:param device: the device to move batches to
"""
if loader_kwargs is None:
loader_kwargs = {}

metrics = []

for name, dataset in zip(self.ood_names, self.test_sets()):
loader = DataLoader(dataset=dataset, **loader_kwargs)

m = OODMetrics()

for x, y in loader:
m.update(detector(x.to(device)), y)

r = m.compute()
r.update({"Dataset": name})

metrics.append(r)

return metrics


class CIFAR10_OpenOOD(Benchmark):
"""
Aims to replicate the benchmark proposed in *OpenOOD: Benchmarking Generalized Out-of-Distribution Detection*.

:see Paper: `OpenOOD <https://openreview.net/pdf?id=gT6j4_tskUt>`__

Outlier datasets are

* CIFAR100
* TinyImageNet
* MNIST
* FashionMNIST
* Textures
* Places365

.. warning :: This currently does not reproduce the benchmark accurately, as it does not exclude images with
overlap with CIFAR10.

"""

def __init__(self, root, transform):
"""
:param root: where to store datasets
:param transform: transform to apply to images
"""
self.transform = Compose([ToRGB(), transform])
self.train_in = CIFAR10(root, download=True, transform=transform, train=True)
self.test_in = CIFAR10(root, download=True, transform=transform, train=False)

self.test_oods = [
CIFAR100(
root, download=True, transform=self.transform, target_transform=ToUnknown(), train=False
),
TinyImageNet(
root, download=True, transform=self.transform, target_transform=ToUnknown(), subset="val"
),
MNIST(
root, download=True, transform=self.transform, target_transform=ToUnknown(), train=False
),
FashionMNIST(
root, download=True, transform=self.transform, target_transform=ToUnknown(), train=False
),
Textures(
root, download=True, transform=self.transform, target_transform=ToUnknown()
),
Places365(
root, download=True, transform=self.transform, target_transform=ToUnknown()
)
]

self.ood_names: List[str] = [] #: OOD Dataset names
self.ood_names = [type(d).__name__ for d in self.test_oods]

def train_set(self) -> Dataset:
"""
Training dataset
"""
return self.train_in

def test_sets(self, known=True, unknown=True) -> List[Dataset]:
"""
List of the different test datasets.
If known and unknown are true, each dataset contains IN and OOD data.

:param known: include IN
:param unknown: include OOD
"""

if known and unknown:
return [self.test_in + other for other in self.test_oods]

if known and not unknown:
return [self.train_in]

if not known and unknown:
return self.test_oods

raise ValueError()

def evaluate(
self, detector: Detector, loader_kwargs: Dict = None, device: str = "cpu"
) -> List[Dict]:
"""
Evaluates the given detector on all datasets and returns a list with the results
Expand Down
Loading