This repository provides an implementation of GradPCA, a method for detecting out-of-distribution (OOD) data using gradient-based representations. It includes tools for benchmarking GradPCA alongside a variety of baseline OOD detectors on standard image classification datasets.
We provide the following example notebooks:
-
example.ipynb:
A minimal, self-contained demonstration of theGradPCAclass across a range of datasets. -
benchmark.ipynb:
A benchmark comparing GradPCA against several baseline OOD detectors on multiple OOD datasets.
The following datasets are supported and can be used for both in-distribution (ID) and out-of-distribution (OOD) evaluation:
CIFAR-10, CIFAR-100, SVHN, Places, LSUN_r, LSUN_c, iSUN, Textures, ImageNet-1k, ImageNet-V2, iNaturalist (MOS), Places (MOS), SUN (MOS)
Some datasets must be downloaded manually:
-
ImageNet-1k
Download from the official site -
MOS Datasets (for ImageNet OOD benchmarks):
⚠️ Place the extracted folders in the expected location (e.g.,data/datasets/) for compatibility with dataset loaders.
All other datasets (e.g., CIFAR, SVHN, LSUN, iSUN) are automatically downloaded when their respective loaders are instantiated.
The framework supports the following backbone architectures:
We use Big Transfer (BiT) pretrained models from the official BiT repository. The following weights must be downloaded manually:
-
BiT-M-R50x1for CIFAR-10
Download -
BiT-M-R101x1for CIFAR-100
Download -
BiT-S-R50x1for ImageNet-1k
Download -
BiT-M-R101x1for ImageNet-1k
Download
📥 Place the
.npzfiles inmodels/pretrained/or the appropriate directory.
We provide ResNet-34 models originally from the TIMM library, converted to JAX and finetuned. The pretrained JAX weights for CIFAR-100 are included with the repository.
The framework—and particularly our GradPCA implementation—is easily extensible to new model architectures. To add a model, implement it in JAX (optionally also in PyTorch for baseline comparison), and use its TrainState directly with GradPCA. Ensure that parameter names are compatible with GradPCA’s expected structure.
We provide a setup script compatible with:
- macOS (arm64)
- Linux with CUDA 11.8
To install the required dependencies, run:
bash setup.shThis will create a dedicated conda environment named gradpca_env with the required packages.
⚠️ Note: This project integrates JAX, PyTorch, and TensorFlow within a single environment. Installation may require manual adjustments depending on your system configuration, CUDA version, and Python environment.