# Variational Bayes Last Layer (VBLL) Classification

In [None]:
%pip install lightning-uq-box

In [None]:
import os
import tempfile

from lightning import Trainer
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import CSVLogger

from lightning_uq_box.datamodules import TwoMoonsDataModule
from lightning_uq_box.models import MLP
from lightning_uq_box.uq_methods import VBLLClassification
from lightning_uq_box.viz_utils import (
    plot_predictions_classification,
    plot_training_metrics,
    plot_two_moons_data,
)

%load_ext autoreload
%autoreload 2

In [None]:
seed_everything(2)
# temporary directory for saving
my_temp_dir = tempfile.mkdtemp()

## Datamodule

In [None]:
dm = TwoMoonsDataModule(batch_size=128)

In [None]:
# define data
X_train, y_train, X_test, y_test, test_grid_points = (
    dm.X_train,
    dm.y_train,
    dm.X_test,
    dm.y_test,
    dm.test_grid_points,
)

In [None]:
fig = plot_two_moons_data(X_train, y_train, X_test, y_test)

## Model

In [None]:
network = MLP(n_inputs=2, n_hidden=[50, 50], n_outputs=2)
network

In [None]:
vbll_model = VBLLClassification(
    model=network,
    num_targets=2,
    regularization_weight=1 / X_train.shape[0],
    parameterization="diagonal",
    prior_scale=1.0,
)

## Trainer

In [None]:
logger = CSVLogger(my_temp_dir)
trainer = Trainer(
    max_epochs=150,  # number of epochs we want to train
    logger=logger,  # log training metrics for later evaluation
    log_every_n_steps=1,
    enable_checkpointing=False,
    enable_progress_bar=False,
    default_root_dir=my_temp_dir,
)

In [None]:
trainer.fit(vbll_model, dm)

In [None]:
fig = plot_training_metrics(
    os.path.join(my_temp_dir, "lightning_logs"), ["train_loss", "trainAcc"]
)

# Predictions

We can plot the predictions for a grid of test points spanning the extent of the input data and visualize the decision boundaries and corresponding uncertainty.

In [None]:
preds = vbll_model.predict_step(test_grid_points)

In [None]:
fig = plot_predictions_classification(
    X_test,
    y_test,
    preds["pred"].argmax(-1),
    test_grid_points,
    preds["pred_uct"].cpu().numpy(),
)