In [None]:
import argparse
from box import Box
import yaml

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from model_logreg_dn import ModelLogisicRegressionMvn
from dataset_npz import DataModuleFromNPZ

import torch
import torch.nn as nn

import numpy as np

import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

# Load data

In [None]:
pl.seed_everything(2202)
dm = DataModuleFromNPZ(
    data_dir="data_logistic_two_moons",
    feature_labels=["inputs", "targets"],
    batch_size=256,
    num_workers=4,
    shuffle_training=False
)

# Train model

In [None]:
dm.prepare_data()
dm.setup(stage="fit")

feature_map = nn.Sequential(nn.Linear(2,256), nn.LeakyReLU(), nn.Linear(256,2))
# WARNING: this feature map only transforms in 2d, you can use a d-dimensional outpout and 
# modify the regression model to work on d-dimensions fro better performance

model_mvn = ModelLogisicRegressionMvn(
        2,
        dm.size_train(),
        feature_map=feature_map,
        is_diagonal=False,
        scale_prior=10.0,
        optimizer_name="RMSprop", 
        optimizer_lr=0.1,
        save_path="runs/models/multivariate")
trainer = Trainer(max_epochs=50)
trainer.fit(model_mvn, dm)
trainer.test(model_mvn, dm)
model_mvn.eval()

model_diag = ModelLogisicRegressionMvn(
        2,
        dm.size_train(),
        feature_map=feature_map,
        is_diagonal=True,
        scale_prior=10.0,
        optimizer_name="RMSprop", 
        optimizer_lr=0.1,
        save_path="runs/models/diagonal")
trainer = Trainer(max_epochs=50)
trainer.fit(model_diag, dm)
trainer.test(model_diag, dm)
model_diag.eval()

# Load all training and testing data for plotting

In [None]:
dm_plotting = DataModuleFromNPZ(
    data_dir="data_logistic_two_moons",
    feature_labels=["inputs", "targets"],
    batch_size=-1,
    num_workers=4,
    shuffle_training=False
)
dm_plotting.prepare_data()
dm_plotting.setup(stage="fit")
for f,l in dm_plotting.train_dataloader():
    features_train, labels_train = f, l
dm_plotting.setup(stage="test")
for f,l in dm_plotting.test_dataloader():
    features_test, labels_test = f, l

# Compute class probabilities for plotting

In [None]:
x, y = np.meshgrid(np.arange(-3,3,0.025), np.arange(-3, 3, 0.025))
features_plot = np.concatenate([x.reshape((-1,1)), y.reshape((-1,1))], axis=-1)
p_plot_mvn  = model_mvn(torch.tensor(features_plot, dtype=torch.float32)).detach().cpu().numpy().reshape(x.shape)
p_plot_diag = model_diag(torch.tensor(features_plot, dtype=torch.float32)).detach().cpu().numpy().reshape(x.shape)

# Plot

In [None]:
fig, ax = plt.subplots(1,2, figsize=(20,10))

Ip = np.argwhere(labels_train[:] > 0.5)
In = np.argwhere(labels_train[:] < 0.5)
ax[0].contourf(x, y, p_plot_mvn, 50, cmap=plt.get_cmap("gray"))
ax[0].plot(features_train[Ip,0], features_train[Ip,1], ".", color = "red")
ax[0].plot(features_train[In,0], features_train[In,1], ".", color = "blue")
ax[0].set_title("Multivariate model: Train data")

Ip = np.argwhere(labels_test[:] > 0.5)
In = np.argwhere(labels_test[:] < 0.5)
ax[1].contourf(x, y, p_plot_mvn, 50, cmap=plt.get_cmap("gray"))
ax[1].plot(features_test[Ip,0], features_test[Ip,1], ".", color = "red")
ax[1].plot(features_test[In,0], features_test[In,1], ".", color = "blue")
ax[1].set_title("Multivariate model: Test data")

fig, ax = plt.subplots(1,2, figsize=(20,10))

Ip = np.argwhere(labels_train[:] > 0.5)
In = np.argwhere(labels_train[:] < 0.5)
ax[0].contourf(x, y, p_plot_diag, 50, cmap=plt.get_cmap("gray"))
ax[0].plot(features_train[Ip,0], features_train[Ip,1], ".", color = "red")
ax[0].plot(features_train[In,0], features_train[In,1], ".", color = "blue")
ax[0].set_title("Diagonal model: Train data")

Ip = np.argwhere(labels_test[:] > 0.5)
In = np.argwhere(labels_test[:] < 0.5)
ax[1].contourf(x, y, p_plot_diag, 50, cmap=plt.get_cmap("gray"))
ax[1].plot(features_test[Ip,0], features_test[Ip,1], ".", color = "red")
ax[1].plot(features_test[In,0], features_test[In,1], ".", color = "blue")
ax[1].set_title("Diagonal model: Test data")



