# Bayesian neural net (half-moons)

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.append('..')

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from lightning.pytorch import seed_everything

from vartorch import (
    MoonsDataModule,
    DenseVarClassifier,
    plot_data_2d,
    plot_data_and_preds_2d
)

In [None]:
_ = seed_everything(111111) # set random seeds manually

## Half-moons data

In [None]:
moons = MoonsDataModule(
    num_train=500,
    num_val=100,
    num_test=100,
    noise_level=0.15,
    offsets=(0.15, -0.15),
    batch_size=32,
)

moons.prepare_data() # sample numerical data
moons.setup(stage='test') # create test set

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))
plot_data_2d(moons.x_train, moons.y_train, colors=(plt.cm.Set1(1), plt.cm.Set1(0)), ax=ax)
ax.set(xlim=(-2, 3), ylim=(-2, 2.5))
ax.set_aspect('equal', adjustable='box')
ax.legend(loc='upper right')
ax.grid(visible=True, which='both', color='gray', alpha=0.2, linestyle='-')
ax.set_axisbelow(True)
fig.tight_layout()

## Model import

In [None]:
ckpt_file = '../run/moons/version_0/checkpoints/last.ckpt'

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

var_model = DenseVarClassifier.load_from_checkpoint(ckpt_file)

var_model = var_model.eval()
var_model = var_model.to(device)

## Example predictions

In [None]:
# @torch.no_grad()
# def point_prediction(x):
#     '''Compute normal point predictions.'''
#     x_tensor = torch.tensor(x, dtype=torch.float32)

#     model.train(False)

#     point_logits = model.predict(x_tensor.to(model.device))
#     point_probs = torch.sigmoid(point_logits)
#     return point_probs.cpu().numpy()

@torch.no_grad()
def post_mean(x):
    '''Predict with posterior mean weights.'''
    x_tensor = torch.tensor(x, dtype=torch.float32)

    var_model.sample(False)
    var_model.train(False)

    point_logits = var_model.predict(x_tensor.to(var_model.device))
    point_probs = torch.sigmoid(point_logits)
    return point_probs.cpu().numpy()

@torch.no_grad()
def post_predictive(x, num_samples=100):
    '''Predict according to the posterior predictive distribution.'''
    x_tensor = torch.tensor(x, dtype=torch.float32)

    var_model.sample(True)
    var_model.train(False)

    sampled_logits = var_model.predict(x_tensor.to(var_model.device), num_samples)
    sampled_probs = torch.sigmoid(sampled_logits)

    post_mean = torch.mean(sampled_probs, axis=-1)
    return post_mean.cpu().numpy()

@torch.no_grad()
def post_uncertainty(x, num_samples=100):
    '''Compute the uncertainty associated with the posterior predictive.'''
    x_tensor = torch.tensor(x, dtype=torch.float32)

    var_model.sample(True)
    var_model.train(False)

    sampled_logits = var_model.predict(x_tensor.to(var_model.device), num_samples)
    sampled_probs = torch.sigmoid(sampled_logits)

    post_std = torch.std(sampled_probs, axis=-1)
    return post_std.cpu().numpy()

In [None]:
# fig, ax = plot_data_and_preds_2d(
#     x_data=moons.x_train,
#     y_data=moons.y_train,
#     pred_function=point_prediction,
#     figsize=(6, 4.5),
#     xlim=(-2, 3),
#     ylim=(-2, 2.5),
#     levels=(0.3, 0.5, 0.7),
#     title='Point predictions'
# )

In [None]:
fig, ax = plot_data_and_preds_2d(
    x_data=moons.x_train,
    y_data=moons.y_train,
    pred_function=post_mean,
    figsize=(6, 4.5),
    xlim=(-2, 3),
    ylim=(-2, 2.5),
    levels=(0.3, 0.5, 0.7),
    title='Posterior mean'
)

In [None]:
fig, ax = plot_data_and_preds_2d(
    x_data=moons.x_train,
    y_data=moons.y_train,
    pred_function=post_predictive,
    figsize=(6, 4.5),
    xlim=(-2, 3),
    ylim=(-2, 2.5),
    levels=(0.1, 0.3, 0.5, 0.7, 0.9),
    title='Posterior predictions'
)

In [None]:
fig, ax = plot_data_and_preds_2d(
    x_data=moons.x_train,
    y_data=moons.y_train,
    pred_function=post_uncertainty,
    figsize=(6, 4.5),
    xlim=(-2, 3),
    ylim=(-2, 2.5),
    levels=np.linspace(0.1, 0.9, 9),
    title='Posterior uncertainty'
)