# Benchmark of ODIN method

This notebook aims at evaluating the **ODIN method** and reproducing the results in the
original paper.

The authors compared two network architectures on two in-distribution datasets and six
out-of-distribution datasets.

- **Network architectures**
  - DenseNet-BC-100
  - Wide-ResNet-28-10
- **ID datasets**
  - CIFAR-10
  - CIFAR-100
- **OOD datasets**
  - TinyImageNet cropped
  - TinyImageNet resized
  - LSUN cropped
  - LSUN resized
  - Uniform noise
  - Gaussian noise.

Here, we focus on a DenseNet network trained on CIFAR-10. This model is challenged on
LSUN and TinyImageNet cropped OOD datasets.

**Reference**  
_Enhancing The Reliability of Out-of-distribution Image Detection in Neural Networks_  
Liang, Shiyu and Li, Yixuan and Srikant, R.  
International Conference on Learning Representations, 2018  
<https://openreview.net/forum?id=H1VGkIxRZ>


In [None]:
%load_ext autoreload
%autoreload 2

## 1. Load CIFAR-10 dataset and pretrained DenseNet model

The CIFAR-10 dataset is loaded and preprocessed (normalized to 0-1). This is our
in-distribution dataset.

A pretrained DenseNet-121 model is loaded and evaluated on CIFAR-10 test set: 87.4 %
accuracy.


In [None]:
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"  # Disable TensorFlow log messages

import tensorflow as tf
import tensorflow_datasets as tfds
from oodeel.datasets import OODDataset

input_scaling = 1 / 255
batch_size = 128

oods_in = OODDataset('cifar10', split="test", input_key="image")
oods_train = OODDataset('cifar10', split="train", input_key="image")

def preprocess_fn(*inputs):
    x = inputs[0] / 255
    return tuple([x] + list(inputs[1:]))

batch_size = 128
ds_in = oods_in.prepare(batch_size=batch_size, preprocess_fn=preprocess_fn)

In [None]:
from oodeel.models.training_funs import train_keras_app

try:
    model = tf.keras.models.load_model("./saved_models/cifar10/")
    
except OSError:
    train_config = {
        "input_shape": (32, 32, 3),
        "num_classes": 10,
        "batch_size": 128,
        "epochs": 60,
        "save_dir": "./saved_models/cifar10/",
        "validation_data": oods_in.get_dataset() #ds_in is actually the test set of MNIST
    }

    model = train_keras_app(oods_train.get_dataset(), "resnet18", **train_config) 


## 2. Load LSUN cropped dataset

Second, the LSUN cropped dataset is loaded and preprocessed. This is our
out-of-distribution dataset.

**Note**: see https://github.com/facebookresearch/odin to download OOD datasets from the
original paper. Some OOD datasets have images of size 36x36 pixels: a black frame of 2 pixels surrounds the
32x32 image. In this case, the image is cropped to 32x32 pixels.


In [None]:
LSUN_root = os.path.expanduser("~/") + "datasets_oodeel/LSUN"
lsun_ds = tf.keras.utils.image_dataset_from_directory(
    LSUN_root,
    image_size=(32, 32),
    shuffle=False,
    batch_size=None
)
# ood dataset
ds_out = OODDataset(lsun_ds).prepare(batch_size=batch_size, preprocess_fn=preprocess_fn)


ood_name = "LSUN"


## 3. OOD detection

Here the ODIN method is applied on both ID and OOD images. The OOD score is computed for
each image and some metrics are measured to evaluate the performance of the OOD
detector.

In a nutshell, the ODIN method consists of the three following steps:

1. Perturb the input image by applying a gradient descent step in order to increase the
   calibrated probability score of the predicted class.
2. Compute the calibrated probability score of the perturbed image. This is defined as
   the OOD score.
3. If the OOD score is below a threshold, the image is considered as OOD.

The _calibrated probability_ is the temperature-scaled softmax of the logits where the
temperature is a hyper-parameter of the method. The step in the gradient descent
perturbation of the image is also a hyper-parameter.


In [None]:
# Compute OOD scores for ODIN method

from oodeel.methods import ODIN

oodmodel = ODIN(temperature=1000)
oodmodel.fit(model)

scores_id = oodmodel.score(ds_in)
scores_ood = oodmodel.score(ds_out)


In [None]:
# Compute and display OOD metrics

import numpy as np
import matplotlib.pyplot as plt
from oodeel.eval.metrics import bench_metrics, get_curve


