### Benchmarking your OoD detection robustness!
This the a demonstration of running our OoD robustness test on a couple of OoD detectors attached to a DNN image classifer.

In [None]:
import os
import yaml
import torch
import timm
from utils.dataloader import load_dataset
from models.model_utils import InputNormalizer, load_model
from utils.ood_detectors import build_detectors
from utils.attackers import build_attackers
from utils.test_utils import setup_seed, original_data_ood_test, seeds_ood_test, perturbed_samples_ood_test
from pytorch_ood.detector import EnergyBased, ViM

os.chdir(os.path.dirname(os.getcwd()))
print("Current working directory: ", os.getcwd())


# Setups
rand_seed = 0
n_seeds = 1000
n_sampling = 50
severity = "avg"
device = "cuda"

benchmark = "Imagenet1k"
n_classes = 1000
datadir = "dataset/"
ood_datasets = ["NINCO", "iNaturalist"]
img_size = 224
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
batch_size = 20

model_name = "swin"
variant = "NT"
weight_name = "swin_base_patch4_window7_224"
ood_detectors = ["EnergyBased", "ViM"]
perturb_functions = ["rotation", "translation", "scale", "hue", "saturation", 
                     "bright_contrast", "blur", "Linf", "L2"]


# Step 1: Load model
weight_path = os.path.join("models", benchmark.lower(), "state_dicts", weight_name)
model = load_model(model_name, weight_path, benchmark, device=device)
model.eval()
input_normalizer = InputNormalizer(mean=mean, std=std)

# Step 2: Build OoD detectors
# 1) Load ID training dataset
id_train_data_set, id_train_data_loader = load_dataset(datadir, dataset_name="Imagenet1k", img_size=img_size, 
                                                       benchmark=benchmark, 
                                                        split="train", batch_size=batch_size, normalize=True,
                                                        mean=mean, std=std)
# 2) Build OoD detectors
detectors = build_detectors(ood_detectors, model, input_normalizer, id_train_data_loader, 
                            device=device)

# Step 3: Build attackers
attackers = build_attackers(perturb_functions, severity_level=severity, img_size=img_size)

# Step 4: Start OoD robustness test
for dataset_name in [benchmark] + ood_datasets:
    print("------------------------------------")
    print("Dataset:", dataset_name)
    save_dir = os.path.join("results", benchmark.lower(), str(rand_seed), model_name, variant, dataset_name, "scores")
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # Load the test dataset
    test_data_set, test_data_loader = load_dataset(datadir, dataset_name, img_size=img_size, 
                                                   benchmark=benchmark, split="test", batch_size=1)
    # Select seeds
    data_set_ = test_data_set # [data_set[i] for i in range(len(y_all)) if (y_pred_all[i] == y_all[i])]
    print("Size:", len(data_set_))
    setup_seed(rand_seed)
    idx_temp = torch.randperm(len(data_set_))[:n_seeds]
    
    ood = dataset_name != benchmark
    # Test the model and OoD detectors on the whole dataset.
    original_data_ood_test(model, detectors, test_data_loader, input_normalizer, ood, save_dir, 
                            device=device)

    # # Test the model and OoD detectors on the seeds.
    # seeds_ood_test(model, detectors, data_set_, idx_temp, input_normalizer, ood, save_dir,
    #                batch_size=batch_size, device=device)
    
    # Generate perturbed samples and test the model and OoD detectors on them.
    # Save the model prediction, confidence & OoD scores of perturbed samples
    perturbed_samples_ood_test(model, detectors, data_set_, idx_temp, input_normalizer, attackers, 
                                n_sampling, ood, save_dir, batch_size=batch_size, device=device)
