# Bayesian neural net: half-moons

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.append('..')
sys.path.append('../../torchutils')

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, PolynomialFeatures
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

from torchutils import Classification

from vartorch import (
    plot_data_2d,
    plot_function_2d,
    VariationalLinear,
    VariationalClassification
)

In [None]:
# set random seeds manually
np.random.seed(123456789)
_ = torch.manual_seed(987654321)

## Half-moons data

In [None]:
num_samples = 500
noise_level = 0.15

X, y = make_moons(num_samples, shuffle=True, noise=noise_level)

X[y==0, 1] += 0.15
X[y==1, 1] += -0.15

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))
plot_data_2d(X, y, colors=(plt.cm.Set1(1), plt.cm.Set1(0)), ax=ax)
ax.set(xlim=(-2, 3), ylim=(-2, 2.5))
ax.set_aspect('equal', adjustable='box')
ax.legend(loc='upper right')
ax.grid(visible=True, which='both', color='gray', alpha=0.2, linestyle='-')
ax.set_axisbelow(True)
fig.tight_layout()

In [None]:
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2)

In [None]:
original_scaler = StandardScaler()

X_train_normalized = original_scaler.fit_transform(X_train)
X_val_normalized = original_scaler.transform(X_val)

In [None]:
polynomial_features = PolynomialFeatures(
    degree=5,
    interaction_only=False,
    include_bias=False
)

X_train_poly = polynomial_features.fit_transform(X_train_normalized)
X_val_poly = polynomial_features.transform(X_val_normalized)

num_features = X_train_poly.shape[1]
print('No. features:', num_features)
print('\nFeatures:', polynomial_features.get_feature_names_out())

In [None]:
polynomial_scaler = StandardScaler()

X_train_final = polynomial_scaler.fit_transform(X_train_poly)
X_val_final = polynomial_scaler.transform(X_val_poly)

In [None]:
train_set = TensorDataset(
    torch.tensor(X_train_final, dtype=torch.float32),
    torch.tensor(y_train, dtype=torch.int64)
)

val_set = TensorDataset(
    torch.tensor(X_val_final, dtype=torch.float32),
    torch.tensor(y_val, dtype=torch.int64)
)

train_loader = DataLoader(
    train_set,
    batch_size=len(train_set),
    shuffle=True
)

val_loader = DataLoader(
    val_set,
    batch_size=len(val_set),
    shuffle=False
)

## Standard training

In [None]:
# create logistic regression model
model1 = nn.Linear(
    in_features=num_features,
    out_features=1
)

In [None]:
criterion = nn.BCEWithLogitsLoss(reduction='mean')
optimizer = torch.optim.Adam(model1.parameters(), lr=0.1)

point_model = Classification(
    model1,
    criterion,
    optimizer,
    train_loader,
    val_loader
)

In [None]:
point_history = point_model.training(
    num_epochs=500,
    log_interval=None
)

In [None]:
point_train_loss, point_train_acc = point_model.test(train_loader)
point_val_loss, point_val_acc = point_model.test(val_loader)

print('Train loss: {:.4f}'.format(point_train_loss))
print('Val. loss: {:.4f}'.format(point_val_loss))
print('\nTrain acc.: {:.4f}'.format(point_train_acc))
print('Val. acc.: {:.4f}'.format(point_val_acc))

In [None]:
fig, ax = plt.subplots(figsize=(6, 4))
ax.plot(np.array(point_history['train_loss']), label='train', alpha=0.7)
ax.plot(np.array(point_history['val_loss']), label='val.', alpha=0.7)
ax.set(xlabel='epoch', ylabel='loss')
ax.set_xlim((0, point_history['num_epochs']))
ax.legend()
ax.grid(visible=True, which='both', color='lightgray', linestyle='-')
ax.set_axisbelow(True)
fig.tight_layout()

## Variational inference

In [None]:
# create variational logistic regression model
model2 = VariationalLinear(
    in_features=num_features,
    out_features=1,
    weight_std=5
)

In [None]:
optimizer = torch.optim.Adam(model2.parameters(), lr=0.1)

post_model = VariationalClassification(
    model2,
    likelihood_type='Bernoulli'
)

post_model.compile_for_training(
    optimizer,
    train_loader,
    val_loader
)

In [None]:
post_history = post_model.training(
    num_epochs=500,
    num_samples=20,
    log_interval=None
)

In [None]:
post_train_loss = post_model.test_loss(train_loader)
post_train_acc = post_model.test_acc(train_loader)

post_val_loss = post_model.test_loss(val_loader)
post_val_acc = post_model.test_acc(val_loader)

print('Train loss: {:.4f}'.format(post_train_loss))
print('Val. loss: {:.4f}'.format(post_val_loss))
print('\nTrain acc.: {:.4f}'.format(post_train_acc))
print('Val. acc.: {:.4f}'.format(post_val_acc))