# Print model parameters

In [None]:
print("Learned distribution parameters")
print("weights mean")
print(model_mvn.weights_loc.detach().cpu().numpy())
print("weights covariance")
L = model_mvn.weights_chol().detach().cpu().numpy()
print(np.matmul(L,L.T))


print("Learned distribution parameters")
print("weights mean")
print(model_diag.weights_loc.detach().cpu().numpy())
print("weights covariance")
L = model_diag.weights_chol().detach().cpu().numpy()
print(np.matmul(L,L.T))

# Plot Bayesian  posterior distribution of weights

In [None]:
def log_prob_sigma(label_sign, features, w):
    z = label_sign*np.sum(features.reshape((1,-1))*w, axis=-1)
    #return np.abs(z) - np.log(1+np.exp(-z-np.abs(z)))
    return -np.log(1 + np.exp(-z))
    
w1, w2 = np.meshgrid(np.linspace(-3*model_diag.scale_prior,3*model_diag.scale_prior, 101),
                     np.linspace(-3*model_diag.scale_prior,3*model_diag.scale_prior, 101))
shape_mesh = w1.shape

w_plot = np.concatenate([w1.reshape((-1,1)), w2.reshape((-1,1))], axis=-1)
log_prob_prior = -0.5*np.sum((w_plot)**2, axis=-1)/(model_diag.scale_prior**2)
log_prob_prior = log_prob_prior - 0.5*2.0*np.log(2*np.pi) - 0.5*2.0*2.0*np.log(model_diag.scale_prior)

#fig, ax = plt.subplots(1,3, figsize=(20,10))
log_prob_current = log_prob_prior.reshape(shape_mesh)
for i, (feature, label) in enumerate(zip(features_train.cpu().numpy(), labels_train.cpu().numpy())):
    log_prob_likelihood = log_prob_sigma(2*label-1,feature, w_plot).reshape(shape_mesh)
    
    if i < 10:
        plot_on = True
    elif i % 100 == 0:
        plot_on = True
    else:
        plot_on = False
    
    if plot_on:
        fig, ax = plt.subplots(1,3, figsize=(24,8))
        ax[0].contourf(w1, w2, log_prob_current, 50, cmap=plt.get_cmap("gray"))
        ax[1].contourf(w1, w2, log_prob_likelihood, 50, cmap=plt.get_cmap("gray"))
        ax[2].contourf(w1, w2, log_prob_current + log_prob_likelihood, 50, cmap=plt.get_cmap("gray"))
        ax[0].set_title(f"size_data: {i}")
    
    log_prob_current = log_prob_current +  log_prob_likelihood 
    

    

# Plot posterior distribuion vs approximation

In [None]:
import scipy as sp

w_loc = model_mvn.weights_loc.detach().cpu().numpy().reshape((1,-1))
L     = model_mvn.weights_chol().detach().cpu().numpy()

log_prob_approx_mvn = -0.5*np.sum(np.matmul(w_plot-w_loc, np.linalg.inv(L).T)**2, axis=-1)
log_prob_approx_mvn = log_prob_approx_mvn - 0.5*2.0*np.log(2*np.pi) - 0.5*2.0*np.sum(np.log(np.diag(L)))


w_loc = model_diag.weights_loc.detach().cpu().numpy().reshape((1,-1))
L_diag = torch.exp(model_diag.weights_scale_logdiag).detach().cpu().numpy().reshape((1,-1))

log_prob_approx_diag = -0.5*np.sum(((w_plot-w_loc)/L_diag)**2, axis=-1)
log_prob_approx_diag = log_prob_approx_diag - 0.5*2.0*np.log(2*np.pi) - 0.5*2.0*np.sum(np.log(L_diag))

if True:
    dw = (w1[0,1] - w1[0,0])*(w2[1,0] - w2[0,0])
    log_prob_approx_mvn  = log_prob_approx_mvn - sp.special.logsumexp(log_prob_approx_mvn + np.log(dw))
    log_prob_approx_diag = log_prob_approx_diag - sp.special.logsumexp(log_prob_approx_diag + np.log(dw))
    log_prob_current     = log_prob_current - sp.special.logsumexp(log_prob_current + np.log(dw))
    
 
fig, ax = plt.subplots(2,2, figsize=(20,10))
ax[0,0].contourf(w1, w2, log_prob_current, 50, cmap=plt.get_cmap("gray"))
ax[0,0].set_title("Bayesian postrior")
ax[0,1].contourf(w1, w2, log_prob_approx_mvn.reshape(w1.shape), 50, cmap=plt.get_cmap("gray"))
ax[0,1].set_title("Mulrivariate approximation")
ax[1,1].contourf(w1, w2, log_prob_approx_diag.reshape(w1.shape), 50, cmap=plt.get_cmap("gray")) 
ax[1,1].set_title("Diagonal approximation")
