# Variational inference (MNIST)

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook
import sys
sys.path.append('..')
sys.path.append('../../torchutils')

In [None]:
#%% import
import pathlib
import numpy as np
import matplotlib.pyplot as plt
import PIL
import torch
import torch.nn as nn
import torch.distributions as dist
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms
from torchutils import ClassifierTraining
from vartorch import \
    VariationalClassifier, \
    VariationalLinear, \
    accuracy_vs_confidence

## Data import

In [None]:
#%% transformations
preprocessor = transforms.Compose([
    transforms.ToTensor()
])

In [None]:
#%% datasets
data_path = pathlib.Path.home() / 'Data'
train_set = datasets.MNIST(data_path,
                           train=True,
                           transform=preprocessor,
                           download=True)
test_set = datasets.MNIST(data_path,
                          train=False,
                          transform=preprocessor,
                          download=True)
print('No. train images:', len(train_set))
print('No. test images:', len(test_set))

In [None]:
#%% data loaders
batch_size = 128
train_loader = DataLoader(train_set,
                          batch_size=batch_size,
                          shuffle=True)
test_loader = DataLoader(test_set,
                         batch_size=batch_size,
                         shuffle=True)
print('No. train batches:', len(train_loader))
print('No. test batches:', len(test_loader))

In [None]:
#%% example images
data_iterator = iter(train_loader)
images, labels = next(data_iterator)
print('Images shape:', images.shape)
print('Labels shape:', labels.shape)

In [None]:
#%% plot: example images
fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(5, 3))
for idx, ax in enumerate(axes.ravel()):
    image = images[idx,0].numpy()
    ax.imshow(image.clip(0,1), cmap='gray')
    ax.set_title(train_set.classes[labels[idx]])
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.tight_layout()
fig.show()

## Standard training

In [None]:
#%% model (logistic regression)
# model = nn.Sequential(
#     nn.Flatten(),
#     nn.Linear(in_features=28*28, out_features=10),
# )
# print(model)

In [None]:
#%% model (small CNN with linear classifier)
model = nn.Sequential(
    nn.Conv2d(in_channels=1, out_channels=4, kernel_size=(5,5), padding=2),
    nn.LeakyReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(in_channels=4, out_channels=8, kernel_size=(5,5), padding=2),
    nn.LeakyReLU(),
    nn.MaxPool2d(2),
    nn.Flatten(),
    nn.Linear(in_features=7*7*8, out_features=10)
)
print(model)

In [None]:
#%% standard model
criterion = nn.CrossEntropyLoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
point_model = ClassifierTraining(model, criterion, optimizer, train_loader, test_loader)

In [None]:
#%% training
point_history = point_model.training(no_epochs=10, log_interval=10)

In [None]:
#%% testing
point_train_loss, point_train_acc = point_model.test(train_loader)
point_test_loss, point_test_acc = point_model.test(test_loader)
print('Train loss: {:.4f}'.format(point_train_loss))
print('Test loss: {:.4f}'.format(point_test_loss))
print('Train acc.: {:.4f}'.format(point_train_acc))
print('Test acc.: {:.4f}'.format(point_test_acc))

In [None]:
#%% plot: training history
fig, ax = plt.subplots()
ax.plot(np.array(point_history['train_loss']), label='training', alpha=0.7)
ax.plot(np.array(point_history['test_loss']), label='testing', alpha=0.7)
ax.set(xlabel='epoch', ylabel='loss')
ax.set_xlim([0, point_history['no_epochs']])
ax.legend()
ax.grid(b=True, which='both', color='lightgray', linestyle='-')
ax.set_axisbelow(True)
fig.tight_layout()
fig.show()

## Variational inference

In [None]:
#%% model (variational logistic regression)
# model = nn.Sequential(
#     nn.Flatten(),
#     VariationalLinear(in_features=28*28, out_features=10),
# )
# print(model)

In [None]:
#%% model (small CNN with variational linear classifier)
model = nn.Sequential(
    nn.Conv2d(in_channels=1, out_channels=4, kernel_size=(5,5), padding=2),
    nn.LeakyReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(in_channels=4, out_channels=8, kernel_size=(5,5), padding=2),
    nn.LeakyReLU(),
    nn.MaxPool2d(2),
    nn.Flatten(),
    VariationalLinear(in_features=7*7*8, out_features=10)
)
print(model)

