# "Mahalanobis" out-of-distribution detection

- Try rejecting based on empirical distribution of Mahalanobis distances (e.g. reject new data if it falls beyond 99th percentile of Mahalanobis distance compared to training data).

## Imports

In [None]:
import plotly.express as px
# import tensorflow_datasets as tfds
import numpy as np
import sys
sys.path.append("../src/")
import gda
import scipy.stats
import torchvision
import torch
import pytorch_lightning as pl
from model import Classifier

## Test data

In [None]:
size = (10000,2)
center = [1,0]
data = np.random.normal(loc=np.array(center),
                        size=size)

In [None]:
px.scatter(x=data[:,0], y=data[:,1])

In [None]:
mean = data.mean(axis=0)
covariance = np.cov(data.T)

In [None]:
train, test = tfds.load('mnist', split=['train', 'test'], data_dir='../data', as_supervised=True)

In [None]:
distances = gda.mahalanobis(data, mean, covariance)

In [None]:
px.histogram(distances)

In [None]:
distribution = scipy.stats.distributions.chi(df=2)

In [None]:
x_vals = np.linspace(start=0, stop=4.5, num=100)
y_vals = distribution.pdf(x_vals)

In [None]:
px.line(x=x_vals, y=y_vals)

## Non-synthetic data

In [None]:
data_root = "../data/"

In [None]:
transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
    torchvision.transforms.Resize((64,64)),
])

In [None]:
mnist_train = torchvision.datasets.MNIST(root=data_root,
                                         train=True,
                                         transform=transforms,
                                         download=True)
mnist_test = torchvision.datasets.MNIST(root=data_root,
                                         train=False,
                                         transform=transforms,
                                         download=True)

In [None]:
kmnist_train = torchvision.datasets.KMNIST(root=data_root,
                                           train=True,
                                           download=True)
kmnist_test = torchvision.datasets.KMNIST(root=data_root,
                                          train=False,
                                          download=True)

## Classifier model

In [None]:
model = torchvision.models.resnet18(pretrained=True)

In [None]:
data = torch.unsqueeze(mnist_train.data.float(), 1)
data = torch.tile(data, (1,3,1,1))

In [None]:
model(data[:10]).shape

In [None]:
def fit(model, data, opt, loss, epochs=10)

In [None]:
model = Classifier()

In [None]:
model(data[:10])

## Model training

In [None]:
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=16)
trainer = pl.Trainer(max_epochs=1)

In [None]:
trainer.fit(model, train_dataloader=train_loader)

In [None]:
for batch, labels in train_loader:
    batch, labels = batch, labels
    break

In [None]:
data = mnist_train.data[:10]

In [None]:
torch.nn.functional.softmax(model(batch), dim=1).shape

In [None]:
labels.shape

In [None]:
labels

In [None]:
torch.nn.functional.softmax(model(batch), dim=1)