Skip to content

mselezniova/GradPCA

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GradPCA: Leveraging NTK Alignment for Reliable Out-of-Distribution Detection

This repository provides an implementation of GradPCA, a method for detecting out-of-distribution (OOD) data using gradient-based representations. It includes tools for benchmarking GradPCA alongside a variety of baseline OOD detectors on standard image classification datasets.


🚀 Usage

We provide the following example notebooks:

  • example.ipynb:
    A minimal, self-contained demonstration of the GradPCA class across a range of datasets.

  • benchmark.ipynb:
    A benchmark comparing GradPCA against several baseline OOD detectors on multiple OOD datasets.


🗂️ Datasets

The following datasets are supported and can be used for both in-distribution (ID) and out-of-distribution (OOD) evaluation:

CIFAR-10, CIFAR-100, SVHN, Places, LSUN_r, LSUN_c, iSUN, Textures, ImageNet-1k, ImageNet-V2, iNaturalist (MOS), Places (MOS), SUN (MOS)

📥 Manual Downloads Required

Some datasets must be downloaded manually:

⚠️ Place the extracted folders in the expected location (e.g., data/datasets/) for compatibility with dataset loaders.

✅ Auto-loaded Datasets

All other datasets (e.g., CIFAR, SVHN, LSUN, iSUN) are automatically downloaded when their respective loaders are instantiated.


🏛️ Models

The framework supports the following backbone architectures:

✅ ResNetV2 (BiT Models)

We use Big Transfer (BiT) pretrained models from the official BiT repository. The following weights must be downloaded manually:

📥 Place the .npz files in models/pretrained/ or the appropriate directory.

✅ ResNet (from TIMM, ported to JAX)

We provide ResNet-34 models originally from the TIMM library, converted to JAX and finetuned. The pretrained JAX weights for CIFAR-100 are included with the repository.

🧩 Extensibility

The framework—and particularly our GradPCA implementation—is easily extensible to new model architectures. To add a model, implement it in JAX (optionally also in PyTorch for baseline comparison), and use its TrainState directly with GradPCA. Ensure that parameter names are compatible with GradPCA’s expected structure.


🧰 Setup

We provide a setup script compatible with:

  • macOS (arm64)
  • Linux with CUDA 11.8

To install the required dependencies, run:

bash setup.sh

This will create a dedicated conda environment named gradpca_env with the required packages.

⚠️ Note: This project integrates JAX, PyTorch, and TensorFlow within a single environment. Installation may require manual adjustments depending on your system configuration, CUDA version, and Python environment.

About

Source code for the paper "GradPCA: Leveraging NTK Alignment for Reliable Out-of-Distribution Detection" by Mariia Seleznova et al., ICLR 2026

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors