# MC-Dropout Classification

In [None]:
%%capture
%pip install git+https://github.com/lightning-uq-box/lightning-uq-box.git

## Imports

In [None]:
import os
import tempfile
from functools import partial

import matplotlib.pyplot as plt
import torch.nn as nn
from lightning import Trainer
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import CSVLogger
from torch.optim import Adam

from lightning_uq_box.datamodules import TwoMoonsDataModule
from lightning_uq_box.models import MLP
from lightning_uq_box.uq_methods import MCDropoutClassification
from lightning_uq_box.viz_utils import (
    plot_predictions_classification,
    plot_training_metrics,
    plot_two_moons_data,
)

plt.rcParams["figure.figsize"] = [14, 5]

%load_ext autoreload
%autoreload 2

In [None]:
seed_everything(0)

In [None]:
my_temp_dir = tempfile.mkdtemp()

## Datamodule

In [None]:
dm = TwoMoonsDataModule(batch_size=128)

X_train, Y_train, X_test, Y_test, test_grid_points = (
    dm.X_train,
    dm.Y_train,
    dm.X_test,
    dm.Y_test,
    dm.test_grid_points,
)

In [None]:
fig = plot_two_moons_data(X_train, Y_train, X_test, Y_test)

## Model

In [None]:
network = MLP(
    n_inputs=2,
    n_hidden=[50, 50, 50],
    n_outputs=2,
    dropout_p=0.2,
    activation_fn=nn.ReLU(),
)
network

In [None]:
mc_dropout_module = MCDropoutClassification(
    model=network,
    optimizer=partial(Adam, lr=1e-2),
    loss_fn=nn.CrossEntropyLoss(),
    num_mc_samples=25,
)

## Trainer

In [None]:
logger = CSVLogger(my_temp_dir)
trainer = Trainer(
    max_epochs=100,  # number of epochs we want to train
    logger=logger,  # log training metrics for later evaluation
    log_every_n_steps=1,
    enable_checkpointing=False,
    enable_progress_bar=False,
    default_root_dir=my_temp_dir,
)

In [None]:
trainer.fit(mc_dropout_module, dm)

## Training Metrics

In [None]:
fig = plot_training_metrics(
    os.path.join(my_temp_dir, "lightning_logs"), ["train_loss", "trainAcc"]
)

## Prediction

In [None]:
# save predictions
trainer.test(mc_dropout_module, dm.test_dataloader())

In [None]:
preds = mc_dropout_module.predict_step(test_grid_points)

## Evaluate Predictions

In [None]:
fig = plot_predictions_classification(
    X_test, Y_test, preds["pred"].argmax(-1), test_grid_points, preds["pred_uct"]
)