# Bayesian neural net (half-moons)

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from lightning.pytorch import seed_everything, Trainer

from vartorch import (
    MoonsDataModule,
    DenseVarClassifier,
    plot_data_2d,
    plot_data_and_preds_2d,
    post_mean,
    post_predictive,
    post_uncertainty
)

In [None]:
_ = seed_everything(111111)  # set random seeds manually

## Half-moons data

In [None]:
moons = MoonsDataModule(
    num_train=500,
    num_val=100,
    num_test=1000,
    noise_level=0.15,
    offsets=(0.15, -0.15),
    batch_size=32
)

moons.prepare_data()  # sample numerical data
moons.setup(stage='test')  # create test set

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))
plot_data_2d(moons.x_train, moons.y_train, colors=(plt.cm.Set1(1), plt.cm.Set1(0)), ax=ax)
ax.set(xlim=(-2.5, 2.5), ylim=(-2.5, 2.5))
ax.set_aspect('equal', adjustable='box')
ax.legend(loc='upper right')
ax.grid(visible=True, which='both', color='gray', alpha=0.2, linestyle='-')
ax.set_axisbelow(True)
fig.tight_layout()

## Variational model

In [None]:
ckpt_file = '../run/moons/version_0/checkpoints/last.ckpt'

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

var_model = DenseVarClassifier.load_from_checkpoint(ckpt_file)

var_model = var_model.to(device)
var_model = var_model.train(False)

print(f'Train mode: {var_model.training}')
print(f'Sampling: {var_model.sampling}')

In [None]:
trainer = Trainer(logger=False)

test_metrics = trainer.test(
    model=var_model,
    dataloaders=moons.test_dataloader(),
    verbose=True
)

## Example predictions

In [None]:
# plot posterior mean
fig, ax = plot_data_and_preds_2d(
    x_data=moons.x_train,
    y_data=moons.y_train,
    pred_function=lambda x: post_mean(var_model, x),
    figsize=(6, 4.5),
    xlim=(-2.5, 2.5),
    ylim=(-2.5, 2.5),
    levels=(0.3, 0.5, 0.7),
    title='Posterior mean'
)

In [None]:
# plot posterior predictive distribution
fig, ax = plot_data_and_preds_2d(
    x_data=moons.x_train,
    y_data=moons.y_train,
    pred_function=lambda x: post_predictive(var_model, x),
    figsize=(6, 4.5),
    xlim=(-2.5, 2.5),
    ylim=(-2.5, 2.5),
    levels=(0.1, 0.3, 0.5, 0.7, 0.9),
    title='Posterior predictions'
)

In [None]:
# plot posterior uncertainty
fig, ax = plot_data_and_preds_2d(
    x_data=moons.x_train,
    y_data=moons.y_train,
    pred_function=lambda x: post_uncertainty(var_model, x),
    figsize=(6, 4.5),
    xlim=(-2.5, 2.5),
    ylim=(-2.5, 2.5),
    levels=np.linspace(0.1, 0.9, 9),
    title='Posterior uncertainty'
)