In [None]:
%load_ext autoreload
%env TFDS_DATA_DIR=/Users/paul/datasets/tensorflow_datasets
%env PYTHONPATH=../

import sys
sys.path.append("../")

import os
print(os.environ["PYTHONPATH"])

import numpy as np
from tensorflow import keras
from oodeel.methods import MLS, DKNN, ODIN
from oodeel.methods import DKNN
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds

from oodeel.eval.metrics import bench_metrics, get_curve
from oodeel.datasets import DataHandler

from sklearn.metrics import *


## Two datasets experiment

In [None]:
%autoreload 2

def normalize(x):
    return x/255
    
data_handler = DataHandler()
ds1 = data_handler.load_tfds('mnist', preprocess=True, preprocessing_fun=normalize)
ds2 = data_handler.load_tfds('fashion_mnist', preprocess=True, preprocessing_fun=normalize)
x_id = ds1["test"]
x_ood = ds2["test"]
x_train = ds1["train"]

x_test = data_handler.merge_tfds(x_id, x_ood, shuffle=False)

In [None]:

model = tf.keras.models.load_model("../saved_models/mnist_model")

### MLS

In [None]:

%autoreload 2


oodmodel = MLS()
oodmodel.fit(model)
scores = oodmodel.score(x_test.batch(100))
labels = data_handler.get_ood_labels(x_test)

fpr, tpr = get_curve(scores, labels)

metrics = bench_metrics(
    scores, labels, 
    metrics = ["auroc", "fpr95tpr", accuracy_score, roc_auc_score], 
    threshold = -5
    )

print(metrics)
plt.plot(fpr, tpr)
plt.show()


### DKNN

In [None]:
%autoreload 2

## This time need a dataset to fit KNN score
x_test = data_handler.merge_tfds(x_id.take(1000), x_ood.take(1000))

oodmodel = DKNN()
oodmodel.fit(model, x_train.take(10000).batch(100))
scores = oodmodel.score(x_test.batch(100))
labels = data_handler.get_ood_labels(x_test)

fpr, tpr = get_curve(scores, labels)
metrics = bench_metrics(
    scores, labels, 
    metrics = ["auroc", "fpr95tpr", accuracy_score, roc_auc_score], 
    threshold = -5
    )

print(metrics)
plt.plot(fpr, tpr)
plt.show()

### ODIN

In [None]:

%autoreload 2
from oodeel.methods import ODIN

x_test = data_handler.merge_tfds(x_id, x_ood)
labels = data_handler.get_ood_labels(x_test)


oodmodel = ODIN()
oodmodel.fit(model)
scores_id = oodmodel.score(x_id.batch(100))
scores_ood = oodmodel.score(x_ood.batch(100))

scores = np.append(scores_id, scores_ood)
#labels = data_handler.get_ood_labels(x_test)

fpr, tpr = get_curve(scores, labels)
metrics = bench_metrics(
    scores, labels, 
    metrics = ["auroc", "fpr95tpr", accuracy_score, roc_auc_score], 
    threshold = -5
    )

print(metrics)
plt.plot(fpr, tpr)
plt.show()

## Single dataset experiment

(Leave-$k$-classes-out training).
First, we need to define a training function

In [None]:
%autoreload 2
inc_labels = [0, 1, 2, 3, 4]
data_handler = DataHandler()
ds = data_handler.load_tfds('mnist', preprocess=True, preprocessing_fun=normalize)
x_id, x_ood = data_handler.filter_tfds(ds["test"], inc_labels = inc_labels )
x_train_id, _ = data_handler.filter_tfds(ds["train"], inc_labels = inc_labels )

In [None]:
%autoreload 2
from oodeel.models.training_funs import train_convnet

train_config = {
    "batch_size": 128,
    "epochs": 5
}

model = train_convnet(x_train_id, **train_config)

## MLS

In [None]:
%autoreload 2


x_test = data_handler.merge_tfds(x_id, x_ood, shuffle=False)

oodmodel = MLS()
oodmodel.fit(model)
scores = oodmodel.score(x_test.batch(100))
labels = data_handler.get_ood_labels(x_test)

fpr, tpr = get_curve(scores, labels)
metrics = bench_metrics(
    scores, labels, 
    metrics = ["auroc", "fpr95tpr", accuracy_score, roc_auc_score], 
    threshold = -5
    )

print(metrics)
plt.plot(fpr, tpr)
plt.show()

### DKNN

In [None]:
%autoreload 2

x_test = data_handler.merge_tfds(x_id.take(1000), x_ood.take(1000))

oodmodel = DKNN()
oodmodel.fit(model, x_train_id.take(10000).batch(100))
scores = oodmodel.score(x_test.batch(100))
labels = data_handler.get_ood_labels(x_test)

fpr, tpr = get_curve(scores, labels)
metrics = bench_metrics(
    scores, labels, 
    metrics = ["auroc", "fpr95tpr", accuracy_score, roc_auc_score], 
    threshold = -5
    )

print(metrics)
plt.plot(fpr, tpr)
plt.show()

### ODIN

In [None]:
%autoreload 2


x_test = data_handler.merge_tfds(x_id, x_ood, shuffle=False)
#x_test, y_id = data_handler.convert_to_numpy(x_id)

oodmodel = ODIN()
oodmodel.fit(model)
scores = oodmodel.score(x_test.batch(100))
labels = data_handler.get_ood_labels(x_test)

fpr, tpr = get_curve(scores, labels)
metrics = bench_metrics(
    scores, labels, 
    metrics = ["auroc", "fpr95tpr", accuracy_score, roc_auc_score], 
    threshold = -5
    )

print(metrics)
plt.plot(fpr, tpr)
plt.show()