In this notebook, we demo implementation and performance on OpenOOD Benchmark [1]

Ref: [1] https://github.com/Jingkang50/OpenOOD/tree/main

In [1]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Tue May 28 04:07:25 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   65C    P8              10W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

1. Install OpenOOD with pip and make necessary preparation

In [2]:
!pip show OpenOOD

[0m

In [3]:
!pip install git+https://github.com/Jingkang50/OpenOOD

Collecting git+https://github.com/Jingkang50/OpenOOD
  Cloning https://github.com/Jingkang50/OpenOOD to /tmp/pip-req-build-7yz7l7p3
  Running command git clone --filter=blob:none --quiet https://github.com/Jingkang50/OpenOOD /tmp/pip-req-build-7yz7l7p3
  Resolved https://github.com/Jingkang50/OpenOOD to commit 18c6f5174a2f518e2a8e819ffb1cd1914bcf12e0
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting json5 (from openood==1.5)
  Downloading json5-0.9.25-py3-none-any.whl (30 kB)
Collecting pre-commit (from openood==1.5)
  Downloading pre_commit-3.7.1-py2.py3-none-any.whl (204 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m204.3/204.3 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
Collecting diffdist>=0.1 (from openood==1.5)
  Downloading diffdist-0.1.tar.gz (4.6 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting faiss-gpu>=1.7.2 (from openood==1.5)
  Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8

In [4]:
# necessary imports
import torch

from openood.evaluation_api import Evaluator
from openood.networks import ResNet18_32x32 #just a wrapper around the ResNet
from openood.networks.ash_net import ASHNet

In [5]:
# download our pre-trained CIFAR-10 classifier
!gdown 1byGeYxM_PlLjT72wZsMQvP6popJeWBgt
!unzip cifar10_res18_v1.5.zip

Downloading...
From (original): https://drive.google.com/uc?id=1byGeYxM_PlLjT72wZsMQvP6popJeWBgt
From (redirected): https://drive.google.com/uc?id=1byGeYxM_PlLjT72wZsMQvP6popJeWBgt&confirm=t&uuid=af74a0b7-fa53-4990-b12d-5fe08edf39a8
To: /content/cifar10_res18_v1.5.zip
100% 375M/375M [00:09<00:00, 38.6MB/s]
Archive:  cifar10_res18_v1.5.zip
   creating: cifar10_resnet18_32x32_base_e100_lr0.1_default/
   creating: cifar10_resnet18_32x32_base_e100_lr0.1_default/s2/
  inflating: cifar10_resnet18_32x32_base_e100_lr0.1_default/s2/best_epoch99_acc0.9450.ckpt  
  inflating: cifar10_resnet18_32x32_base_e100_lr0.1_default/s2/config.yml  
  inflating: cifar10_resnet18_32x32_base_e100_lr0.1_default/s2/best.ckpt  
  inflating: cifar10_resnet18_32x32_base_e100_lr0.1_default/s2/last_epoch100_acc0.9420.ckpt  
  inflating: cifar10_resnet18_32x32_base_e100_lr0.1_default/s2/log.txt  
   creating: cifar10_resnet18_32x32_base_e100_lr0.1_default/s1/
  inflating: cifar10_resnet18_32x32_base_e100_lr0.1_default

In [6]:
# load the model
net = ResNet18_32x32(num_classes=10)
net.load_state_dict(
    torch.load('./cifar10_resnet18_32x32_base_e100_lr0.1_default/s0/best.ckpt')
)
net.cuda()
net.eval()

ResNet18_32x32(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1,

2. SOTA Baseline

In [7]:
knn_evaluator = Evaluator(
    net,
    id_name='cifar10',                     # the target ID dataset
    data_root='./data',                    # change if necessary
    config_root=None,                      # see notes above
    preprocessor=None,                     # default preprocessing for the target ID dataset
    postprocessor_name='knn',              # the postprocessor to use
    postprocessor=None,                    # if you want to use your own postprocessor
    batch_size=200,                        # for certain methods the results can be slightly affected by batch size
    shuffle=False,
    num_workers=2)                         # could use more num_workers outside colab

Downloading...
From (original): https://drive.google.com/uc?id=1XKzBdWCqg3vPoj-D32YixJyJJ0hL63gP
From (redirected): https://drive.google.com/uc?id=1XKzBdWCqg3vPoj-D32YixJyJJ0hL63gP&confirm=t&uuid=a20bab61-f7ec-484a-a460-44409751cea2
To: /content/data/benchmark_imglist.zip
100%|██████████| 27.7M/27.7M [00:01<00:00, 18.7MB/s]


cifar10 needs download:
./data/images_classic/cifar10


Downloading...
From (original): https://drive.google.com/uc?id=1Co32RiiWe16lTaiOU6JMMnyUYS41IlO1
From (redirected): https://drive.google.com/uc?id=1Co32RiiWe16lTaiOU6JMMnyUYS41IlO1&confirm=t&uuid=5178b001-0fab-422c-9259-d2c6b4bffb51
To: /content/data/images_classic/cifar10/cifar10.zip
100%|██████████| 143M/143M [00:02<00:00, 53.3MB/s]


cifar100 needs download:
./data/images_classic/cifar100


Downloading...
From (original): https://drive.google.com/uc?id=1PGKheHUsf29leJPPGuXqzLBMwl8qMF8_
From (redirected): https://drive.google.com/uc?id=1PGKheHUsf29leJPPGuXqzLBMwl8qMF8_&confirm=t&uuid=5d396f46-8a73-4301-8ba8-491b5ef054e9
To: /content/data/images_classic/cifar100/cifar100.zip
100%|██████████| 141M/141M [00:01<00:00, 80.7MB/s]


tin needs download:
./data/images_classic/tin


Downloading...
From (original): https://drive.google.com/uc?id=1PZ-ixyx52U989IKsMA2OT-24fToTrelC
From (redirected): https://drive.google.com/uc?id=1PZ-ixyx52U989IKsMA2OT-24fToTrelC&confirm=t&uuid=d8ae713f-763b-4237-8458-bf7beb3e1fc4
To: /content/data/images_classic/tin/tin.zip
100%|██████████| 237M/237M [00:02<00:00, 106MB/s] 


mnist needs download:
./data/images_classic/mnist


Downloading...
From (original): https://drive.google.com/uc?id=1CCHAGWqA1KJTFFswuF9cbhmB-j98Y1Sb
From (redirected): https://drive.google.com/uc?id=1CCHAGWqA1KJTFFswuF9cbhmB-j98Y1Sb&confirm=t&uuid=8ec930b0-031a-4594-926d-8618ce074915
To: /content/data/images_classic/mnist/mnist.zip
100%|██████████| 47.2M/47.2M [00:01<00:00, 40.7MB/s]


svhn needs download:
./data/images_classic/svhn


Downloading...
From: https://drive.google.com/uc?id=1DQfc11HOtB1nEwqS4pWUFp8vtQ3DczvI
To: /content/data/images_classic/svhn/svhn.zip
100%|██████████| 19.0M/19.0M [00:00<00:00, 136MB/s] 


texture needs download:
./data/images_classic/texture


Downloading...
From (original): https://drive.google.com/uc?id=1OSz1m3hHfVWbRdmMwKbUzoU8Hg9UKcam
From (redirected): https://drive.google.com/uc?id=1OSz1m3hHfVWbRdmMwKbUzoU8Hg9UKcam&confirm=t&uuid=3331b3ae-ec7e-4fc4-bb58-e2b27115df1b
To: /content/data/images_classic/texture/texture.zip
100%|██████████| 626M/626M [00:08<00:00, 74.8MB/s]


places365 needs download:
./data/images_classic/places365


Downloading...
From (original): https://drive.google.com/uc?id=1Ec-LRSTf6u5vEctKX9vRp9OA6tqnJ0Ay
From (redirected): https://drive.google.com/uc?id=1Ec-LRSTf6u5vEctKX9vRp9OA6tqnJ0Ay&confirm=t&uuid=3af181f2-93f4-4ee1-bfaa-fd1f7f55b656
To: /content/data/images_classic/places365/places365.zip
100%|██████████| 497M/497M [00:07<00:00, 63.5MB/s]
Setup: 100%|██████████| 250/250 [00:45<00:00,  5.48it/s]


Starting automatic parameter search...


100%|██████████| 5/5 [00:02<00:00,  2.26it/s]
100%|██████████| 5/5 [00:02<00:00,  2.30it/s]


Hyperparam: [50], auroc: 0.905803


100%|██████████| 5/5 [00:02<00:00,  2.38it/s]
100%|██████████| 5/5 [00:03<00:00,  1.66it/s]


Hyperparam: [100], auroc: 0.9042669999999999


100%|██████████| 5/5 [00:03<00:00,  1.34it/s]
100%|██████████| 5/5 [00:02<00:00,  2.14it/s]


Hyperparam: [200], auroc: 0.902765


100%|██████████| 5/5 [00:02<00:00,  2.19it/s]
100%|██████████| 5/5 [00:02<00:00,  2.06it/s]


Hyperparam: [500], auroc: 0.9000844999999998


100%|██████████| 5/5 [00:02<00:00,  2.05it/s]
100%|██████████| 5/5 [00:03<00:00,  1.50it/s]

Hyperparam: [1000], auroc: 0.8976035
Final hyperparam: 50





In [8]:
metrics = knn_evaluator.eval_ood(fsood=False)

Performing inference on cifar10 test set...


100%|██████████| 45/45 [00:19<00:00,  2.26it/s]

Processing near ood...
Performing inference on cifar100 dataset...



100%|██████████| 45/45 [00:18<00:00,  2.49it/s]

Computing metrics on cifar100 dataset...
FPR@95: 37.91, AUROC: 89.75 AUPR_IN: 90.11, AUPR_OUT: 88.37
──────────────────────────────────────────────────────────────────────

Performing inference on tin dataset...



100%|██████████| 39/39 [00:17<00:00,  2.28it/s]

Computing metrics on tin dataset...
FPR@95: 31.18, AUROC: 91.65 AUPR_IN: 93.35, AUPR_OUT: 89.06
──────────────────────────────────────────────────────────────────────

Computing mean metrics...
FPR@95: 34.54, AUROC: 90.70 AUPR_IN: 91.73, AUPR_OUT: 88.71
──────────────────────────────────────────────────────────────────────

Processing far ood...
Performing inference on mnist dataset...



100%|██████████| 350/350 [02:23<00:00,  2.43it/s]

Computing metrics on mnist dataset...
FPR@95: 20.62, AUROC: 94.41 AUPR_IN: 82.75, AUPR_OUT: 99.01
──────────────────────────────────────────────────────────────────────

Performing inference on svhn dataset...



100%|██████████| 131/131 [00:51<00:00,  2.54it/s]

Computing metrics on svhn dataset...
FPR@95: 20.83, AUROC: 92.89 AUPR_IN: 88.98, AUPR_OUT: 96.21
──────────────────────────────────────────────────────────────────────

Performing inference on texture dataset...



100%|██████████| 29/29 [00:34<00:00,  1.19s/it]

Computing metrics on texture dataset...
FPR@95: 24.56, AUROC: 93.02 AUPR_IN: 96.04, AUPR_OUT: 87.10
──────────────────────────────────────────────────────────────────────

Performing inference on places365 dataset...



100%|██████████| 176/176 [01:45<00:00,  1.68it/s]

Computing metrics on places365 dataset...
FPR@95: 29.50, AUROC: 92.10 AUPR_IN: 81.27, AUPR_OUT: 97.37
──────────────────────────────────────────────────────────────────────

Computing mean metrics...
FPR@95: 23.88, AUROC: 93.11 AUPR_IN: 87.26, AUPR_OUT: 94.92
──────────────────────────────────────────────────────────────────────




ID Acc Eval: 100%|██████████| 45/45 [00:09<00:00,  4.88it/s]

           FPR@95  AUROC  AUPR_IN  AUPR_OUT   ACC
cifar100    37.91  89.75    90.11     88.37 95.22
tin         31.18  91.65    93.35     89.06 95.22
nearood     34.54  90.70    91.73     88.71 95.22
mnist       20.62  94.41    82.75     99.01 95.22
svhn        20.83  92.89    88.98     96.21 95.22
texture     24.56  93.02    96.04     87.10 95.22
places365   29.50  92.10    81.27     97.37 95.22
farood      23.88  93.11    87.26     94.92 95.22





5. fDBD Implementation and Evaluation

In [16]:
import numpy as np
for i, param in enumerate(net.fc.parameters()):
  if i == 0:
    w = param.data.cpu().numpy()
  else:
    b = param.data.cpu().numpy()

In [17]:
denominator_matrix = np.zeros((10,10))
for p in range(10):
  w_p = w - w[p,:]
  denominator = np.linalg.norm(w_p, axis=1)
  denominator[p] = 1
  denominator_matrix[p, :] = denominator

denominator_matrix = torch.tensor(denominator_matrix).cuda()

In [18]:
train_mean = np.mean(knn_evaluator.postprocessor.activation_log, axis=0)
train_mean_tensor = torch.from_numpy(train_mean).cuda()

In [19]:
from typing import Any

import torch
import torch.nn as nn

from openood.postprocessors import BasePostprocessor

class fDBDPostprocessor(BasePostprocessor):
    def __init__(self, config):
        super(fDBDPostprocessor, self).__init__(config)
        self.APS_mode = False

    @torch.no_grad()
    def postprocess(self, net: nn.Module, data: Any):
        output, feature = net(data, return_feature=True)
        values, nn_idx = output.max(1)
        logits_sub = torch.abs(output - values.repeat(10, 1).T)
        score = torch.sum(logits_sub/denominator_matrix[nn_idx], axis=1)/torch.norm(feature - train_mean_tensor , dim = 1)
        return nn_idx, score

In [20]:
fdbd_evaluator = Evaluator(
    net,
    id_name='cifar10',                     # the target ID dataset
    data_root='./data',                    # change if necessary
    config_root=None,                      # see notes above
    preprocessor=None,                     # default preprocessing for the target ID dataset
    #postprocessor_name=postprocessor_name, # the postprocessor to use
    postprocessor= fDBDPostprocessor('./configs'),       # if you want to use your own postprocessor
    batch_size=200,                        # for certain methods the results can be slightly affected by batch size
    shuffle=False,
    num_workers=2)                         # could use more num_workers outside colab

In [21]:
metrics = fdbd_evaluator.eval_ood(fsood=False)

Performing inference on cifar10 test set...


100%|██████████| 45/45 [00:12<00:00,  3.69it/s]

Processing near ood...
Performing inference on cifar100 dataset...



100%|██████████| 45/45 [00:06<00:00,  6.97it/s]

Computing metrics on cifar100 dataset...
FPR@95: 41.14, AUROC: 89.36 AUPR_IN: 88.80, AUPR_OUT: 88.26
──────────────────────────────────────────────────────────────────────

Performing inference on tin dataset...



100%|██████████| 39/39 [00:09<00:00,  4.05it/s]

Computing metrics on tin dataset...
FPR@95: 31.71, AUROC: 91.48 AUPR_IN: 92.49, AUPR_OUT: 89.26
──────────────────────────────────────────────────────────────────────

Computing mean metrics...
FPR@95: 36.43, AUROC: 90.42 AUPR_IN: 90.64, AUPR_OUT: 88.76
──────────────────────────────────────────────────────────────────────

Processing far ood...
Performing inference on mnist dataset...



100%|██████████| 350/350 [01:00<00:00,  5.74it/s]


Computing metrics on mnist dataset...
FPR@95: 20.28, AUROC: 94.59 AUPR_IN: 78.55, AUPR_OUT: 99.10
──────────────────────────────────────────────────────────────────────

Performing inference on svhn dataset...


100%|██████████| 131/131 [00:23<00:00,  5.49it/s]

Computing metrics on svhn dataset...
FPR@95: 24.18, AUROC: 91.89 AUPR_IN: 84.48, AUPR_OUT: 95.90
──────────────────────────────────────────────────────────────────────

Performing inference on texture dataset...



100%|██████████| 29/29 [00:28<00:00,  1.03it/s]

Computing metrics on texture dataset...
FPR@95: 24.98, AUROC: 92.83 AUPR_IN: 95.65, AUPR_OUT: 87.28
──────────────────────────────────────────────────────────────────────

Performing inference on places365 dataset...



100%|██████████| 176/176 [01:07<00:00,  2.61it/s]

Computing metrics on places365 dataset...
FPR@95: 26.40, AUROC: 92.70 AUPR_IN: 82.24, AUPR_OUT: 97.60
──────────────────────────────────────────────────────────────────────

Computing mean metrics...
FPR@95: 23.96, AUROC: 93.00 AUPR_IN: 85.23, AUPR_OUT: 94.97
──────────────────────────────────────────────────────────────────────




ID Acc Eval: 100%|██████████| 45/45 [00:06<00:00,  6.66it/s]

           FPR@95  AUROC  AUPR_IN  AUPR_OUT   ACC
cifar100    41.14  89.36    88.80     88.26 95.22
tin         31.71  91.48    92.49     89.26 95.22
nearood     36.43  90.42    90.64     88.76 95.22
mnist       20.28  94.59    78.55     99.10 95.22
svhn        24.18  91.89    84.48     95.90 95.22
texture     24.98  92.83    95.65     87.28 95.22
places365   26.40  92.70    82.24     97.60 95.22
farood      23.96  93.00    85.23     94.97 95.22