In [None]:
fig, ax = plt.subplots(figsize=(6, 4))
ax.plot(-np.array(post_history['train_loss']), label='train', alpha=0.7)
ax.plot(-np.array(post_history['val_loss']), label='val.', alpha=0.7)
ax.set(xlabel='epoch', ylabel='ELBO')
ax.set_xlim((0, post_history['num_epochs']))
ax.legend()
ax.grid(visible=True, which='both', color='lightgray', linestyle='-')
ax.set_axisbelow(True)
fig.tight_layout()

## Example predictions

In [None]:
def transform_features(x):
    '''Transform features.'''
    x_normalized = original_scaler.transform(x)
    x_poly = polynomial_features.transform(x_normalized)
    x_final = polynomial_scaler.transform(x_poly)

    x_tensor = torch.tensor(x_final, dtype=torch.float32)
    return x_tensor

@torch.no_grad()
def point_prediction(x):
    '''Compute normal point predictions.'''
    x_tensor = transform_features(x)

    point_model.train(False)

    point_logits = point_model.predict(x_tensor.to(point_model.device))
    point_probs = torch.sigmoid(point_logits)
    return point_probs.cpu().numpy()

@torch.no_grad()
def posterior_mean(x):
    '''Predict with posterior mean weights.'''
    x_tensor = transform_features(x)

    post_model.sample(False)
    post_model.train(False)

    point_logits = post_model.predict(x_tensor.to(post_model.device))
    point_probs = torch.sigmoid(point_logits)
    return point_probs.cpu().numpy()

@torch.no_grad()
def posterior_predictive(x, num_samples=1000):
    '''Predict according to the posterior predictive distribution.'''
    x_tensor = transform_features(x)

    post_model.sample(True)
    post_model.train(False)

    sampled_logits = post_model.predict(x_tensor.to(post_model.device), num_samples)
    sampled_probs = torch.sigmoid(sampled_logits)

    post_mean = torch.mean(sampled_probs, axis=-1)
    return post_mean.cpu().numpy()

@torch.no_grad()
def posterior_uncertainty(x, num_samples=1000):
    '''Compute the uncertainty associated with the posterior predictive.'''
    x_tensor = transform_features(x)

    post_model.sample(True)
    post_model.train(False)

    sampled_logits = post_model.predict(x_tensor.to(post_model.device), num_samples)
    sampled_probs = torch.sigmoid(sampled_logits)

    post_std = torch.std(sampled_probs, axis=-1)
    return post_std.cpu().numpy()

In [None]:
fig, ax = plt.subplots(figsize=(6, 4.5))
plot_data_2d(X_train, y_train, colors=(plt.cm.Set1(1), plt.cm.Set1(0)), ax=ax)
ax.set(xlim=(-2, 3), ylim=(-2, 2.5))
plot_function_2d(point_prediction, levels=(0.3, 0.5, 0.7), ax=ax)
ax.set_aspect('equal', adjustable='box')
ax.set_title('Point predictions')
ax.legend(loc='upper left')
fig.tight_layout()

In [None]:
fig, ax = plt.subplots(figsize=(6, 4.5))
plot_data_2d(X_train, y_train, colors=(plt.cm.Set1(1), plt.cm.Set1(0)), ax=ax)
ax.set(xlim=(-2, 3), ylim=(-2, 2.5))
plot_function_2d(posterior_mean, levels=(0.3, 0.5, 0.7), ax=ax)
ax.set_aspect('equal', adjustable='box')
ax.set_title('Posterior mean')
ax.legend(loc='upper left')
fig.tight_layout()

In [None]:
fig, ax = plt.subplots(figsize=(6, 4.5))
plot_data_2d(X_train, y_train, colors=(plt.cm.Set1(1), plt.cm.Set1(0)), ax=ax)
ax.set(xlim=(-2, 3), ylim=(-2, 2.5))
plot_function_2d(posterior_predictive, levels=(0.1, 0.3, 0.5, 0.7, 0.9), ax=ax)
ax.set_aspect('equal', adjustable='box')
ax.set_title('Posterior predictions')
ax.legend(loc='upper left')
fig.tight_layout()

In [None]:
fig, ax = plt.subplots(figsize=(6, 4.5))
plot_data_2d(X_train, y_train, colors=(plt.cm.Set1(1), plt.cm.Set1(0)), ax=ax)
ax.set(xlim=(-2, 3), ylim=(-2, 2.5))
plot_function_2d(posterior_uncertainty, levels=np.linspace(0.1, 0.9, 9), ax=ax)
ax.set_aspect('equal', adjustable='box')
ax.set_title('Posterior uncertainty')
ax.legend(loc='upper left')
fig.tight_layout()