# Masksembles

Masksembles were proposed by [Durasov et al 2021](https://arxiv.org/abs/2012.08334).

In [None]:
%%capture
%pip install git+https://github.com/lightning-uq-box/lightning-uq-box.git

## Imports

In [None]:
import os
import tempfile
from functools import partial

import matplotlib.pyplot as plt
import torch.nn as nn
from lightning import Trainer
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import CSVLogger
from torch.optim import Adam

from lightning_uq_box.datamodules import ToyHeteroscedasticDatamodule
from lightning_uq_box.models import MLP
from lightning_uq_box.uq_methods import NLL, MasksemblesRegression
from lightning_uq_box.viz_utils import (
    plot_calibration_uq_toolbox,
    plot_predictions_regression,
    plot_toy_regression_data,
    plot_training_metrics,
)

plt.rcParams["figure.figsize"] = [14, 5]

In [None]:
seed_everything(0)  # seed everything for reproducibility

In [None]:
my_temp_dir = tempfile.mkdtemp()

## Datamodule

In [None]:
dm = ToyHeteroscedasticDatamodule(batch_size=32)

X_train, Y_train, train_loader, X_test, Y_test, test_loader, X_gtext, Y_gtext = (
    dm.X_train,
    dm.Y_train,
    dm.train_dataloader(),
    dm.X_test,
    dm.Y_test,
    dm.test_dataloader(),
    dm.X_gtext,
    dm.Y_gtext,
)

In [None]:
fig = plot_toy_regression_data(X_train, Y_train, X_test, Y_test)

## Model

Here we are creating a deterministic MLP, with two outputs like an MVE network that is trained with the Negative Log Likelihood Loss.

In [None]:
network = MLP(n_inputs=1, n_hidden=[50, 50], n_outputs=2, activation_fn=nn.Tanh())
network

When initializing the Masksemble Module, the init will convert the model into a Maskesemble by replacing the layers with Masked Ensemble Layers.

In [None]:
masksemble = MasksemblesRegression(
    model=network,
    num_estimators=5,
    scale=3,
    loss_fn=NLL(),
    optimizer=partial(Adam, lr=3e-3),
)

## Trainer

In [None]:
logger = CSVLogger(my_temp_dir)
trainer = Trainer(
    max_epochs=300,  # number of epochs we want to train
    logger=logger,
    log_every_n_steps=1,
    enable_checkpointing=False,
    enable_progress_bar=False,
    default_root_dir=my_temp_dir,
)

In [None]:
trainer.fit(masksemble, dm)

## Training Metrics

In [None]:
fig = plot_training_metrics(
    os.path.join(my_temp_dir, "lightning_logs"), ["train_loss", "trainRMSE"]
)

## Evaluate Predictions

The constructed Data Module contains two possible test variable. `X_test` are IID samples from the same noise distribution as the training data, while `X_gtext` ("X ground truth extended") are dense inputs from the underlying "ground truth" function without any noise that also extends the input range to either side, so we can visualize the method's UQ tendencies when extrapolating beyond the training data range. Thus, we will use `X_gtext` for visualization purposes, but use `X_test` to compute uncertainty and calibration metrics because we want to analyse how well the method has learned the noisy data distribution.

In [None]:
preds = masksemble.predict_step(X_gtext)
fig = plot_predictions_regression(
    X_train,
    Y_train,
    X_gtext,
    Y_gtext,
    preds["pred"],
    preds["pred_uct"],
    epistemic=preds["epistemic_uct"],
    aleatoric=preds["aleatoric_uct"],
    title="Masksembles",
    show_bands=False,
)

In [None]:
preds = masksemble.predict_step(X_test)
fig = plot_calibration_uq_toolbox(
    preds["pred"].cpu().numpy(),
    preds["pred_uct"].cpu().numpy(),
    Y_test.cpu().numpy(),
    X_test.cpu().numpy(),
)