def compute_ood_metrics(scores_id, scores_ood):
    # Compute ROC curve, AUROC, FPR and TPR.
    scores = np.concatenate([scores_id, scores_ood])
    labels = np.array([0] * len(scores_id) + [1] * len(scores_ood))
    fpr, tpr = get_curve(scores, labels)
    metrics = bench_metrics(
        scores,
        labels,
        metrics=["auroc", "fpr95tpr"],
    )
    return metrics, fpr, tpr


def display_ROC(scores_id, scores_ood, fpr, tpr):
    plt.figure(figsize=(9, 3))
    plt.subplot(121)
    plt.hist((scores_id, scores_ood), bins=30, label=("id", "ood"))
    plt.xlabel("Score")
    plt.title("OOD scores")
    plt.legend()
    plt.subplot(122)
    plt.plot(fpr, tpr)
    plt.xlabel("False positive rate")
    plt.ylabel("True positive rate")
    plt.title("ROC")
    plt.show()


print("In-distribution: CIFAR-10 / Out-of-distribution: LSUN")
metrics, fpr, tpr = compute_ood_metrics(scores_id, scores_ood)
print(metrics)

display_ROC(scores_id, scores_ood, fpr, tpr)


## BONUS

In [None]:
# BONUS
# Compare calibrated softmax outputs from original and perturbed images (single batch)
for img in ds_out.take(1):
    img = img[0]
    pass
img_perturbed = oodmodel._input_perturbation(img)

perturbations = img_perturbed - img
norms = tf.norm(tf.reshape(perturbations, [perturbations.shape[0], -1]), axis=-1)
print(norms.numpy())

# Outputs for original images
preds = tf.nn.softmax(model.predict(img) / oodmodel.temperature).numpy()
maxx = preds.max(axis=-1)
argmaxx = preds.argmax(axis=-1).squeeze()

# Outputs for perturbed images
preds_per = tf.nn.softmax(model.predict(img_perturbed) / oodmodel.temperature).numpy()
maxx_per = preds_per.max(axis=-1)
argmaxx_per = preds_per.argmax(axis=-1).squeeze()

print(argmaxx)
print()
print(argmaxx_per)
print()

print(maxx)
print()
print(maxx_per)
print()


In [None]:
plt.figure(figsize=(20, 15))
plt.plot(maxx, "*--")
plt.plot(maxx_per, "*--")
plt.title("Max of calibrated softmax outputs")
plt.legend(["Original", "Perturbed"])
plt.show()

plt.figure(figsize=(20, 15))
plt.plot(argmaxx, "*--")
plt.plot(argmaxx_per, "*--")
plt.title("Argmax before - after ")
plt.legend(["Original", "Perturbed"])
plt.show()


In [None]:
# Compare softmax outputs for a single image (original and perturbed)
one_img = tf.expand_dims(img[17], axis=0)
one_img_perturbed = oodmodel._input_perturbation(one_img)
print("Diff", np.abs((one_img - one_img_perturbed).numpy()).sum())
print()

print(
    "preds original",
    oodmodel.feature_extractor.model(one_img, training=False) / oodmodel.temperature,
)
print(
    "argmax preds original",
    (oodmodel.feature_extractor.model(one_img, training=False) / oodmodel.temperature)
    .numpy()
    .argmax(),
)
print()
print(
    "preds perturbed",
    oodmodel.feature_extractor.model(one_img_perturbed, training=False)
    / oodmodel.temperature,
)
print(
    "argmax preds perturbed",
    (
        oodmodel.feature_extractor.model(one_img_perturbed, training=False)
        / oodmodel.temperature
    )
    .numpy()
    .argmax(),
)


In [None]:
print(oodmodel.feature_extractor.model.get_layer("conv2_block2_1_bn").moving_mean.numpy()[:20])

for batch in ds_out.take(1):
    oodmodel.feature_extractor.model(100*batch)

# print(oodmodel.feature_extractor.model.get_layer("conv2_block2_1_bn").moving_mean.numpy()[:20])


with tf.GradientTape(watch_accessed_variables=True) as tape:
    tape.watch(batch)
    tape.watch(oodmodel.feature_extractor.model.get_layer("conv2_block2_1_bn").variables)
    preds = oodmodel.feature_extractor.model(batch, training=None)
# grad = tape.gradient(preds, batch)
_ = tape.gradient(preds, oodmodel.feature_extractor.model.get_layer("conv2_block2_1_bn").variables)

# print(grad.numpy().flatten()[:20])
print(oodmodel.feature_extractor.model.get_layer("conv2_block2_1_bn").moving_mean.numpy()[:20])



# print()
# print(oodmodel.feature_extractor.model(batch[:3]))
# print(oodmodel.feature_extractor.model(batch[:3], training=False))
# print(oodmodel.feature_extractor.model(batch[:3], training=True))
