# CIFAR10 Bayesian Deep Learning with Class Removal

We will attempt to perform classification on a neural network trained on CIFAR9 (no airplanes). Then we will add the airplane class back into the dataset and see how the inference changes!

## Imports and Setup

In [2]:
import numpy as np

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision import models, transforms, datasets
from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn
from torchmetrics import Accuracy

from utils import train, learning_curves, evaluate, uncertainty_quantification

## Dataset Creation

In [3]:
train_trans = transforms.Compose([transforms.ToTensor(), 
                                  transforms.Normalize((0.5,), (0.5,)),
                                  transforms.RandomCrop(32, padding=4, padding_mode='reflect'), 
                                  transforms.RandomHorizontalFlip(), ])

val_trans = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5,), (0.5,))])
root = '../data'
train_ds = datasets.CIFAR10(root=root, train=True, transform=train_trans, download=True)
test_ds = datasets.CIFAR10(root=root, train=False, transform=val_trans, download=True)

classes = train_ds.classes
print(classes)

exclude_list = ['airplane']
exclude_label = [train_ds.class_to_idx[n] for n in exclude_list] # name to index
# Returns all indices in the dataset with the labels in exclude_label
# The output is a one-element tuple, so we add [0] at the end
include_idx = list(np.where(np.isin(train_ds.targets, exclude_label, invert=True))[0])
val_include_idx = list(np.where(np.isin(test_ds.targets, exclude_label, invert=True))[0])

# Train and Validate are used in the training process. The full validation set is used for uncertainty quantification
train_ds = torch.utils.data.Subset(train_ds, include_idx)
val_ds = torch.utils.data.Subset(test_ds, val_include_idx)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data\cifar-10-python.tar.gz


 90%|████████▉ | 152666112/170498071 [00:18<00:02, 8313851.14it/s] 


KeyboardInterrupt: 

In [None]:
batch_size=256
lr = 1e-3
total_classes = len(classes)

# Create Dataloaders
train_loader = torch.utils.data.DataLoader(
                dataset=train_ds,
                batch_size=batch_size,
                shuffle=True)
val_loader = torch.utils.data.DataLoader(
                dataset=val_ds,
                batch_size=batch_size,
                shuffle=False)
test_loader = torch.utils.data.DataLoader(
                dataset=test_ds,
                batch_size=batch_size,
                shuffle=False)

dataloads = {'train':train_loader,'val':val_loader,'test':test_loader}

: 

## Visualize the Training Images

In [None]:
def imshow(inp, ax, title=None):
    """imshow for tensor

    Args:
        inp (torch tensor image): Input image
        title (str, optional): title for image. Defaults to None.
    """
    # Put channels last for matplotlib
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.5, 0.5, 0.5])
    std = np.array([0.5, 0.5, 0.5])
    # Reverse normalize
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    ax.imshow(inp)
    if title is not None:
        ax.set_title(title)

: 

In [None]:
fig, axes = plt.subplots(3,4, figsize=(5,5))
for i in range(12):
    ax = axes.ravel()[i]
    # Don't show classes that we excluded
    img, true_class = train_ds[i]
    name = classes[true_class]
    imshow(img, ax, name)
    ax.axis('off')
    
fig.suptitle('Sample Training Images')
fig.subplots_adjust(top=0.88)
plt.show()

: 

## Build Model

We use a ResNet18 base model here. We are using default ImageNet weights with a 10 output node classifier head. To create the Flipout model, we use Bayesian Torch and their "dnn_to_bnn" function. 

In [None]:
const_bnn_prior_parameters = {
        "prior_mu": 0.0,
        "prior_sigma": 1.0,
        "posterior_mu_init": 0.0,
        "posterior_rho_init": -3.0,
        "type": "Flipout",  # Flipout or Reparameterization
        "moped_enable": True,  # True to initialize mu/sigma from the pretrained dnn weights
        "moped_delta": 0.5,
}
    
model = models.resnet18(weights='DEFAULT')

# Here we set the fully connected layer to match our number of classes
model.fc = nn.Linear(model.fc.in_features, total_classes)

dnn_to_bnn(model, const_bnn_prior_parameters) # Convert to Flipout model

# for param in model.parameters():
#     param.requires_grad = False

# for param in model.fc.parameters():
#     param.requires_grad = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

: 

## Train Model

We are using an Adam optimizer here and CrossEntropyLoss(), which combines the effect of Softmax and NLLLoss(). This simplier implementation will require an additional step in inference of adding Softmax back in.

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()
metrics = {'train': Accuracy(task='multiclass', num_classes=total_classes).to(device),
           'val': Accuracy(task='multiclass', num_classes=total_classes).to(device),
           'test': Accuracy(task='multiclass', num_classes=total_classes).to(device),
           'short': 'acc',
           'name': 'Accuracy'}

: 

In [None]:
model_path = '/content/drive/MyDrive/ml_projects/cifar_no_planes/resnet18.pth'

# if os.path.exists(model_path):
#     model.load_state_dict(torch.load(model_path))

results = train(model = model, 
                        dataloads = dataloads, 
                        criterion = criterion, 
                        optimizer = optimizer, 
                        metrics = metrics,
                        epochs = 50, 
                        val_patience = 10, 
                        model_path = model_path,
                        bayesian = True)

: 

In [None]:
# Plot learning curves 
fig = plt.figure(layout='constrained')
fig.suptitle(f'Learning Curves')
learning_curves(results, fig, metrics)
plt.show()

: 

## Model Evaluation

For Bayesian models, there are several facets to consider. 
- NLL (without KL divergence)
- Calibrated Uncertainy
- Uncertainty Quantification
- Individual Image Analysis

