# Variational Bayesian inference (MNIST)

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook
import sys
sys.path.append('..')
sys.path.append('../../torchutils')

In [None]:
import pathlib
import numpy as np
import matplotlib.pyplot as plt
import PIL
import torch
import torch.nn as nn
import torch.distributions as dist
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms
from torchutils import Classification, confusion_matrix
from vartorch import \
    VariationalClassification, \
    VariationalLinear, \
    anomaly_score

## Data import

In [None]:
#%% transformations
preprocessor = transforms.Compose([
    transforms.ToTensor()
])

In [None]:
#%% datasets
data_path = pathlib.Path.home() / 'Data'
train_set = datasets.MNIST(data_path,
                           train=True,
                           transform=preprocessor,
                           download=True)
test_set = datasets.MNIST(data_path,
                          train=False,
                          transform=preprocessor,
                          download=True)
print('No. train images:', len(train_set))
print('No. test images:', len(test_set))

In [None]:
#%% data loaders
batch_size = 128
train_loader = DataLoader(train_set,
                          batch_size=batch_size,
                          shuffle=True)
test_loader = DataLoader(test_set,
                         batch_size=batch_size,
                         shuffle=True)
print('No. train batches:', len(train_loader))
print('No. test batches:', len(test_loader))

In [None]:
#%% example images
images, labels = next(iter(train_loader))
print('Images shape:', images.shape)
print('Labels shape:', labels.shape)

In [None]:
#%% plot: example images
fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(5,3))
for idx, ax in enumerate(axes.ravel()):
    image = images[idx,0].numpy()
    ax.imshow(image.clip(0,1), cmap='gray')
    ax.set_title(train_set.classes[labels[idx]])
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.tight_layout()
fig.show()

## Standard training

In [None]:
#%% model (logistic regression)
# model1 = nn.Sequential(
#     nn.Flatten(),
#     nn.Linear(in_features=28*28, out_features=10),
# )
# print(model1)

In [None]:
#%% model (small CNN with linear classifier)
model1 = nn.Sequential(
    nn.Conv2d(in_channels=1, out_channels=4, kernel_size=(5,5), padding=2),
    nn.LeakyReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(in_channels=4, out_channels=8, kernel_size=(5,5), padding=2),
    nn.LeakyReLU(),
    nn.MaxPool2d(2),
    nn.Flatten(),
    nn.Linear(in_features=7*7*8, out_features=10)
)
print(model1)

In [None]:
#%% standard model
criterion = nn.CrossEntropyLoss(reduction='mean')
optimizer = torch.optim.Adam(model1.parameters(), lr=0.01)
point_model = Classification(model1, criterion, optimizer, train_loader, test_loader)

In [None]:
#%% training
point_history = point_model.training(no_epochs=10, log_interval=10)

In [None]:
#%% plot: training history
fig, ax = plt.subplots(figsize=(6,4))
ax.plot(np.array(point_history['train_loss']), label='training', alpha=0.7)
ax.plot(np.array(point_history['test_loss']), label='testing', alpha=0.7)
ax.set(xlabel='epoch', ylabel='loss')
ax.set_xlim([0, point_history['no_epochs']])
ax.legend()
ax.grid(b=True, which='both', color='lightgray', linestyle='-')
ax.set_axisbelow(True)
fig.tight_layout()
fig.show()

In [None]:
#%% final loss/accuracy
point_train_loss, point_train_acc = point_model.test(train_loader)
point_test_loss, point_test_acc = point_model.test(test_loader)
print('Train loss: {:.4f}'.format(point_train_loss))
print('Test loss: {:.4f}'.format(point_test_loss))
print('Train acc.: {:.4f}'.format(point_train_acc))
print('Test acc.: {:.4f}'.format(point_test_acc))

In [None]:
#%% confusion matrix
confmat = confusion_matrix(point_model, test_loader)
print('Confusion matrix:\n{}'.format(confmat))

## Variational inference

In [None]:
#%% model (variational logistic regression)
# model2 = nn.Sequential(
#     nn.Flatten(),
#     VariationalLinear(in_features=28*28, out_features=10),
# )
# print(model2)

