# Variational Bayesian inference (half-moons)

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook

import sys
sys.path.append('..')
sys.path.append('../../torchutils')

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, PolynomialFeatures
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torchutils import Classification

from vartorch import VariationalClassification, VariationalLinear

## Plotting functions

In [None]:
def plot_data(X, y, colors=[plt.cm.Set1(1), plt.cm.Set1(0)], ax=None):
    '''Plot two sampled classes on a two-dimensional plane.'''
    if ax is None:
        fig, ax = plt.subplots()

    ax.scatter(
        X[y==0, 0], X[y==0, 1], color=colors[0], alpha=0.7, edgecolors='none', label='y=0'
    )

    ax.scatter(
        X[y==1, 0], X[y==1, 1], color=colors[1], alpha=0.7, edgecolors='none', label='y=1'
    )

    ax.set_xlabel('$x_1$')
    ax.set_ylabel('$x_2$')
    return ax

In [None]:
def plot_function(function,
                  levels=(0.1, 0.3, 0.5, 0.7, 0.9),
                  x_limits=None,
                  y_limits=None,
                  colorbar=True,
                  ax=None):
    '''Plot a function of two features on the plane.'''
    if ax is None:
        fig, ax = plt.subplots()

    if x_limits is None:
        x_limits = ax.get_xlim()

    if y_limits is None:
        y_limits = ax.get_ylim()

    x_values = np.linspace(*x_limits, num=201)
    y_values = np.linspace(*y_limits, num=201)

    (X_values, Y_values) = np.meshgrid(x_values, y_values)
    Z_values = function(np.stack((X_values.ravel(), Y_values.ravel()), axis=1)).reshape(X_values.shape)
    
    im1 = ax.imshow(Z_values, origin='lower', extent=(*x_limits,*y_limits),
                    interpolation='bicubic', cmap='Greys', alpha=0.4) # vmin=0, vmax=1
    im2 = ax.contour(X_values, Y_values, Z_values, levels, colors='black', alpha=0.6)

    if colorbar:
        plt.colorbar(im1)

    plt.clabel(im2, fmt='%1.2f')

    return ax

## Half-moons data

In [None]:
no_samples = 500
noise_level = 0.15
X, y = make_moons(no_samples, shuffle=True, noise=noise_level)
X[y==0,1] += 0.15
X[y==1,1] += -0.15

In [None]:
fig, ax = plt.subplots()
plot_data(X, y, ax=ax)
ax.set(xlim=(-2, 3), ylim=(-2, 2.5))
ax.set_aspect('equal', adjustable='box')
ax.legend(loc='upper right')
ax.grid(visible=True, which='both', color='gray', alpha=0.2, linestyle='-')
ax.set_axisbelow(True)
fig.tight_layout()
fig.show()

In [None]:
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2)

In [None]:
original_scaler = StandardScaler()
X_train_normalized = original_scaler.fit_transform(X_train)
X_val_normalized = original_scaler.transform(X_val)

In [None]:
polynomial_features = PolynomialFeatures(degree=5, interaction_only=False, include_bias=False)
X_train_poly = polynomial_features.fit_transform(X_train_normalized)
X_val_poly = polynomial_features.transform(X_val_normalized)

no_features = X_train_poly.shape[1]
print('No. features:', no_features)
print('\nFeatures:', polynomial_features.get_feature_names_out())

In [None]:
polynomial_scaler = StandardScaler()
X_train_final = polynomial_scaler.fit_transform(X_train_poly)
X_val_final = polynomial_scaler.transform(X_val_poly)

In [None]:
X_train_tensor = torch.tensor(X_train_final, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.int64)

X_val_tensor = torch.tensor(X_val_final, dtype=torch.float32)
y_val_tensor = torch.tensor(y_val, dtype=torch.int64)

In [None]:
train_set = TensorDataset(X_train_tensor, y_train_tensor)
val_set = TensorDataset(X_val_tensor, y_val_tensor)

train_loader = DataLoader(
    train_set,batch_size=len(train_set), shuffle=True
)

val_loader = DataLoader(
    val_set, batch_size=len(val_set), shuffle=False
)

## Standard training

In [None]:
# create logistic regression model
model1 = nn.Linear(in_features=no_features, out_features=1)

In [None]:
criterion = nn.BCEWithLogitsLoss(reduction='mean')
optimizer = torch.optim.Adam(model1.parameters(), lr=0.1)

point_model = Classification(
    model1,
    criterion,
    optimizer,
    train_loader,
    val_loader
)

In [None]:
point_history = point_model.training(no_epochs=500, log_interval=None)

In [None]:
point_train_loss, point_train_acc = point_model.test(train_loader)
point_val_loss, point_val_acc = point_model.test(val_loader)

print('Train loss: {:.4f}'.format(point_train_loss))
print('Val. loss: {:.4f}'.format(point_val_loss))
print('\nTrain acc.: {:.4f}'.format(point_train_acc))
print('Val. acc.: {:.4f}'.format(point_val_acc))

In [None]:
fig, ax = plt.subplots(figsize=(6, 4))
ax.plot(np.array(point_history['train_loss']), label='train', alpha=0.7)
ax.plot(np.array(point_history['val_loss']), label='val.', alpha=0.7)
ax.set(xlabel='epoch', ylabel='loss')
ax.set_xlim((0, point_history['no_epochs']))
ax.legend()
ax.grid(visible=True, which='both', color='lightgray', linestyle='-')
ax.set_axisbelow(True)
fig.tight_layout()
fig.show()