### Make MC Inferences

With a deterministic model, we will get the same outputs every time we run inference. Bayesian models have a distribution instead of a single value for each weight. Each inference samples from that distribution, producing different results each time. This is paramount for bayesian models as it allows analysis of the variance and entropy of the samples. Here we apply 25 different inferences to model the distribution of the weights over the output of the model. More is always better, but can be computationally expensive.

In [None]:
num_mc_inferences = 25
y_preds = torch.zeros(num_mc_inferences,len(dataloads['test'].dataset),total_classes)
for i in range(num_mc_inferences):
    y_preds[i,:,:] = evaluate(model, dataloads['test'], criterion, metrics)

y_true = torch.nn.functional.one_hot(torch.tensor(test_ds.targets),num_classes=total_classes)

: 

In [None]:
stats = uncertainty_quantification(y_preds)

: 

### Aleatoric and Epistemic Uncertainty

Aleatoric uncertainty is the uncertainty inherent to the model or problem. It cannot be reduced further without changing the model architecture or transforming the data before sending it through the model. Epistemic uncertainty is uncertainty that is a result of the insufficient training data. If a model is unfamiliar with a particular type of data (like say an airplane), the epistemic uncertainty will be high and can be reduced by adding more images of planes to the training set. 

For this problem, the images of planes have no more inherent noise than the rest of the images, so I would expect comparable aleatoric uncertainty, but high epistemic uncertainty. 

This methodology works best with highly accurate models. The decomposition of uncertainty is not perfect; there is an epistemic component to aleatoric uncertainty. Being able to drive down aleatoric with a high performing model makes the differences more starck.

In [None]:
# Isolate the excluded class for specific analysis
exclude_id = [test_ds.class_to_idx[cls] for cls in exclude_list]

exclude_idx = np.where(np.isin(test_ds.targets, exclude_id))
include_idx = np.where(np.isin(test_ds.targets, exclude_id, invert=True))

aleatoric_include = np.delete(stats['aleatoric'], exclude_idx)
aleatoric_exclude = np.delete(stats['aleatoric'], include_idx)
epistemic_include = np.delete(stats['epistemic'], exclude_idx)
epistemic_exclude = np.delete(stats['epistemic'], include_idx)

# The mean over the samples is a poor metric - each sample is unique and requires specific analysis,
# but this gives a starting point to understanding the benefits of bayesian networks
print(f'Aleatoric Include: {np.mean(aleatoric_include):.4f}')
print(f'Aleatoric Exclude: {np.mean(aleatoric_exclude):.4f}')
print(f'Epistemic Include: {np.mean(epistemic_include):.4f}')
print(f'Epistemic Exclude: {np.mean(epistemic_exclude):.4f}')

: 

## Visualizing Model Predictions

Let's start by visualizing some of the images and our predictions.

In [None]:
fig, axes = plt.subplots(3,4, figsize=(10,10))
for i, ax in enumerate(axes.ravel()):
    img, true_class = test_ds[i]
    true_class = classes[true_class]
    pred_id = stats['max_pred'][i]
    pred_class = classes[pred_id]
    name = f'{pred_class}\n({true_class})'
    imshow(img, ax, name)
    ax.axis('off')
    
fig.suptitle('Predicted Images')
fig.subplots_adjust(top=0.88)
plt.show()

: 

Here we start reviewing the logits from our MC Inferences with each image and its classification. The more constrained each class is, the more confident the model is in its prediction. If there is lots of variance in a class or more than one class has high logit outputs, then thats an indication that the model isn't sure what the image actually is.

In [None]:
fig, axes = plt.subplots(4,4, figsize=(20,10))
x = np.repeat(test_ds.classes,num_mc_inferences)
i = 100
for j, axs in enumerate(axes.ravel()):
    if j % 2 == 0:
        logits = y_preds[:,i,:].numpy()
        max_pred = np.argmax(np.mean(logits, axis=0))
        img, true_class = test_ds[i]
        true_class = classes[true_class]
        pred_class = classes[max_pred]
        name = f'{pred_class}\n({true_class})'
        imshow(img, axs, name)
        axs.axis('off')
    else:
        axs.scatter(x=x, y=logits.T.flatten())
        axs.tick_params(axis='x',labelrotation=45)
        i += 1
plt.xticks(ha='right')
fig.suptitle('Predicted Images')
plt.tight_layout()
plt.show()

: 

This is the same plot as above, but only looking at the excluded class. I would expect a high degree of uncertainty here; I want the model to self-report that it doesn't know what to do.

In [None]:
fig, axes = plt.subplots(4,4, figsize=(20,10))
x = np.repeat(classes,num_mc_inferences)
y_pred_exclude = np.delete(y_preds, exclude_idx, axis=1)
i = 100
for j, axs in enumerate(axes.ravel()):
    if j % 2 == 0:
        logits = y_pred_exclude[:,i,:].numpy()
        max_pred = np.argmax(np.mean(logits, axis=0))
        pred_class = classes[max_pred]
        # Get the image in the test dataset at index exclude_idx[0][i]
        # We have to do [0] because its a one element tuple
        img, true_class = test_ds[exclude_idx[0][i]]
        true_class = classes[true_class]
        name = f'Predict: {pred_class}\n Actual: {true_class}'
        imshow(img, axs, name)
        axs.axis('off')
    else:
        axs.scatter(x=x, y=logits.T.flatten())
        axs.tick_params(axis='x',labelrotation=45)
        i += 1
plt.xticks(ha='right')
fig.suptitle('Predicted Images')
plt.tight_layout()
plt.show()

: 

: 