# Uncertainty

Quantifying uncertainty to enhance decision making is the main benefit of Bayesian Deep Learning (BDL). This is particularly useful when the model encounters data that it hasn't seen during training, which we demonstrated in the [Bayesian Deep Learning Tutorial](bayesian_deep_learning.ipynb) where ensembles are trained on data within (-2pi, 2pi) and tested on (-8pi, 8pi).

![](de_regression_plot.png)


From this plot we can see that the predictions for data within the training set (between the dotted lines) are closely clustered around the mean indicating a high level of certainty. Predictions on data outside the training set produce a distribution of results with much higher levels of variance, indicating high uncertainty. 

In this tutorial, we will quantify uncertainty on a classification task.

## Dataset

We will be examining the MNIST dataset, which consists of 60,000 28x28 pixel grayscale images. We will train on the original images, and test on a set of rotated images to demonstrate the model's ability to handle data outside the training set. Our expectation should be that the more rotated/augmented the test image is, the more **uncertain** our model will become.

![](mnist.png)


In [1]:
import experiments.nns.simplenet.simplenet
from experiments.nns.simplenet.simplenet import SimpleNet
from experiments.nns.cnn.cnn import CNN
from experiments.nns.lenet.lenet import LeNet
import torch.optim as optim
import torch.nn as nn
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

In [5]:
import torch
import copy
import os
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from pathlib import Path

# Get the current notebook's path
notebook_path = Path().resolve()

# Define the path to directory containing MNIST
mnist_directory = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(notebook_path)), "..","..","..","..","..","..","..", "/usr/data1/vision/data/"))


# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# Load the MNIST training dataset
train_dataset = datasets.MNIST(root=mnist_directory, train=True, download=False, transform=transform)

In [None]:
# This code reduces the size of the training set to 1/10th the original amount, to decrease training time.
#  This cell only needs to be run once. After first run, indices are saved in MNIST directory.


# # Select 1/10 samples of each class for both train and test datasets
# train_subset_indices = []

# for c in range(10):
#     class_indices = [i for i in range(len(train_dataset)) if train_dataset[i][1] == c]
#     train_subset_indices.extend(class_indices[:len(class_indices)//10])

# # # Save the subset indices inside mnist_directory
# train_idx_path = os.path.join(mnist_directory, "train_indices.pth")
# torch.save(train_subset_indices, train_idx_path)


In [3]:
train_idx_path = os.path.join(mnist_directory, "train_indices.pth")

# Load the subset indices from mnist_directory
loaded_train_indices = torch.load(train_idx_path)

# Create the subset using the loaded indices
loaded_train_subset = torch.utils.data.Subset(train_dataset, loaded_train_indices)

## Creating Rotated Dataset

In [6]:
import torch
import copy
import os
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from bdl import CustomMNISTDataset

# Select the numbers we would like to create rotated versions of
selected_numbers = [1]
# Include rotation in the transformation
# transform = transforms.Compose([
#     transforms.RandomRotation(degrees=(-180, 180), fill=(0,)),
#     transforms.ToTensor(),
#     transforms.Normalize((0.5,), (0.5,))
# ])

# rotation_angles = [0, 15, 30, 45, 60, 75, 80, 95]
rotation_angles = []
for i in range(12):
    rotation_angles.append(i*5.5)
def get_rotated_mnist_dataset(degrees):
    rotate_transform = transforms.Compose([
            transforms.RandomRotation(degrees=(degrees, degrees), fill=(0,)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
    rotated_dataset = CustomMNISTDataset(root=mnist_directory, numbers=selected_numbers, train=False, transform=rotate_transform)
    rotated_loader = DataLoader(rotated_dataset, batch_size=512, shuffle=False)
    return rotated_dataset, rotated_loader
rotated_datasets = []
rotated_loaders = []

for angle in rotation_angles:
    rotated_dataset, rotated_loader = get_rotated_mnist_dataset(angle)
    rotated_datasets.append(rotated_dataset)
    rotated_loaders.append(rotated_loader)


test_dataset = CustomMNISTDataset(root=mnist_directory, numbers=selected_numbers, train=False, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)

In [7]:
# Create data loaders
batch_size = 512
train_loader = DataLoader(loaded_train_subset, batch_size=batch_size, shuffle=True)

## BDL Methods
We will be applying three different BDL methods to compare each methods ability to quantify uncertainty.

### Deep Ensembles
Deep ensembles run the same model under different intialiazations or random seeds to generate a distribution of weights.
It treats the ensemble of models as an approximation to the posterior.

### Multi-SWAG
Multi-SWAG creates a distribution of parameter states by 

1. Calculating a running average for each parameter (μ)
2. Tracking the deviation from the mean for each parameter for the last K states with K=20 by default (σ)

and approximates the posterior by applying Bayesian Model Averaging to samples taken from the Gaussian ~ N(μ,σ²). 

In [None]:
import torch
from torch.utils.data import DataLoader
import push.bayes.ensemble

epochs = 50

ensemble = push.bayes.ensemble.train_deep_ensemble(
        train_loader,
        torch.nn.CrossEntropyLoss(),
        epochs,
        LeNet,
        num_devices=2,
        num_ensembles=100,
        cache_size=25
    )