# ZigZag: Universal Sampling-free Uncertainty Estimation

ZigZag was proposed by [Durasov et al 2024](https://openreview.net/forum?id=QSvb6jBXML).

The work does several evaluations regarding OOD tasks and regards their methods to adress the two types of uncertainty as follows:

"In other words, there are two scenarios when reconstruction fails: 1) when (x, y) is OOD because x is OOD, addressing epistemic uncertainty and OOD samples, 2) when (x, y) is OOD because y is OOD / errornous. In this
case, the reconstruction issue is due to y, our uncertainty measure is high, we cover aleatoric uncertainty
connected to predicted target."

## Imports

In [None]:
import os
import tempfile
from functools import partial

import matplotlib.pyplot as plt
import torch.nn as nn
from lightning import Trainer
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import CSVLogger
from torch.optim import Adam

from lightning_uq_box.datamodules import ToyHeteroscedasticDatamodule
from lightning_uq_box.models import MLP
from lightning_uq_box.uq_methods import ZigZagRegression
from lightning_uq_box.viz_utils import (
    plot_calibration_uq_toolbox,
    plot_predictions_regression,
    plot_toy_regression_data,
    plot_training_metrics,
)

plt.rcParams["figure.figsize"] = [14, 5]

In [None]:
seed_everything(0)  # seed everything for reproducibility

In [None]:
my_temp_dir = tempfile.mkdtemp()

## Datamodule

In [None]:
dm = ToyHeteroscedasticDatamodule(batch_size=32)

X_train, y_train, train_loader, X_test, y_test, test_loader = (
    dm.X_train,
    dm.y_train,
    dm.train_dataloader(),
    dm.X_test,
    dm.y_test,
    dm.test_dataloader(),
)

In [None]:
fig = plot_toy_regression_data(X_train, y_train, X_test, y_test)

## Model

Here we are creating a deterministic MLP, with two inputs because the ZigZag method is first trained to reconstruct the input and later uses a two-step prediction forward pass, where the features of the first forward pass are concatenated to the original input.

In [None]:
network = MLP(n_inputs=2, n_hidden=[50, 50, 50], n_outputs=1, activation_fn=nn.Tanh())
network

When initializing the Masksemble Module, the init will convert the model into a Maskesemble by replacing the layers with Masked Ensemble Layers.

In [None]:
zigzag = ZigZagRegression(
    model=network, loss_fn=nn.MSELoss(), optimizer=partial(Adam, lr=3e-3)
)

## Trainer

In [None]:
logger = CSVLogger(my_temp_dir)
trainer = Trainer(
    max_epochs=400,  # number of epochs we want to train
    logger=logger,  # log training metrics for later evaluation
    log_every_n_steps=1,
    enable_checkpointing=False,
    enable_progress_bar=False,
    default_root_dir=my_temp_dir,
)

In [None]:
trainer.fit(zigzag, dm)

## Training Metrics

In [None]:
fig = plot_training_metrics(
    os.path.join(my_temp_dir, "lightning_logs"), ["train_loss", "trainRMSE"]
)

## Prediction

We visualize the predictive uncertainty as both the epistemic and aleatoric uncertainty because the interpretation depends on inputs and targets as quoted in the beginning.

In [None]:
preds = zigzag.predict_step(X_test)
fig = plot_predictions_regression(
    X_train,
    y_train,
    X_test,
    y_test,
    preds["pred"],
    preds["pred_uct"],
    epistemic=preds["pred_uct"],
    aleatoric=preds["pred_uct"],
    title="ZigZag",
    show_bands=False,
)

In [None]:
preds = zigzag.predict_step(X_test)
fig = plot_calibration_uq_toolbox(
    preds["pred"].cpu().numpy(),
    preds["pred_uct"].cpu().numpy(),
    y_test.cpu().numpy(),
    X_test.cpu().numpy(),
)