In [None]:
import os, sys
import numpy as np
import pandas as pd

# Plotting
import matplotlib
import matplotlib.pyplot as plt
import corner

# Sampling from the BNN posterior
import torch
import torchvision.models
from h0rton.configs import TrainValConfig
import h0rton.losses
from h0rton.h0_inference import *
import h0rton.tdlmc_data
import h0rton.tdlmc_utils
import h0rton.train_utils as train_utils
from h0rton.trainval_data import XYData
from torch.utils.data import DataLoader

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Visualizing the BNN posterior

__Author:__ Ji Won Park (@jiwoncpark)

__Created:__ 11/01/2019

__Last run:__ 1/07/2019

__Goals:__
We visualize the BNN posterior overlaid against the BNN prior.

__Before_running:__
1. Train the BNN, e.g.
```bash
python -m h0rton.train h0rton/example_user_config.py
```

In [None]:
cfg = TrainValConfig.from_file(os.path.join('..', 'h0rton', 'train_val_config_file.py'))

In [None]:
device = torch.device('cpu')
if device.type == 'cuda':
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    torch.set_default_tensor_type('torch.FloatTensor')

loss_fn = getattr(h0rton.losses, cfg.model.likelihood_class)(Y_dim=cfg.data.Y_dim, device=device)
net = getattr(torchvision.models, cfg.model.architecture)(pretrained=cfg.model.load_pretrained)
n_filters = net.fc.in_features
net.fc = torch.nn.Linear(in_features=n_filters, out_features=loss_fn.out_dim) # replace final layer
net.to(device)
# Load trained weights from saved state
net, epoch = train_utils.load_state_dict_test("/home/jwp/stage/sl/h0rton/saved_models/resnet18_epoch=779_01-07-2020_15:37.mdl", net, cfg.optim.n_epochs, device)
net.eval()

In [None]:
# < 10 seconds
val_data = XYData(cfg.data.val_dir, data_cfg=cfg.data)
n_val = 4
val_loader = DataLoader(val_data, batch_size=n_val, shuffle=False, drop_last=True)

for _, (X_, Y_) in enumerate(val_loader):
    X = X_.to(device)
    Y = Y_.to(device)
    pred = net(X)
    break

In [None]:
bnn_post = DoubleGaussianBNNPosterior(val_data.Y_dim, cfg.data.Y_cols_to_whiten_idx, cfg.data.train_Y_mean, cfg.data.train_Y_std, cfg.data.Y_cols_to_log_parameterize_idx, device)
bnn_post.set_sliced_pred(pred)
n_samples = 5000 # number of bnn samples per lens

In [None]:
truth = bnn_post.transform_back(Y).numpy()

In [None]:
import baobab.sim_utils
bnn_samples = bnn_post.sample(n_samples, sample_seed=cfg.global_seed).reshape(-1, val_data.Y_dim)
bnn_samples = pd.DataFrame(bnn_samples, columns=cfg.data.Y_cols)
# Convert shear and ellipticity to gamma/psi and e1/e2, respectively
if 'external_shear_gamma1' in bnn_samples.columns:
    bnn_samples = baobab.sim_utils.add_gamma_psi_ext_columns(bnn_samples)
bnn_samples = baobab.sim_utils.add_qphi_columns(bnn_samples)
bnn_samples_colnames = bnn_samples.columns.values
bnn_samples_values = bnn_samples.values.reshape(n_val, n_samples, -1)

In [None]:
lens_i = 0
bnn_sample_df = pd.DataFrame(bnn_samples_values[lens_i, :, :], columns=bnn_samples_colnames)
truth_lens_i = truth[lens_i, :]

In [None]:
bnn_samples_colnames

In [None]:
prior_Y = torch.Tensor(val_data.Y_df[cfg.data.Y_cols].values)
prior_Y = bnn_post.transform_back(prior_Y).numpy()
print(prior_Y.shape)

In [None]:
bnn_post_fig = corner.corner(bnn_sample_df[cfg.data.Y_cols],
                             color='tab:red',
                             smooth=1.0,
                             alpha=0.5,
                             labels=cfg.data.Y_cols,
                             no_fill_contours=True,
                             plot_datapoints=False,
                             plot_contours=True,
                             show_titles=True,
                             levels=[0.68, 0.95],
                             contour_kwargs=dict(linestyles='solid'),
                             quiet=True,
                             hist_kwargs=dict(density=True,))

bnn_prior_fig = corner.corner(prior_Y,
                              color='tab:orange',
                              smooth=1.0,
                              alpha=0.5,
                              no_fill_contours=True,
                              plot_datapoints=False,
                              plot_contours=True,
                              truths=truth_lens_i,
                              fig=bnn_post_fig,
                              truths_color='tab:green',
                              levels=[0.68, 0.95],
                              contour_kwargs=dict(linestyles='solid'),
                              quiet=True,
                              range=[0.99]*len(cfg.data.Y_cols),
                              hist_kwargs=dict(density=True, ))