# Variational Bayesian Last Layer (VBLL) with SNGP Regression

In [None]:
%pip install lightning-uq-box
%pip install vbll

In [None]:
import os
import tempfile
from functools import partial

import torch
import torch.nn as nn
from lightning import Trainer
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import CSVLogger

from lightning_uq_box.datamodules import ToyHeteroscedasticDatamodule
from lightning_uq_box.models.fc_resnet import FCResNet
from lightning_uq_box.uq_methods import VBLLRegression
from lightning_uq_box.uq_methods.sngp import RandomFourierFeatures
from lightning_uq_box.uq_methods.spectral_normalized_layers import (
    collect_input_sizes,
    spectral_normalize_model_layers,
)
from lightning_uq_box.viz_utils import (
    plot_predictions_regression,
    plot_toy_regression_data,
    plot_training_metrics,
)

%load_ext autoreload
%autoreload 2

In [None]:
# temporary directory for saving
my_temp_dir = tempfile.mkdtemp()

seed_everything(42)

## Datamodule

In [None]:
# datamodule = ToyDUE(batch_size=32, n_samples=128, normalize=True)
datamodule = ToyHeteroscedasticDatamodule(batch_size=64)
datamodule.__dict__.keys()
X_train, y_train, X_test, y_test, train_loader, test_loader = (
    datamodule.X_train,
    datamodule.y_train,
    datamodule.X_test,
    datamodule.y_test,
    datamodule.train_dataloader(),
    datamodule.val_dataloader(),
)

In [None]:
fig = plot_toy_regression_data(X_train, y_train, X_test, y_test)

## Model

In [None]:
# todo should be spectral normalized
feature_extractor = FCResNet(
    input_dim=1,
    features=64,
    depth=4,
    num_outputs=64,
    dropout_rate=0.0,
    activation="elu",
)
input_dims = collect_input_sizes(feature_extractor, 1)
feature_extractor = spectral_normalize_model_layers(
    feature_extractor, input_dimensions=input_dims, n_power_iterations=1
)

# todo be able to get
rff_features = RandomFourierFeatures(in_dim=64, num_random_features=128)

model = nn.Sequential(feature_extractor, rff_features)

In [None]:
vbll_model = VBLLRegression(
    model=model,
    replace_ll=False,  # instead append the VBLL layer
    regularization_weight=(1 / X_train.shape[0]) * 2,
    optimizer=partial(torch.optim.Adam, lr=4e-3),
    num_targets=1,
    prior_scale=1.0,
    wishart_scale=0.1,
)

## Trainer

In [None]:
logger = CSVLogger(my_temp_dir)
trainer = Trainer(
    max_epochs=400,  # number of epochs we want to train
    logger=logger,  # log training metrics for later evaluation
    log_every_n_steps=1,
    enable_checkpointing=False,
    enable_progress_bar=True,
    default_root_dir=my_temp_dir,
    gradient_clip_val=1.0,
)

In [None]:
trainer.fit(vbll_model, datamodule)

In [None]:
fig = plot_training_metrics(
    os.path.join(my_temp_dir, "lightning_logs"), ["train_loss", "trainRMSE"]
)

## Prediction

In [None]:
preds = vbll_model.predict_step(X_test)

In [None]:
fig = plot_predictions_regression(
    X_train,
    y_train,
    X_test,
    y_test,
    preds["pred"],
    preds["pred_uct"].squeeze(-1),
    epistemic=preds["pred_uct"].squeeze(-1),
    title="VBLL Regression with SNGP Feature Extractor",
    show_bands=False,
)