# Bayesian neural net (MNIST)

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from lightning.pytorch import seed_everything, Trainer

from vartorch import (
    MNISTDataModule,
    ConvVarClassifier,
    anomaly_score,
    plot_post_predictions,
    plot_entropy_histograms
)

In [None]:
_ = seed_everything(111111)  # set random seeds manually

## MNIST data

In [None]:
mnist = MNISTDataModule(
    data_set='mnist',
    data_dir='../run/data/',
    mean=0.5,
    std=0.5,
    batch_size=32
)

mnist.prepare_data()  # download data if not yet done
mnist.setup(stage='test')  # create test set

In [None]:
test_loader = mnist.test_dataloader()
x_batch, y_batch = next(iter(test_loader))

In [None]:
fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(5, 4.5))
for idx, ax in enumerate(axes.ravel()):
    image = mnist.renormalize(x_batch[idx, 0]).numpy()
    ax.imshow(image, cmap='gray', vmin=0, vmax=1)
    ax.set_title(mnist.test_set.classes[y_batch[idx]])
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.tight_layout()

## Variational model

In [None]:
ckpt_file = '../run/mnist/version_0/checkpoints/last.ckpt'

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

var_model = ConvVarClassifier.load_from_checkpoint(ckpt_file)

var_model = var_model.to(device)
var_model = var_model.train(False)

print(f'Train mode: {var_model.training}')
print(f'Sampling: {var_model.sampling}')

In [None]:
trainer = Trainer(logger=False)

test_metrics = trainer.test(
    model=var_model,
    dataloaders=mnist.test_dataloader(),
    verbose=True
)

## Example predictions

In [None]:
kmnist = MNISTDataModule(
    data_set='kmnist',
    data_dir='../run/data/',
    mean=0.5,
    std=0.5,
    batch_size=32
)

kmnist.prepare_data()  # download data if not yet done
kmnist.setup(stage='test')  # create test set

In [None]:
norm_set = mnist.test_set
anom_set = kmnist.test_set

norm_loader = mnist.test_dataloader()
anom_loader = kmnist.test_dataloader()

x_norm, y_norm = next(iter(norm_loader))
x_anom, y_anom = next(iter(anom_loader))

In [None]:
# select random images
plot_ids = np.random.permutation(len(x_norm))

# select images with lowest prediction entropy
# plot_ids = torch.argsort(point_anom_entropy, descending=False).detach().cpu().numpy()

In [None]:
num_samples = 500

var_model.train(False)  # turn off train mode
var_model.sample(True)  # turn on sampling

with torch.no_grad():
    sampled_norm_logits = var_model.predict(x_norm.to(var_model.device), num_samples).cpu()
    sampled_norm_probs = torch.softmax(sampled_norm_logits, dim=1)

    sampled_anom_logits = var_model.predict(x_anom.to(var_model.device), num_samples).cpu()
    sampled_anom_probs = torch.softmax(sampled_anom_logits, dim=1)

In [None]:
# plot posterior predictions (in distribution)
fig, axes = plot_post_predictions(
    images=mnist.renormalize(x_norm[plot_ids]),
    sampled_probs=sampled_norm_probs[plot_ids],
    labels=y_norm[plot_ids],
    names=norm_set.classes,
    nrows=3,
    figsize=(8, 6),
    title='Posterior predictions (in distribution)'
)

In [None]:
# plot posterior predictions (out of distribution)
fig, axes = plot_post_predictions(
    images=kmnist.renormalize(x_anom[plot_ids]),
    sampled_probs=sampled_anom_probs[plot_ids],
    labels=y_anom[plot_ids],
    names=anom_set.classes,
    nrows=3,
    figsize=(8, 6),
    title='Posterior predictions (out of distribution)'
)

## Out-of-distribution detection

In [None]:
var_norm_entropy = anomaly_score(
    var_model,
    norm_loader,
    mode='entropy',
    num_samples=100
)

var_anom_entropy = anomaly_score(
    var_model,
    anom_loader,
    mode='entropy',
    num_samples=100
)

In [None]:
# plot posterior entropy histogram
fig, ax = plot_entropy_histograms(
    norm_entropy=var_norm_entropy,
    anom_entropy=var_anom_entropy,
    figsize=(6, 4),
    range=(0, 2),
    bins=100,
    title='Posterior predictive'
)