In [None]:
#%% model (small CNN with variational linear classifier)
model2 = nn.Sequential(
    nn.Conv2d(in_channels=1, out_channels=4, kernel_size=(5,5), padding=2),
    nn.LeakyReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(in_channels=4, out_channels=8, kernel_size=(5,5), padding=2),
    nn.LeakyReLU(),
    nn.MaxPool2d(2),
    nn.Flatten(),
    VariationalLinear(in_features=7*7*8, out_features=10)
)
print(model2)

In [None]:
#%% variational inference
post_model = VariationalClassification(model2, likelihood_type='Categorical')
optimizer = torch.optim.Adam(model2.parameters(), lr=0.01)
post_model.compile_for_training(optimizer, train_loader, test_loader)

In [None]:
#%% training
post_history = post_model.training(no_epochs=20, no_samples=10, log_interval=10)

In [None]:
#%% plot: training history
fig, ax = plt.subplots(figsize=(6,4))
ax.plot(-np.array(post_history['train_loss']), label='training', alpha=0.7)
ax.plot(-np.array(post_history['test_loss']), label='testing', alpha=0.7)
ax.set(xlabel='epoch', ylabel='ELBO')
ax.set_xlim([0, post_history['no_epochs']])
ax.legend()
ax.grid(b=True, which='both', color='lightgray', linestyle='-')
ax.set_axisbelow(True)
fig.tight_layout()
fig.show()

In [None]:
#%% final loss/accuracy
post_train_loss = post_model.test_loss(train_loader)
post_train_acc = post_model.test_acc(train_loader)
post_test_loss = post_model.test_loss(test_loader)
post_test_acc = post_model.test_acc(test_loader)
print('Train loss: {:.4f}'.format(post_train_loss))
print('Test loss: {:.4f}'.format(post_test_loss))
print('Train acc.: {:.4f}'.format(post_train_acc))
print('Test acc.: {:.4f}'.format(post_test_acc))

In [None]:
#%% confusion matrix
confmat = confusion_matrix(post_model, test_loader, no_samples=100)
print('Confusion matrix:\n{}'.format(confmat))

## Example predictions

In [None]:
#%% datasets and loaders
norm_set = test_set
norm_loader = test_loader
anom_set = datasets.KMNIST(data_path, train=False, transform=preprocessor, download=True) # KMNIST
# anom_set = datasets.FashionMNIST(data_path, train=False, transform=preprocessor, download=True) # FashionMNIST
# anom_set = TensorDataset(torch.rand(batch_size, 1, 28, 28), # random noise
#                          torch.zeros((batch_size,), dtype=torch.int64))
anom_loader = DataLoader(anom_set, batch_size=batch_size, shuffle=True)

In [None]:
#%% example data
norm_images, norm_labels = next(iter(norm_loader))
anom_images, anom_labels = next(iter(anom_loader))

In [None]:
#%% standard point predictions
point_model.train(False)
with torch.no_grad():
    point_norm_probs = point_model.predict_proba(norm_images.to(point_model.device)).cpu()
    point_norm_entropy = dist.Categorical(probs=point_norm_probs).entropy()
    point_anom_probs = point_model.predict_proba(anom_images.to(point_model.device)).cpu()
    point_anom_entropy = dist.Categorical(probs=point_anom_probs).entropy()

In [None]:
#%% posterior sample predictions
no_samples = 500
post_model.sample(True)
post_model.train(False)
with torch.no_grad():
    sampled_norm_logits = post_model.predict(norm_images.to(post_model.device), no_samples).cpu()
    sampled_norm_probs = torch.softmax(sampled_norm_logits, dim=1)
    post_norm_probs = torch.mean(sampled_norm_probs, axis=-1)
    post_norm_entropy = dist.Categorical(probs=post_norm_probs).entropy()
    sampled_anom_logits = post_model.predict(anom_images.to(post_model.device), no_samples).cpu()
    sampled_anom_probs = torch.softmax(sampled_anom_logits, dim=1)
    post_anom_probs = torch.mean(sampled_anom_probs, axis=-1)
    post_anom_entropy = dist.Categorical(probs=post_anom_probs).entropy()

