In [None]:
import os, sys
import numpy as np
import pandas as pd

# Plotting
import matplotlib
import matplotlib.pyplot as plt
import corner

# Sampling from the BNN posterior
import torch
from h0rton.configs import TrainValConfig, TestConfig
import h0rton.losses
from h0rton.h0_inference import *
import h0rton.tdlmc_data
import h0rton.tdlmc_utils
import h0rton.train_utils as train_utils
import h0rton.models
from h0rton.trainval_data import XYData, XYCosmoData
from torch.utils.data import DataLoader

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Visualizing the BNN posterior

__Author:__ Ji Won Park (@jiwoncpark)

__Created:__ 11/01/2019

__Last run:__ 2/17/2020

__Goals:__
We visualize the BNN posterior overlaid against the BNN prior.

__Before_running:__
1. Train the BNN, e.g.
```bash
python -m h0rton.train h0rton/example_user_config.py
```

In [None]:
torch.cuda.empty_cache()
cfg = TrainValConfig.from_file(os.path.join('..', 'experiments', 'v11_train_val_config_file.json'))

In [None]:
device = torch.device('cuda')
if device.type == 'cuda':
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    torch.set_default_tensor_type('torch.FloatTensor')

# Instantiate loss function
loss_fn = getattr(h0rton.losses, cfg.model.likelihood_class)(Y_dim=cfg.data.Y_dim, device=device)
# Instantiate posterior (for logging)
bnn_post = getattr(h0rton.h0_inference.gaussian_bnn_posterior, loss_fn.posterior_name)(cfg.data.Y_dim, device, cfg.data.train_Y_mean, cfg.data.train_Y_std)
# Instantiate model
net = getattr(h0rton.models, cfg.model.architecture)(num_classes=loss_fn.out_dim)
# Load trained weights from saved state
train_utils.load_state_dict_test("/home/jwp/stage/sl/h0rton/v11_models/resnet34_epoch=1499_02-22-2020_10:01.mdl", net, cfg.optim.n_epochs, device)

In [None]:
val_data = XYData(cfg.data.val_dir, data_cfg=cfg.data)
n_val = val_data.n_data
val_loader = DataLoader(val_data, batch_size=5, shuffle=False, drop_last=True)

with torch.no_grad():
    net.eval()
    for _, (X_, Y_) in enumerate(val_loader):
        X = X_.to(device)
        Y = Y_.to(device)
        pred = net(X)
        print(loss_fn(pred, Y).item())
        break

In [None]:
loss_fn(pred, Y)

In [None]:
f

In [None]:
Y_dim = 12
rank=2

In [None]:
for k, v in sliced.items():
    sliced[k] = v.cpu()

In [None]:
from h0rton.losses import DoubleGaussianNLL, sigmoid
from scipy.stats import multivariate_normal

In [None]:
b = 0
mu_b = sliced['mu'][b, :]
diag_b = np.diagflat(np.exp(sliced['logvar'][b, :]))
F_b = sliced['F'][b, :].reshape(Y_dim, rank)
low_rank_b = np.matmul(F_b, F_b.T)
mu2_b = sliced['mu2'][b, :]
diag2_b = np.diagflat(np.exp(sliced['logvar2'][b, :]))
F2_b = sliced['F2'][b, :].reshape(Y_dim, rank)
low_rank2_b = np.matmul(F2_b, F2_b.T)

w2_b = 0.5*sigmoid(sliced['alpha'][b])

In [None]:
mu_b, Y

In [None]:
nll1 = -np.log(multivariate_normal.pdf(Y.cpu()[b, :Y_dim], mean=mu_b, cov=diag_b + low_rank_b))
nll2 = -np.log(multivariate_normal.pdf(Y.cpu()[b, :Y_dim], mean=mu2_b, cov=diag2_b + low_rank2_b))
(-np.log((1.0 - w2_b) * np.exp(-nll1) + w2_b * np.exp(-nll2)))

In [None]:
nll1, nll2

In [None]:
cfg.data.Y_cols_to_log_parameterize_idx

In [None]:
bnn_post = DoubleGaussianBNNPosterior(val_data.Y_dim, device, cfg.data.train_Y_mean, cfg.data.train_Y_std,)
bnn_post.set_sliced_pred(pred)

