# Mean Variance Estimation

In [None]:
%%capture
%pip install git+https://github.com/lightning-uq-box/lightning-uq-box.git

## Theoretic Foundation

The Gaussian model, also referred to as Mean Variance Estimation, was first studied in [Nix, 1994](https://ieeexplore.ieee.org/abstract/document/374138) and further used in [Sluijterman 2023](https://arxiv.org/abs/2302.08875), this is a deterministic model that predicts the parameters of a Gaussian distribution

$$
    f_{\theta}(x^{\star}) = (\mu_{\theta}(x^\star),\sigma_{\theta}(x^\star))
$$

in a single forward pass, where standard deviations $\sigma_{\theta}(x^\star)$ can be used as a measure of data uncertainty. To this end the network now outputs two parameters and is trained with the Gaussian negative log-likelihood (NLL) as a loss objective [Kendall, 2017](https://proceedings.neurips.cc/paper_files/paper/2017/file/2650d6089a6d640c5e85b2b88265dc2b-Paper.pdf), that is given by

$$
    \mathcal{L}(\theta, (x^{\star}, y^{\star})) = \frac{1}{2}\text{ln}\left(2\pi\sigma_{\theta}(x^{\star})^2\right) + \frac{1}{2\sigma_{\theta}(x^{\star})^2}\left(\mu_{\theta}(x^{\star})-y^{\star}\right)^2.
$$

Correspondingly, the model prediction consists of a predictive mean, $\mu_{\theta}(x^\star)$, and the predictive uncertainty - the parameters of a Gaussian distribution - in this case the standard deviation $\sigma_{\theta}(x^\star)$.


## Imports

In [None]:
import os
import tempfile
from functools import partial

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from lightning import Trainer
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import CSVLogger

from lightning_uq_box.datamodules import ToyHeteroscedasticDatamodule
from lightning_uq_box.models import MLP
from lightning_uq_box.uq_methods import MVERegression
from lightning_uq_box.viz_utils import (
    plot_calibration_uq_toolbox,
    plot_predictions_regression,
    plot_toy_regression_data,
    plot_training_metrics,
)

plt.rcParams["figure.figsize"] = [14, 5]

%load_ext autoreload
%autoreload 2

In [None]:
seed_everything(0)  # seed everything for reproducibility

We define a temporary directory to look at some training metrics and results.

In [None]:
my_temp_dir = tempfile.mkdtemp()

## Datamodule

To demonstrate the method, we will make use of a Toy Regression Example that is defined as a [Lightning Datamodule](https://lightning.ai/docs/pytorch/stable/data/datamodule.html). While this might seem like overkill for a small toy problem, we think it is more helpful how the individual pieces of the library fit together so you can train models on more complex tasks.

In [None]:
dm = ToyHeteroscedasticDatamodule()

X_train, Y_train, train_loader, X_test, Y_test, test_loader, X_gtext, Y_gtext = (
    dm.X_train,
    dm.Y_train,
    dm.train_dataloader(),
    dm.X_test,
    dm.Y_test,
    dm.test_dataloader(),
    dm.X_gtext,
    dm.Y_gtext,
)

In [None]:
fig = plot_toy_regression_data(X_train, Y_train, X_gtext, Y_gtext)

## Model

For our Toy Regression problem, we will use a simple Multi-layer Perceptron (MLP) that you can configure to your needs. For the documentation of the MLP see [here](https://readthedocs.io/en/stable/api/models.html#MLP).

In [None]:
network = MLP(n_inputs=1, n_hidden=[50, 50, 50], n_outputs=2, activation_fn=nn.Tanh())
network

With an underlying neural network, we can now use our desired UQ-Method as a sort of wrapper. All UQ-Methods are implemented as [LightningModule](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html) that allow us to concisely organize the code and remove as much boilerplate code as possible. 

We can first train with the MSE loss, and only adapt the parameters of the mean prediction for the number of ```burnin_eochs``` for more stable training, before switching to the NLL loss. 

In [None]:
mve_model = MVERegression(
    model=network, optimizer=partial(torch.optim.Adam, lr=1e-3), burnin_epochs=5
)

## Trainer

Now that we have a LightningDataModule and a UQ-Method as a LightningModule, we can conduct training with a [Lightning Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html). It has tons of options to make your life easier, so we encourage you to check the documentation.

In [None]:
logger = CSVLogger(my_temp_dir)
trainer = Trainer(
    max_epochs=250,  # number of epochs we want to train
    accelerator="cpu",  # use distributed training
    logger=logger,  # log training metrics for later evaluation
    log_every_n_steps=1,
    enable_checkpointing=False,
    enable_progress_bar=False,
    default_root_dir=my_temp_dir,
)

Training our model is now easy:

In [None]:
trainer.fit(mve_model, dm)

## Training Metrics

To get some insights into how the training went, we can use the utility function to plot the training loss and RMSE metric.

In [None]:
fig = plot_training_metrics(
    os.path.join(my_temp_dir, "lightning_logs"), ["train_loss", "trainRMSE"]
)

## Evaluate Predictions

The constructed Data Module contains two possible test variable. `X_test` are IID samples from the same noise distribution as the training data, while `X_gtext` ("X ground truth extended") are dense inputs from the underlying "ground truth" function without any noise that also extends the input range to either side, so we can visualize the method's UQ tendencies when extrapolating beyond the training data range. Thus, we will use `X_gtext` for visualization purposes, but use `X_test` to compute uncertainty and calibration metrics because we want to analyse how well the method has learned the noisy data distribution.

In [None]:
preds = mve_model.predict_step(X_gtext)

fig = plot_predictions_regression(
    X_train,
    Y_train,
    X_gtext,
    Y_gtext,
    preds["pred"],
    preds["pred_uct"].squeeze(-1),
    aleatoric=preds["aleatoric_uct"],
    title="Mean Variance Estimation Network",
    show_bands=False,
)

For some additional metrics relevant to UQ, we can use the great [uncertainty-toolbox](https://uncertainty-toolbox.github.io/) that gives us some insight into the calibration of our prediction, that we can evaluate on our held out IID test set.

In [None]:
preds = mve_model.predict_step(X_test)

fig = plot_calibration_uq_toolbox(
    preds["pred"].cpu().numpy(),
    preds["pred_uct"].cpu().numpy(),
    Y_test.cpu().numpy(),
    X_test.cpu().numpy(),
)