In [None]:
#%% plot: point predictions (in distribution)
plot_ids = np.random.permutation(np.arange(len(images))) # random
# plot_ids = torch.argsort(point_norm_entropy, descending=False).data.numpy() # lowest entropy
fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(5,6))
for idx, (ax1, ax2) in zip(plot_ids[:axes.shape[0]], axes):
    # image
    image = norm_images[idx,0].numpy()
    ax1.imshow(image.clip(0,1), cmap='gray')
    ax1.set_title('{}'.format(norm_set.classes[norm_labels[idx]])
                  if hasattr(norm_set, 'classes') else 'random')
    ax1.set(xticks=[], yticks=[], xlabel='', ylabel='')
    # probabilities
    ax2.bar(np.arange(10), point_norm_probs.data.numpy()[idx])
    ax2.set_title('$\pi(c|x,\hat{w})$')
    ax2.set(xticks=np.arange(10), ylim=[0,1], xlabel='c')
    # ax2.text(0, 0.75, 'entropy: {:.2f}'.format(point_norm_entropy[idx]), alpha=0.5)
fig.suptitle('point predictions (in distribution)')
fig.tight_layout()
fig.show()

In [None]:
#%% plot: point predictions (out of distribution)
plot_ids = np.random.permutation(np.arange(len(images))) # random
# plot_ids = torch.argsort(point_anom_entropy, descending=False).data.numpy() # lowest entropy
fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(5,6))
for idx, (ax1, ax2) in zip(plot_ids[:axes.shape[0]], axes):
    # image
    image = anom_images[idx,0].numpy()
    ax1.imshow(image.clip(0,1), cmap='gray')
    ax1.set_title('{}'.format(anom_set.classes[anom_labels[idx]])
                  if hasattr(anom_set, 'classes') else 'random')
    ax1.set(xticks=[], yticks=[], xlabel='', ylabel='')
    # probabilities
    ax2.bar(np.arange(10), point_anom_probs.data.numpy()[idx])
    ax2.set_title('$\pi(c|x,\hat{w})$')
    ax2.set(xticks=np.arange(10), ylim=[0,1], xlabel='c')
    # ax2.text(0, 0.75, 'entropy: {:.2f}'.format(point_anom_entropy[idx]), alpha=0.5)
fig.suptitle('point predictions (out of distribution)')
fig.tight_layout()
fig.show()

In [None]:
#%% plot: posterior predictions (in distribution)
fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(8,6))
for idx, (ax1, ax2, ax3) in zip(plot_ids[:axes.shape[0]], axes):
    # image
    image = norm_images[idx,0].numpy()
    ax1.imshow(image.clip(0,1), cmap='gray')
    ax1.set_title('{}'.format(norm_set.classes[norm_labels[idx]])
                  if hasattr(norm_set, 'classes') else 'random noise')
    ax1.set(xticks=[], yticks=[], xlabel='', ylabel='')
    # violin plot
    # ax2.violinplot(sampled_norm_probs[idx,:,:], positions=np.arange(10))
    # ax2.set_title('$\pi(c|x,w)$, $w$ from $\pi(w|\mathcal{D})$')
    # ax2.set(xticks=np.arange(10), ylim=[0,1], xlabel='c')
    # histogram
    highest_ids = post_norm_probs[idx].data.numpy().argsort()[::-1][:3]
    for highest_idx in highest_ids:
        ax2.hist(sampled_norm_probs[idx,highest_idx,:].data.numpy(), bins=50,
                 range=[0,1], density=True, histtype='stepfilled', alpha=0.5)
    ax2.set_title('$\pi(c|x,w)$, $w$ from $\pi(w|\mathcal{D})$')
    ax2.set_xlim([0,1])
    ax2.legend(['c={}'.format(c) for c in highest_ids], loc='upper center')
    ax2.grid(b=True, which='both', color='lightgray', linestyle='-')
    ax2.set_axisbelow(True)
    # posterior predictive
    ax3.bar(np.arange(10), post_norm_probs[idx].data.numpy())
    ax3.set_title('$\pi(c|x,\mathcal{D}) = \int \pi(c|x,w) \pi(w|\mathcal{D}) dw$')
    ax3.set(xticks=np.arange(10), ylim=[0,1], xlabel='c')
    # ax3.text(0, 0.75, 'entropy: {:.2f}'.format(post_norm_entropy[idx]), alpha=0.5)
fig.suptitle('posterior predictions (in distribution)')
fig.tight_layout()
fig.show()

