In [None]:
import argparse
from box import Box
import yaml

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from model_logreg_mvn import ModelLogisicRegressionMvn
from dataset_npz import DataModuleFromNPZ

import torch
import numpy as np

import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

# Load data

In [None]:
pl.seed_everything(2202)
dm = DataModuleFromNPZ(
    data_dir="data_logistic_regression_2d",
    feature_labels=["inputs", "targets"],
    batch_size=128,
    num_workers=4,
    shuffle_training=False
)

# Train or load a model

In [None]:
if False: #switch to True to train and save a model within this notebook, switch back to False to evaluate
    dm.prepare_data()
    dm.setup(stage="fit")
    model = ModelLogisicRegressionMvn(
            2,
            dm.size_train(),
            scale_prior=10.0,
            optimizer_name="RMSprop", 
            optimizer_lr=0.1,
            save_path="runs/models/debug")
    trainer = Trainer(max_epochs=50)
    trainer.fit(model, dm)
    trainer.test(model, dm)
    model.eval()
else:
    dm.prepare_data()
    dm.setup(stage="fit")
    SAVE_PATH = "runs/models/debug/loss_val-epoch=41-step=336.ckpt" # change this to your saved model in the same directory!!!
    model = ModelLogisicRegressionMvn.load_from_checkpoint(SAVE_PATH, size_data=dm.size_train())
    dm.setup(stage="test")
    trainer = Trainer()
    trainer.test(model, dm)
    model.eval()
    

# Load all training and testing data for plotting

In [None]:
dm_plotting = DataModuleFromNPZ(
    data_dir="data_logistic_regression_2d",
    feature_labels=["inputs", "targets"],
    batch_size=-1,
    num_workers=4,
    shuffle_training=False
)
dm_plotting.prepare_data()
dm_plotting.setup(stage="fit")
for f,l in dm_plotting.train_dataloader():
    features_train, labels_train = f, l
dm_plotting.setup(stage="test")
for f,l in dm_plotting.test_dataloader():
    features_test, labels_test = f, l

# Compute class probabilities for plotting

In [None]:
x, y = np.meshgrid(np.arange(-1.1,1.1,0.025), np.arange(-1.1,1.1,0.025))
features_plot = np.concatenate([x.reshape((-1,1)), y.reshape((-1,1))], axis=-1)
p_plot = model(torch.tensor(features_plot, dtype=torch.float32)).detach().cpu().numpy().reshape(x.shape)

# Plot

In [None]:
fig, ax = plt.subplots(1,2, figsize=(20,10))

Ip = np.argwhere(labels_train[:] > 0.5)
In = np.argwhere(labels_train[:] < 0.5)
ax[0].contourf(x, y, p_plot, 50, cmap=plt.get_cmap("gray"))
ax[0].plot(features_train[Ip,0], features_train[Ip,1], ".", color = "red")
ax[0].plot(features_train[In,0], features_train[In,1], ".", color = "blue")
ax[0].set_title("Train data")

Ip = np.argwhere(labels_test[:] > 0.5)
In = np.argwhere(labels_test[:] < 0.5)
ax[1].contourf(x, y, p_plot, 50, cmap=plt.get_cmap("gray"))
ax[1].plot(features_test[Ip,0], features_test[Ip,1], ".", color = "red")
ax[1].plot(features_test[In,0], features_test[In,1], ".", color = "blue")
ax[1].set_title("Test data")



# Print model parameters

In [None]:
print("Learned distribution parameters")
print("weights mean")
print(model.weights_loc.detach().cpu().numpy())
print("weights covariance")
L = model.weights_chol().detach().cpu().numpy()
print(np.matmul(L,L.T))