In [None]:
#%% variational inference
post_model = VariationalClassifier(model, likelihood_type='Categorical')
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
post_model.compile_for_training(optimizer, train_loader, test_loader)

In [None]:
#%% training
post_history = post_model.training(no_epochs=20, no_samples=10, log_interval=10)

In [None]:
#%% testing
post_train_loss = post_model.test_loss(train_loader)
post_train_acc = post_model.test_acc(train_loader)
post_test_loss = post_model.test_loss(test_loader)
post_test_acc = post_model.test_acc(test_loader)
print('Train loss: {:.4f}'.format(post_train_loss))
print('Test loss: {:.4f}'.format(post_test_loss))
print('Train acc.: {:.4f}'.format(post_train_acc))
print('Test acc.: {:.4f}'.format(post_test_acc))

In [None]:
#%% plot: training history
fig, ax = plt.subplots()
ax.plot(-np.array(post_history['train_loss']), label='training', alpha=0.7)
ax.plot(-np.array(post_history['test_loss']), label='testing', alpha=0.7)
ax.set(xlabel='epoch', ylabel='ELBO')
ax.set_xlim([0, post_history['no_epochs']])
ax.legend()
ax.grid(b=True, which='both', color='lightgray', linestyle='-')
ax.set_axisbelow(True)
fig.tight_layout()
fig.show()

## Model predictions

In [None]:
#%% normal MNIST
data_set = test_set
data_loader = test_loader

In [None]:
#%% another MNIST
# data_set = datasets.KMNIST(data_path, train=False, transform=preprocessor, download=True)
# # data_set = datasets.FashionMNIST(data_path, train=False, transform=preprocessor, download=True)
# data_loader = DataLoader(data_set, batch_size=batch_size, shuffle=True)

In [None]:
#%% random noise
# data_set = TensorDataset(torch.rand(batch_size, 1, 28, 28), torch.zeros((batch_size,), dtype=torch.int64))
# data_loader = DataLoader(data_set, batch_size=batch_size, shuffle=True)

In [None]:
#%% example data
data_iterator = iter(data_loader)
images, labels = next(data_iterator)

In [None]:
#%% standard point predictions
point_model.train(False)
with torch.no_grad():
    point_logits = point_model.predict(images)
    point_probs = torch.softmax(point_logits, dim=1)
    point_top_prob, point_top_class = torch.topk(point_probs, k=1, dim=1)
    point_entropy = dist.Categorical(probs=point_probs).entropy()

In [None]:
#%% posterior mean point predictions
# post_model.sample(False)
# post_model.train(False)
# with torch.no_grad():
#     point_logits = post_model.predict(images)
#     point_probs = torch.softmax(point_logits, dim=1)
#     point_top_prob, point_top_class = torch.topk(point_probs, k=1, dim=1)
#     point_entropy = dist.Categorical(probs=point_probs).entropy()

In [None]:
#%% posterior samples predictions
no_samples = 500
post_model.sample(True)
post_model.train(False)
with torch.no_grad():
    sampled_logits = post_model.predict(images, no_samples)
    sampled_probs = torch.softmax(sampled_logits, dim=1)
    sampled_top_prob, sampled_top_class = torch.topk(sampled_probs, k=1, dim=1)
    post_probs = torch.mean(sampled_probs, axis=-1)
    post_top_prob, post_top_class = torch.topk(post_probs, k=1, dim=1)
    is_consistent = sampled_top_class == post_top_class.unsqueeze(-1)
    post_consistency = torch.mean(is_consistent.float(), dim=-1).squeeze()
    post_entropy = dist.Categorical(probs=post_probs).entropy()

In [None]:
#%% plot ids
plot_ids = np.random.permutation(np.arange(len(images))) # random
# plot_ids = torch.argsort(point_entropy, descending=False).data.numpy() # lowest entropy
# plot_ids = torch.argsort(post_consistency, descending=False).data.numpy() # lowest consistency

