In [None]:
import argparse
from box import Box
import yaml

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from model_logreg_mvn import ModelLogisicRegressionMvn
from dataset_npz import DataModuleFromNPZ

import torch
import numpy as np

import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

# Load data

In [None]:
pl.seed_everything(2202)
dm = DataModuleFromNPZ(
    data_dir="data_logistic_regression_2d",
    feature_labels=["inputs", "targets"],
    batch_size=256,
    num_workers=4,
    shuffle_training=False
)

# Train or load a model

In [None]:
if True: # chose True to train and save a model within this notebook, switch back to False to evaluate
    dm.prepare_data()
    dm.setup(stage="fit")
    model_mvn = ModelLogisicRegressionMvn(
            2,
            dm.size_train(),
            scale_prior=10.0,
            optimizer_name="RMSprop", 
            optimizer_lr=0.1,
            save_path="runs/models/multivariate")
    trainer = Trainer(max_epochs=100)
    trainer.fit(model_mvn, dm)
    trainer.test(model_mvn, dm)
    model_mvn.eval()
        
else:
    dm.prepare_data()
    dm.setup(stage="fit")
    size_data_train = dm.size_train()
    
    SAVE_PATH = "runs/models/multivariate/loss_val-epoch=160-step=644.ckpt" # change this to your saved model in the same directory!!!
    model_mvn = ModelLogisicRegressionMvn.load_from_checkpoint(SAVE_PATH, size_data=size_data_train)
    dm.setup(stage="test")
    trainer = Trainer()
    trainer.test(model_mvn, dm)
    model_mvn.eval()


# Load all training and testing data for plotting

In [None]:
dm_plotting = DataModuleFromNPZ(
    data_dir="data_logistic_regression_2d",
    feature_labels=["inputs", "targets"],
    batch_size=-1,
    num_workers=4,
    shuffle_training=False
)
dm_plotting.prepare_data()
dm_plotting.setup(stage="fit")
for f,l in dm_plotting.train_dataloader():
    features_train, labels_train = f, l
dm_plotting.setup(stage="test")
for f,l in dm_plotting.test_dataloader():
    features_test, labels_test = f, l

# Compute class probabilities for plotting

In [None]:
x, y = np.meshgrid(np.arange(-1.1,1.1,0.025), np.arange(-1.1,1.1,0.025))
features_plot = np.concatenate([x.reshape((-1,1)), y.reshape((-1,1))], axis=-1)
p_plot_mvn  = model_mvn(torch.tensor(features_plot, dtype=torch.float32)).detach().cpu().numpy().reshape(x.shape)

# Plot

In [None]:
fig, ax = plt.subplots(1,2, figsize=(20,10))

Ip = np.argwhere(labels_train[:] > 0.5)
In = np.argwhere(labels_train[:] < 0.5)
ax[0].contourf(x, y, p_plot_mvn, 50, cmap=plt.get_cmap("gray"))
ax[0].plot(features_train[Ip,0], features_train[Ip,1], ".", color = "red")
ax[0].plot(features_train[In,0], features_train[In,1], ".", color = "blue")
ax[0].set_title("Multivariate model: Train data")

Ip = np.argwhere(labels_test[:] > 0.5)
In = np.argwhere(labels_test[:] < 0.5)
ax[1].contourf(x, y, p_plot_mvn, 50, cmap=plt.get_cmap("gray"))
ax[1].plot(features_test[Ip,0], features_test[Ip,1], ".", color = "red")
ax[1].plot(features_test[In,0], features_test[In,1], ".", color = "blue")
ax[1].set_title("Multivariate model: Test data")

# Print model parameters

In [None]:
print("Learned distribution parameters")
print("weights mean")
print(model_mvn.weights_loc.detach().cpu().numpy())
print("weights covariance")
L = model_mvn.weights_chol().detach().cpu().numpy()
print(np.matmul(L,L.T))

# Plot exact Bayesian posterior distribution of the weights

In [None]:
def log_prob_sigma(label_sign, features, w):
    z = label_sign*np.sum(features.reshape((1,-1))*w, axis=-1)
    #return np.abs(z) - np.log(1+np.exp(-z-np.abs(z)))
    return -np.log(1 + np.exp(-z))
    
w1, w2 = np.meshgrid(np.linspace(-3*model_mvn.scale_prior,3*model_mvn.scale_prior, 101),
                     np.linspace(-3*model_mvn.scale_prior,3*model_mvn.scale_prior, 101))
shape_mesh = w1.shape

w_plot = np.concatenate([w1.reshape((-1,1)), w2.reshape((-1,1))], axis=-1)
log_prob_prior = -0.5*np.sum((w_plot)**2, axis=-1)/(model_mvn.scale_prior**2)
log_prob_prior = log_prob_prior - 0.5*2.0*np.log(2*np.pi) - 0.5*2.0*2.0*np.log(model_mvn.scale_prior)

#fig, ax = plt.subplots(1,3, figsize=(20,10))
log_prob_current = log_prob_prior.reshape(shape_mesh)
for i, (feature, label) in enumerate(zip(features_train.cpu().numpy(), labels_train.cpu().numpy())):
    log_prob_likelihood = log_prob_sigma(2*label-1,feature, w_plot).reshape(shape_mesh)
    
    if i < 10:
        plot_on = True
    elif i % 200 == 0:
        plot_on = True
    else:
        plot_on = False
    
    if plot_on:
        fig, ax = plt.subplots(1,3, figsize=(24,8))
        ax[0].contourf(w1, w2, log_prob_current, 50, cmap=plt.get_cmap("gray"))
        ax[1].contourf(w1, w2, log_prob_likelihood, 50, cmap=plt.get_cmap("gray"))
        ax[2].contourf(w1, w2, log_prob_current + log_prob_likelihood, 50, cmap=plt.get_cmap("gray"))
        ax[0].set_title(f"size_data: {i}")
        ax[1].set_title("New Likelihood")
        ax[2].set_title("New posterior")
    
    log_prob_current = log_prob_current +  log_prob_likelihood 
    

# Plot true Bayesin posterior distribuion vs approximation

In [None]:
import scipy as sp

w_loc = model_mvn.weights_loc.detach().cpu().numpy().reshape((1,-1))
L     = model_mvn.weights_chol().detach().cpu().numpy()

log_prob_approx_mvn = -0.5*np.sum(np.matmul(w_plot-w_loc, np.linalg.inv(L).T)**2, axis=-1)
log_prob_approx_mvn = log_prob_approx_mvn - 0.5*2.0*np.log(2*np.pi) - 0.5*2.0*np.sum(np.log(np.diag(L)))


if True:
    dw = (w1[0,1] - w1[0,0])*(w2[1,0] - w2[0,0])
    log_prob_approx_mvn  = log_prob_approx_mvn - sp.special.logsumexp(log_prob_approx_mvn + np.log(dw))
    log_prob_current     = log_prob_current - sp.special.logsumexp(log_prob_current + np.log(dw))
    
 
fig, ax = plt.subplots(1,2, figsize=(20,10))
ax[0].contourf(w1, w2, log_prob_current, 50, cmap=plt.get_cmap("gray"))
ax[0].set_title("Bayesian postrior")
ax[1].contourf(w1, w2, log_prob_approx_mvn.reshape(w1.shape), 50, cmap=plt.get_cmap("gray"))
ax[1].set_title("Mulrivariate approximation")