In [None]:
#%% plot: posterior predictions (out of distribution)
fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(8,6))
for idx, (ax1, ax2, ax3) in zip(plot_ids[:axes.shape[0]], axes):
    # image
    image = anom_images[idx,0].numpy()
    ax1.imshow(image.clip(0,1), cmap='gray')
    ax1.set_title('{}'.format(anom_set.classes[anom_labels[idx]])
                  if hasattr(anom_set, 'classes') else 'random noise')
    ax1.set(xticks=[], yticks=[], xlabel='', ylabel='')
    # violin plot
    # ax2.violinplot(sampled_anom_probs[idx,:,:], positions=np.arange(10))
    # ax2.set_title('$\pi(c|x,w)$, $w$ from $\pi(w|\mathcal{D})$')
    # ax2.set(xticks=np.arange(10), ylim=[0,1], xlabel='c')
    # histogram
    highest_ids = post_anom_probs[idx].data.numpy().argsort()[::-1][:3]
    for highest_idx in highest_ids:
        ax2.hist(sampled_anom_probs[idx,highest_idx,:].data.numpy(), bins=50,
                 range=[0,1], density=True, histtype='stepfilled', alpha=0.5)
    ax2.set_title('$\pi(c|x,w)$, $w$ from $\pi(w|\mathcal{D})$')
    ax2.set_xlim([0,1])
    ax2.legend(['c={}'.format(c) for c in highest_ids], loc='upper center')
    ax2.grid(b=True, which='both', color='lightgray', linestyle='-')
    ax2.set_axisbelow(True)
    # posterior predictive
    ax3.bar(np.arange(10), post_anom_probs[idx].data.numpy())
    ax3.set_title('$\pi(c|x,\mathcal{D}) = \int \pi(c|x,w) \pi(w|\mathcal{D}) dw$')
    ax3.set(xticks=np.arange(10), ylim=[0,1], xlabel='c')
    # ax3.text(0, 0.75, 'entropy: {:.2f}'.format(post_anom_entropy[idx]), alpha=0.5)
fig.suptitle('posterior predictions (out of distribution)')
fig.tight_layout()
fig.show()

## Out-of-distribution detection

In [None]:
#%% datasets and loaders
norm_loader = test_loader # MNIST
anom_set = datasets.KMNIST(data_path, train=False, transform=preprocessor, download=True) # KMNIST
# anom_set = datasets.FashionMNIST(data_path, train=False, transform=preprocessor, download=True) # FashionMNIST
# anom_set = TensorDataset(torch.rand(batch_size, 1, 28, 28), # random noise
#                          torch.zeros((batch_size,), dtype=torch.int64))
anom_loader = DataLoader(anom_set, batch_size=batch_size, shuffle=True)

In [None]:
#%% anomaly scores
point_norm_entropy = anomaly_score(point_model, norm_loader, mode='entropy')
point_anom_entropy = anomaly_score(point_model, anom_loader, mode='entropy')
post_norm_entropy = anomaly_score(post_model, norm_loader, mode='entropy', no_samples=100)
post_anom_entropy = anomaly_score(post_model, anom_loader, mode='entropy', no_samples=100)

In [None]:
#%% plot: point entropy histogram
fig, ax = plt.subplots(figsize=(6,4))
ax.hist(point_norm_entropy, bins=100, range=(0,2), density=True,
        histtype='stepfilled', alpha=0.7, label='in distribution')
ax.hist(point_anom_entropy, bins=100, range=(0,2), density=True,
        histtype='stepfilled', alpha=0.7, label='out of distribution')
ax.set(xlim=[0,2], xlabel='entropy', ylabel='density')
ax.set_title('point predictions')
ax.legend(loc='upper right')
ax.grid(b=True, which='both', color='lightgray', linestyle='-')
ax.set_axisbelow(True)
fig.tight_layout()
fig.show()

In [None]:
#%% plot: posterior entropy histogram
fig, ax = plt.subplots(figsize=(6,4))
ax.hist(post_norm_entropy, bins=100, range=(0,2), density=True,
        histtype='stepfilled', alpha=0.7, label='in distribution')
ax.hist(post_anom_entropy, bins=100, range=(0,2), density=True,
        histtype='stepfilled', alpha=0.7, label='out of distribution')
ax.set(xlim=[0,2], xlabel='entropy', ylabel='density')
ax.set_title('posterior predictive')
ax.legend(loc='upper right')
ax.grid(b=True, which='both', color='lightgray', linestyle='-')
ax.set_axisbelow(True)
fig.tight_layout()
fig.show()