In [None]:
#%% plot: point pedictions
fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(5,6))
for idx, (ax1, ax2) in zip(plot_ids[:axes.shape[0]], axes):
    # image
    image = images[idx,0].numpy()
    ax1.imshow(image.clip(0,1), cmap='gray')
    ax1.set_title('{}'.format(data_set.classes[labels[idx]])
                  if hasattr(data_set, 'classes') else 'random')
    ax1.set(xticks=[], yticks=[], xlabel='', ylabel='')
    # probabilities
    ax2.bar(np.arange(10), point_probs.data.numpy()[idx])
    ax2.set_title('$\pi(c|x,\hat{w})$')
    ax2.set(xticks=np.arange(10), ylim=[0,1], xlabel='c')
    # ax2.text(0, 0.75, 'entropy: {:.2f}'.format(point_entropy[idx]), alpha=0.5)
fig.tight_layout()
fig.show()

In [None]:
#%% plot: posterior predictions
fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(8,6))
for idx, (ax1, ax3, ax4) in zip(plot_ids[:axes.shape[0]], axes):
    # image
    image = images[idx,0].numpy()
    ax1.imshow(image.clip(0,1), cmap='gray')
    ax1.set_title('{}'.format(data_set.classes[labels[idx]])
                  if hasattr(data_set, 'classes') else 'random noise')
    ax1.set(xticks=[], yticks=[], xlabel='', ylabel='')
    # violin plot
    # ax2.violinplot(sampled_probs[idx,:,:], positions=np.arange(10))
    # ax2.set_title('$\pi(c|x,w)$, $w$ from $\pi(w|\mathcal{D})$')
    # ax2.set(xticks=np.arange(10), ylim=[0,1], xlabel='c')
    # ax2.text(0, 0.75, 'consistency: {:.2f}'.format(post_consistency[idx]), alpha=0.5)
    # histogram
    highest_ids = post_probs[idx].data.numpy().argsort()[::-1][:3]
    for highest_idx in highest_ids:
        ax3.hist(sampled_probs[idx,highest_idx,:], bins=50, range=[0,1],
                 density=True, histtype='stepfilled', alpha=0.5)
    ax3.set_title('$\pi(c|x,w)$, $w$ from $\pi(w|\mathcal{D})$')
    ax3.set_xlim([0,1])
    ax3.legend(['c={}'.format(c) for c in highest_ids], loc='upper center')
    ax3.grid(b=True, which='both', color='lightgray', linestyle='-')
    ax3.set_axisbelow(True)
    # posterior predictive
    ax4.bar(np.arange(10), post_probs[idx].data.numpy())
    ax4.set_title('$\pi(c|x,\mathcal{D}) = \int \pi(c|x,w) \pi(w|\mathcal{D}) dw$')
    ax4.set(xticks=np.arange(10), ylim=[0,1], xlabel='c')
    # ax4.text(0, 0.75, 'entropy: {:.2f}'.format(post_entropy[idx]), alpha=0.5)
fig.tight_layout()
fig.show()

## Out-of-distribution detection

In [None]:
#%% normal MNIST
norm_set = test_set
norm_loader = test_loader

In [None]:
#%% another MNIST
anom_set = datasets.KMNIST(data_path, train=False, transform=preprocessor, download=True)
# anom_set = datasets.FashionMNIST(data_path, train=False, transform=preprocessor, download=True)
anom_loader = DataLoader(anom_set, batch_size=batch_size, shuffle=True)

In [None]:
#%% function definitions
def entropy_and_maxprob(predict_proba, data_loader):
    '''Compute entropy and max. confidence.'''
    probs_list = []
    with torch.no_grad():
        for images, labels in data_loader:
            probs = predict_proba(images)
            probs_list.append(probs)
        probs = torch.cat([probs for probs in probs_list], dim=0)
        entropy = dist.Categorical(probs=probs).entropy()
        top_prob, top_class = torch.topk(probs, k=1, dim=1)
    return entropy, top_prob
point_predict_proba = lambda images: torch.softmax(point_model.predict(images), dim=1)
post_predict_proba = lambda images: post_model.predict_proba(images, no_samples=100)

In [None]:
#%% entropies for normal/anomalous data
point_norm_entropy, point_norm_maxprob = entropy_and_maxprob(point_predict_proba, norm_loader)
point_anom_entropy, point_anom_maxprob = entropy_and_maxprob(point_predict_proba, anom_loader)
post_norm_entropy, post_norm_maxprob = entropy_and_maxprob(post_predict_proba, norm_loader)
post_anom_entropy, post_anom_maxprob = entropy_and_maxprob(post_predict_proba, anom_loader)

