# Spectral Normalized Gaussian Process (SNGP) Regression

In [None]:
import torch
import os
from functools import partial
from torch.optim import Adam
from lightning_uq_box.uq_methods import SNGPRegression
from lightning_uq_box.datamodules import ToyDUE
from lightning import Trainer
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import CSVLogger

from lightning_uq_box.models.fc_resnet import FCResNet
from lightning_uq_box.viz_utils import (
    plot_toy_regression_data,
    plot_predictions_regression,
    plot_training_metrics,
)

import tempfile

%load_ext autoreload
%autoreload 2

In [None]:
# temporary directory for saving
my_temp_dir = tempfile.mkdtemp()

seed_everything(42)

## Datamodule

In [None]:
datamodule = ToyDUE(batch_size=128, n_samples=1000)
datamodule.__dict__.keys()
X_train, y_train, X_test, y_test, train_loader, test_loader = (
    datamodule.X_train,
    datamodule.y_train,
    datamodule.X_test,
    datamodule.y_test,
    datamodule.train_dataloader(),
    datamodule.val_dataloader(),
)

In [None]:
fig = plot_toy_regression_data(X_train, y_train, X_test, y_test)

## Model

In [None]:
feature_extractor = FCResNet(input_dim=1, features=128, depth=4)

In [None]:
sngp = SNGPRegression(
    feature_extractor=feature_extractor,
    loss_fn=torch.nn.MSELoss(),
    optimizer=partial(Adam, lr=1e-3),
)

## Trainer

In [None]:
logger = CSVLogger(my_temp_dir)
trainer = Trainer(
    max_epochs=500,  # number of epochs we want to train
    logger=logger,  # log training metrics for later evaluation
    log_every_n_steps=1,
    enable_checkpointing=False,
    enable_progress_bar=True,
    default_root_dir=my_temp_dir,
)

In [None]:
trainer.fit(sngp, datamodule)

In [None]:
fig = plot_training_metrics(os.path.join(my_temp_dir, "lightning_logs"), ["train_loss", "trainRMSE"])

## Prediction

In [None]:
preds = sngp.predict_step(X_test)

In [None]:
fig = plot_predictions_regression(
    X_train,
    y_train,
    X_test,
    y_test,
    preds["pred"],
    preds["pred_uct"],
    epistemic=preds["epistemic_uct"],
    title="SNGP",
    show_bands=False,
)

: 