# Spectral Normalized Gaussian Process (SNGP) Regression

In [None]:
%%capture
%pip install git+https://github.com/lightning-uq-box/lightning-uq-box.git

In [None]:
import os
import tempfile
from functools import partial

import matplotlib.pyplot as plt
import torch
from lightning import Trainer
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import CSVLogger
from torch.optim import Adam

from lightning_uq_box.datamodules import ToyDUE
from lightning_uq_box.models.fc_resnet import FCResNet
from lightning_uq_box.uq_methods import SNGPRegression
from lightning_uq_box.viz_utils import (
    plot_calibration_uq_toolbox,
    plot_predictions_regression,
    plot_toy_regression_data,
    plot_training_metrics,
)

plt.rcParams["figure.figsize"] = [14, 5]

%load_ext autoreload
%autoreload 2

In [None]:
# temporary directory for saving
my_temp_dir = tempfile.mkdtemp()

seed_everything(42)

## Datamodule

In [None]:
dm = ToyDUE(batch_size=256, n_samples=1000)
# dm = ToyHeteroscedasticDatamodule(batch_size=256, n_points=1000)
X_train, Y_train, X_test, Y_test, train_loader, test_loader, X_gtext, Y_gtext = (
    dm.X_train,
    dm.Y_train,
    dm.X_test,
    dm.Y_test,
    dm.train_dataloader(),
    dm.val_dataloader(),
    dm.X_gtext,
    dm.Y_gtext,
)

In [None]:
fig = plot_toy_regression_data(X_train, Y_train, X_test, Y_test)

## Model

In [None]:
feature_extractor = FCResNet(input_dim=1, features=64, depth=4)

In [None]:
sngp = SNGPRegression(
    feature_extractor=feature_extractor,
    loss_fn=torch.nn.MSELoss(),
    optimizer=partial(Adam, lr=3e-3),
)

## Trainer

In [None]:
logger = CSVLogger(my_temp_dir)
trainer = Trainer(
    max_epochs=100,  # number of epochs we want to train
    logger=logger,  # log training metrics for later evaluation
    log_every_n_steps=1,
    enable_checkpointing=False,
    enable_progress_bar=False,
    default_root_dir=my_temp_dir,
)

In [None]:
trainer.fit(sngp, dm)

In [None]:
fig = plot_training_metrics(
    os.path.join(my_temp_dir, "lightning_logs"), ["train_loss", "trainRMSE"]
)

## Evaluate Predictions

The constructed Data Module contains two possible test variable. `X_test` are IID samples from the same noise distribution as the training data, while `X_gtext` ("X ground truth extended") are dense inputs from the underlying "ground truth" function without any noise that also extends the input range to either side, so we can visualize the method's UQ tendencies when extrapolating beyond the training data range. Thus, we will use `X_gtext` for visualization purposes, but use `X_test` to compute uncertainty and calibration metrics because we want to analyse how well the method has learned the noisy data distribution.

In [None]:
preds = sngp.predict_step(X_gtext)

fig = plot_predictions_regression(
    X_train,
    Y_train,
    X_gtext,
    Y_gtext,
    preds["pred"].squeeze(),
    preds["pred_uct"],
    epistemic=preds["epistemic_uct"],
    title="SNGP",
    show_bands=False,
)

In [None]:
preds = sngp.predict_step(X_test)

fig = plot_calibration_uq_toolbox(
    preds["pred"].cpu().numpy(),
    preds["pred_uct"].cpu().numpy(),
    Y_test.cpu().numpy(),
    X_test.cpu().numpy(),
)