In [None]:
#%% plot: point entropy histogram
fig, ax = plt.subplots(figsize=(5,3.5))
ax.hist(point_norm_entropy.data.numpy().squeeze(),
        bins=100, range=(0,2), density=True, histtype='stepfilled',
        alpha=0.7, label='in distribution')
ax.hist(point_anom_entropy.data.numpy().squeeze(),
        bins=100, range=(0,2), density=True, histtype='stepfilled',
        alpha=0.7, label='out of distribution')
ax.set(xlim=[0,2], xlabel='entropy', ylabel='density')
ax.set_title('point predictions')
ax.legend(loc='upper right')
ax.grid(b=True, which='both', color='lightgray', linestyle='-')
ax.set_axisbelow(True)
fig.tight_layout()
fig.show()

In [None]:
#%% plot: post entropy histogram
fig, ax = plt.subplots(figsize=(5,3.5))
ax.hist(post_norm_entropy.data.numpy().squeeze(),
        bins=100, range=(0,2), density=True, histtype='stepfilled',
        alpha=0.7, label='in distribution')
ax.hist(post_anom_entropy.data.numpy().squeeze(),
        bins=100, range=(0,2), density=True, histtype='stepfilled',
        alpha=0.7, label='out of distribution')
ax.set(xlim=[0,2], xlabel='entropy', ylabel='density')
ax.set_title('posterior predictive')
ax.legend(loc='upper right')
ax.grid(b=True, which='both', color='lightgray', linestyle='-')
ax.set_axisbelow(True)
fig.tight_layout()
fig.show()

## Confidence calibration

In [None]:
#%% normal MNIST
# data_set = test_set
# data_loader = test_loader

In [None]:
#%% rotated MNIST
max_rotation = 35
preprocessor_with_noise = transforms.Compose([
    transforms.RandomRotation(degrees=max_rotation, resample=PIL.Image.BILINEAR, fill=(0,)),
    transforms.ToTensor()
])
data_set = datasets.MNIST(data_path, train=False, transform=preprocessor_with_noise, download=True)
data_loader = DataLoader(data_set, batch_size=batch_size, shuffle=True)

In [None]:
#%% accuracies
point_loss, point_acc = point_model.test(data_loader)
post_acc = post_model.test_acc(data_loader, no_samples=1)
print('Point acc.: {:.4f}'.format(point_acc))
print('Post acc.: {:.4f}'.format(post_acc))

In [None]:
#%% accuracy vs. confidence (point predictions)
point_conf_edges, point_bin_accs = accuracy_vs_confidence(point_model,
                                                          data_loader,
                                                          likelihood_type='Categorical')

In [None]:
#%% plot: accuracy vs. confidence (point predictions)
fig, ax = plt.subplots()
ax.bar(point_conf_edges[0:-1], point_bin_accs,
       width=np.diff(point_conf_edges), align='edge',
       alpha=0.7, edgecolor='black')
ax.plot([0,1], [0,1], color='gray', linestyle='--')
ax.set(xlim=[0,1], ylim=[0,1], xlabel='confidence', ylabel='accuracy')
ax.set_title('point predictions')
ax.grid(b=True, which='both', color='lightgray', linestyle='-')
ax.set_axisbelow(True)
fig.tight_layout()
fig.show()

In [None]:
#%% accuracy vs. confidence (posterior predictive)
post_conf_edges, post_bin_accs = accuracy_vs_confidence(post_model,
                                                        data_loader,
                                                        likelihood_type='Categorical',
                                                        no_samples=100)

In [None]:
#%% plot: accuracy vs. confidence (posterior predictive)
fig, ax = plt.subplots()
ax.bar(post_conf_edges[0:-1], post_bin_accs,
       width=np.diff(post_conf_edges), align='edge',
       alpha=0.7, edgecolor='black')
ax.plot([0,1], [0,1], color='gray', linestyle='--')
ax.set(xlim=[0,1], ylim=[0,1], xlabel='confidence', ylabel='accuracy')
ax.set_title('posterior predictive')
ax.grid(b=True, which='both', color='lightgray', linestyle='-')
ax.set_axisbelow(True)
fig.tight_layout()
fig.show()