## Variational inference

In [None]:
# create variational logistic regression model
model2 = VariationalLinear(in_features=no_features, out_features=1, weight_std=5)

In [None]:
optimizer = torch.optim.Adam(model2.parameters(), lr=0.1)

post_model = VariationalClassification(
    model2, likelihood_type='Bernoulli'
)

post_model.compile_for_training(
    optimizer, train_loader, val_loader
)

In [None]:
post_history = post_model.training(no_epochs=500, no_samples=20, log_interval=None)

In [None]:
post_train_loss = post_model.test_loss(train_loader)
post_train_acc = post_model.test_acc(train_loader)
post_val_loss = post_model.test_loss(val_loader)
post_val_acc = post_model.test_acc(val_loader)

print('Train loss: {:.4f}'.format(post_train_loss))
print('Val. loss: {:.4f}'.format(post_val_loss))
print('\nTrain acc.: {:.4f}'.format(post_train_acc))
print('Val. acc.: {:.4f}'.format(post_val_acc))

In [None]:
fig, ax = plt.subplots(figsize=(6, 4))
ax.plot(-np.array(post_history['train_loss']), label='train', alpha=0.7)
ax.plot(-np.array(post_history['val_loss']), label='val.', alpha=0.7)
ax.set(xlabel='epoch', ylabel='ELBO')
ax.set_xlim((0, post_history['no_epochs']))
ax.legend()
ax.grid(visible=True, which='both', color='lightgray', linestyle='-')
ax.set_axisbelow(True)
fig.tight_layout()
fig.show()

## Example predictions

In [None]:
def transform_features(x):
    '''Transform features.'''
    x_normalized = original_scaler.transform(x)
    x_poly = polynomial_features.transform(x_normalized)
    x_final = polynomial_scaler.transform(x_poly)
    x_tensor = torch.tensor(x_final, dtype=torch.float32)
    return x_tensor

def point_prediction(x):
    '''Compute normal point predictions.'''
    x_tensor = transform_features(x)
    point_model.train(False)
    with torch.no_grad():
        point_logits = point_model.predict(x_tensor.to(point_model.device)).cpu()
        point_probs = torch.sigmoid(point_logits)
    return point_probs.data.numpy()

def posterior_mean(x):
    '''Predict with posterior mean weights.'''
    x_tensor = transform_features(x)
    post_model.sample(False)
    post_model.train(False)
    with torch.no_grad():
        point_logits = post_model.predict(x_tensor.to(post_model.device)).cpu()
        point_probs = torch.sigmoid(point_logits)
    return point_probs.data.numpy()

def posterior_predictive(x, no_samples=1000):
    '''Predict according to the posterior predictive distribution.'''
    x_tensor = transform_features(x)
    post_model.sample(True)
    post_model.train(False)
    with torch.no_grad():
        sampled_logits = post_model.predict(x_tensor.to(post_model.device), no_samples).cpu()
        sampled_probs = torch.sigmoid(sampled_logits)
    post_mean = torch.mean(sampled_probs, axis=-1)
    return post_mean.data.numpy()

def posterior_uncertainty(x, no_samples=1000):
    '''Compute the uncertainty associated with the posterior predictive.'''
    x_tensor = transform_features(x)
    post_model.sample(True)
    post_model.train(False)
    with torch.no_grad():
        sampled_logits = post_model.predict(x_tensor.to(post_model.device), no_samples).cpu()
        sampled_probs = torch.sigmoid(sampled_logits)
    post_std = torch.std(sampled_probs, axis=-1)
    return post_std.data.numpy()

In [None]:
fig, ax = plt.subplots()
plot_data(X_train, y_train, ax=ax)
ax.set(xlim=(-2, 3), ylim=(-2, 2.5))
plot_function(point_prediction, levels=(0.3, 0.5, 0.7), ax=ax)
ax.set_aspect('equal', adjustable='box')
ax.set_title('point predictions')
ax.legend(loc='upper left')
fig.tight_layout()
fig.show()

In [None]:
fig, ax = plt.subplots()
plot_data(X_train, y_train, ax=ax)
ax.set(xlim=(-2, 3), ylim=(-2, 2.5))
plot_function(posterior_mean, levels=(0.3, 0.5, 0.7), ax=ax)
ax.set_aspect('equal', adjustable='box')
ax.set_title('posterior mean')
ax.legend(loc='upper left')
fig.tight_layout()
fig.show()

In [None]:
fig, ax = plt.subplots()
plot_data(X_train, y_train, ax=ax)
ax.set(xlim=(-2, 3), ylim=(-2, 2.5))
plot_function(posterior_predictive, levels=(0.1, 0.3, 0.5, 0.7, 0.9), ax=ax)
ax.set_aspect('equal', adjustable='box')
ax.set_title('posterior predictive')
ax.legend(loc='upper left')
fig.tight_layout()
fig.show()

In [None]:
fig, ax = plt.subplots()
plot_data(X_train, y_train, ax=ax)
ax.set(xlim=(-2, 3), ylim=(-2, 2.5))
plot_function(posterior_uncertainty, levels=np.linspace(0.1, 0.9, 9), ax=ax)
ax.set_aspect('equal', adjustable='box')
ax.set_title('posterior uncertainty')
ax.legend(loc='upper left')
fig.tight_layout()
fig.show()