In [None]:
#mu_orig = bnn_post.transform_back(bnn_post.mu).cpu().numpy()
#mu_orig2 = bnn_post.transform_back(bnn_post.mu2).cpu().numpy()

In [None]:
import baobab.sim_utils
n_samples = 3000 # number of bnn samples per lens
bnn_samples = bnn_post.sample(n_samples, sample_seed=0)

In [None]:
truth = bnn_post.transform_back_mu(Y).cpu().numpy()

In [None]:
lens_i = 1
bnn_sample_df = pd.DataFrame(bnn_samples[lens_i, :, :], columns=cfg.data.Y_cols)
truth_lens_i = truth[lens_i, :]
print(truth_lens_i)

In [None]:
mu_orig = bnn_post.transform_back_mu(bnn_post.mu).cpu().numpy()
mu_orig2 = bnn_post.transform_back_mu(bnn_post.mu2).cpu().numpy()

In [None]:
mu_orig[lens_i, :]

In [None]:
param_idx = 1
plt.hist(bnn_samples[lens_i, :, param_idx], bins=30)
plt.axvline(mu_orig[lens_i, param_idx], color='r')
plt.axvline(mu_orig2[lens_i, param_idx], color='b')

In [None]:
bnn_samples[lens_i, :, 0]

In [None]:
prior_Y = torch.Tensor(val_data.Y_df[cfg.data.Y_cols].values)
prior_Y = bnn_post.transform_back_mu(prior_Y).cpu().numpy()
print(prior_Y.shape)

In [None]:
bnn_post_fig = corner.corner(bnn_sample_df[cfg.data.Y_cols],
                             color='tab:red',
                             smooth=1.0,
                             alpha=0.5,
                             labels=cfg.plotting.Y_cols_latex_names,
                             no_fill_contours=True,
                             plot_datapoints=False,
                             plot_contours=True,
                             show_titles=True,
                             levels=[0.68, 0.95],
                             contour_kwargs=dict(linestyles='solid'),
                             quiet=True,
                             hist_kwargs=dict(density=True,))

bnn_prior_fig = corner.corner(prior_Y,
                              color='tab:orange',
                              smooth=1.0,
                              alpha=0.5,
                              no_fill_contours=True,
                              plot_datapoints=False,
                              plot_contours=True,
                              truths=truth_lens_i,
                              fig=bnn_post_fig,
                              truths_color='tab:green',
                              levels=[0.68, 0.95],
                              contour_kwargs=dict(linestyles='solid'),
                              quiet=True,
                              range=[0.99]*len(cfg.data.Y_cols),
                              hist_kwargs=dict(density=True, ))

In [None]:
plus_minus = np.array(cfg.plotting.Y_cols_range)
lower = truth_lens_i - plus_minus
upper = truth_lens_i + plus_minus
display_range = list(zip(lower, upper))

In [None]:
_ = corner.corner(bnn_sample_df[cfg.data.Y_cols],
                             color='tab:red',
                             smooth=1.0,
                             alpha=0.5,
                             labels=cfg.plotting.Y_cols_latex_names,
                             truths=truth_lens_i,
                              truths_color='tab:green',
                             no_fill_contours=True,
                             plot_datapoints=False,
                             plot_contours=True,
                             show_titles=True,
                             levels=[0.68, 0.95],
                              range=display_range,
                             contour_kwargs=dict(linestyles='solid'),
                             quiet=True,
                             hist_kwargs=dict(density=True,))

In [None]:
bnn_prior_fig = corner.corner(prior_Y,
                              color='tab:orange',
                              smooth=1.0,
                              alpha=0.5,
                              no_fill_contours=True,
                              plot_datapoints=False,
                              plot_contours=True,
                              labels=cfg.data.Y_cols_latex_names,
                              #truths=truth_lens_i,
                              show_titles=True,
                              #truths_color='tab:green',
                              levels=[0.68, 0.95],
                              contour_kwargs=dict(linestyles='solid'),
                              quiet=True,
                              range=[0.99]*len(cfg.data.Y_cols),
                              hist_kwargs=dict(density=True, ))