In [None]:
import os, sys
import numpy as np
import pandas as pd

# Plotting
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
import corner
import matplotlib.cm as cm
# Sampling from the BNN posterior
from lenstronomy.Cosmo.lcdm import LCDM
from h0rton.configs import TrainValConfig, TestConfig
import h0rton.losses
from h0rton.h0_inference import *
import h0rton.tdlmc_data
import h0rton.tdlmc_utils
import h0rton.train_utils as train_utils
import h0rton.models
from h0rton.trainval_data import XYData
from torch.utils.data import DataLoader
from baobab.data_augmentation.noise_lenstronomy import NoiseModelNumpy
from baobab.configs import BaobabConfig
from baobab.sim_utils import metadata_utils
%matplotlib inline
%load_ext autoreload
%autoreload 2

plt.rcParams.update(plt.rcParamsDefault)
plt.rc('font', family='STIXGeneral', size=20)
plt.rc('xtick', labelsize='medium')
plt.rc('ytick', labelsize='medium')
plt.rc('text', usetex=True)
plt.rc('axes', linewidth=2, titlesize='large', labelsize='large')

# Individual cornerplot

__Author:__ Ji Won Park (@jiwoncpark)

__Created:__ 8/20/2020

__Last run:__ 11/29/2020

__Goals:__
We compare three distributions in a cornerplot:
- the BNN-inferred posterior over the lens model parameters and $D_{\Delta t}$
- the forward modeling equivalent
- the interim prior (training set distribution)

__Before_running:__
1. Train the BNN, e.g.
```bash
python h0rton/train.py experiments/v2/train_val_cfg.json
```

2. Get BNN inference results for the trained model, e.g.
```bash
python h0rton/infer_h0_mcmc_default.py experiments/v2/mcmc_default.json
```

3. Get forward modeling inference results for the trained model, e.g.
```bash
python h0rton/infer_h0_mcmc_forward_modeling.py experiments/v2/mcmc_default_chain_test.json
```

Let's read in the inference config and load the test and validation datasets.

In [None]:
# Read in inference config
true_H0 = 70.0
true_Om0 = 0.3
cfg = TrainValConfig.from_file(os.path.join('..', 'experiments', 'v2', 'train_val_cfg.json'))
test_cfg = TestConfig.from_file('/home/jwp/stage/sl/h0rton/experiments/v2/mcmc_default_chain_test.json')
test_baobab_cfg = BaobabConfig.from_file(test_cfg.data.test_baobab_cfg_path)

# Read in validation data config (for plotting the interim prior)
val_baobab_cfg = BaobabConfig.from_file(cfg.data.train_baobab_cfg_path)
val_truth = pd.read_csv(os.path.join(val_baobab_cfg.out_dir, 'metadata.csv'), index_col=None, nrows=5000)
val_truth['src_light_center_x_offset'] = val_truth['src_light_center_x'] - val_truth['lens_mass_center_x']
val_truth['src_light_center_y_offset'] = val_truth['src_light_center_y'] - val_truth['lens_mass_center_y']
val_truth = metadata_utils.add_qphi_columns(val_truth)
val_truth = metadata_utils.add_g1g2_columns(val_truth)

The validation data doesn't include the true $D_{\Delta t}$ for each lens, so we calculate it here.

In [None]:
# Compute the true D_dt 
val_truth['D_dt'] = 0.0
for i in range(val_truth.shape[0]):
    row_val = val_truth.iloc[i]
    lcdm = LCDM(z_lens=row_val['z_lens'], z_source=row_val['z_src'], flat=True)
    row_D_dt = lcdm.D_dt(H_0=true_H0, Om0=true_Om0)
    val_truth.loc[i, 'D_dt'] = row_D_dt

In [None]:
lens_i = 43

version_id = 2 # 2 HST orbits
version_dir = '/home/jwp/stage/sl/h0rton/experiments/v{:d}'.format(version_id)
default_samples_path = os.path.join(version_dir, 'mcmc_default_chain_test', 'mcmc_samples_{0:04d}.csv'.format(lens_i))

# BNN posterior samples
default_samples = pd.read_csv(default_samples_path, index_col=None)

# Forward modeling posterior samples
fm_samples_path = os.path.join(version_dir, 'forward_modeling_{:d}'.format(lens_i), 'mcmc_samples_{0:04d}.csv'.format(lens_i))
fm_samples = pd.read_csv(fm_samples_path, index_col=None)

# Discard burn-in from the forward modeling MCMC samples
print("Number of MCMC samples before discarding burn-in: ", fm_samples.shape[0])
fm_samples = fm_samples.iloc[int(fm_samples.shape[0]*0.9):]
print("Number of MCMC samples after discarding burn-in: ", fm_samples.shape[0])

# Redefine source position in absolute terms for both BNN and forward modeling posterior samples
default_samples['src_light_center_x'] = default_samples['src_light_center_x'] + default_samples['lens_mass_center_x']
default_samples['src_light_center_y'] = default_samples['src_light_center_y'] + default_samples['lens_mass_center_y']
fm_samples['src_light_center_x'] = fm_samples['src_light_center_x'] + fm_samples['lens_mass_center_x']
fm_samples['src_light_center_y'] = fm_samples['src_light_center_y'] + fm_samples['lens_mass_center_y']
# Read in truth values for this lens so we can overlay them
truth = pd.read_csv(os.path.join(test_baobab_cfg.out_dir, 'metadata.csv'), index_col=None)
truth['src_light_center_x_offset'] = truth['src_light_center_x'] - truth['lens_mass_center_x']
truth['src_light_center_y_offset'] = truth['src_light_center_y'] - truth['lens_mass_center_y']
truth = metadata_utils.add_qphi_columns(truth)
truth = metadata_utils.add_g1g2_columns(truth)
truth_lens_i = truth.iloc[lens_i]

lcdm = LCDM(z_lens=truth_lens_i['z_lens'], z_source=truth_lens_i['z_src'], flat=True)
true_D_dt = lcdm.D_dt(H_0=true_H0, Om0=true_Om0)
truth_lens_i['D_dt'] = true_D_dt

# D_dt posterior samples must be convolved with the kappa_ext prior since they assume kappa_ext = 0.0
default_samples['D_dt'] *= np.random.normal(1, 0.025, default_samples['D_dt'].values.shape)
fm_samples['D_dt'] *= np.random.normal(1, 0.025, fm_samples['D_dt'].values.shape)

In [None]:
print("Lens redshift: ", truth_lens_i['z_lens'], "source_redshift: ", truth_lens_i['z_src'])

In [None]:
total_cols = [
    'lens_mass_gamma',
    'lens_mass_theta_E',
    'lens_mass_e1',
    'lens_mass_e2',
    'external_shear_gamma1',
    'external_shear_gamma2',
    'src_light_R_sersic',
    'src_light_center_x',
    'src_light_center_y',
    'D_dt'
]

cols_to_plot = [
    'lens_mass_gamma',
    'lens_mass_theta_E',
    'lens_mass_e1',
    'lens_mass_e2',
    'external_shear_gamma1',
    'external_shear_gamma2',
    'src_light_R_sersic',
    'src_light_center_x',
    'src_light_center_y',
    'D_dt'
]
#cols_to_plot = cfg.data.Y_cols
labels = [
    r'$\gamma_{\rm lens}$',
    r'${\theta}_E (^{\prime \prime})$',
    '$e_1$',
    '$e_2$',
    '$\gamma_1$',
    '$\gamma_2$',
    r'$R_{\rm src} (^{\prime \prime})$',
    r"$x_{\rm src} ('')$",
    r"$y_{\rm src} ('')$",
    r'$D_{\Delta t}$ (Mpc)'
]

offset = np.array([0.3, #0.05, 
                   0.1, 
                   0.2, 
                   0.2, 
                   0.1, 
                   0.1, 
                   0.2, 
                   0.1,
                   0.1,
                   2000])*2.0

offset_dict = dict(zip(total_cols, offset))
middle = truth_lens_i[cols_to_plot].values
offset_to_plot = [offset_dict[c] for c in cols_to_plot]
display_range = list(zip(middle - offset_to_plot, middle + offset_to_plot))
display_range[-1] = (max(0, display_range[-1][0]), display_range[-1][1])
print(display_range[-1])
labels_dict = dict(zip(total_cols, labels))

In [None]:
plt.close('all')

bnn_post_fig = corner.corner(val_truth[cols_to_plot],
                            color='tab:gray',
                            smooth=1.5,
                            alpha=0.5,
                            fill_contours=True,
                            plot_datapoints=False,
                            labels=[labels_dict[c] for c in cols_to_plot],
                            label_kwargs={'fontsize': 30},

                            plot_contours=True,
                            show_titles=False,
                            plot_density=False,
                            levels=[0.68, 0.95],
                            contour_kwargs=dict(linestyles='--'),
                            #contourf_kwargs=dict(colors='tab:gray'),
                            quiet=True,
                            #fig=bnn_post_fig,
                            range=display_range,
                            #range=[0.99]*len(cols_to_plot),
                            use_math_text=True,
                            hist_kwargs=dict(density=True,),
                            hist2d_kwargs=dict(pcolor_kwargs=dict(alpha=0.1)))

# BNN
_ = corner.corner(default_samples[cols_to_plot],
                 color='#d6616b',
                 smooth=0.8,
                 alpha=1.0,
                 labels=[labels_dict[c] for c in cols_to_plot],
                 label_kwargs={'fontsize': 30},
                 fill_contours=True,
                 plot_datapoints=False,
                 plot_contours=True,
                 show_titles=True,
                 levels=[0.68, 0.95],
                 truths=truth_lens_i[cols_to_plot],
                  truth_color='k',
                 contour_kwargs=dict(linestyles='solid', colors='k'),
                 quiet=True,
                 title_fmt=".1g",
                 fig=bnn_post_fig,
                 title_kwargs={'fontsize': 18},
                 range=display_range,
                 use_math_text=True,
                 hist_kwargs=dict(density=True, histtype='stepfilled'))

# Forward modeling
_ = corner.corner(fm_samples[cols_to_plot],
                color='#8ca252',
                smooth=0.8,
                alpha=1.0,
                fig=bnn_post_fig,
                fill_contours=True,
                plot_datapoints=False,
                labels=[labels_dict[c] for c in cols_to_plot],
                label_kwargs={'fontsize': 30},
                plot_density=False,
                plot_contours=True,
                show_titles=False,
                levels=[0.68, 0.95],
                contour_kwargs=dict(linestyles='solid', colors='k'),
                quiet=True,
                range=display_range,
                use_math_text=True,
                hist_kwargs=dict(density=True, 
                                 histtype='stepfilled', 
                                 alpha=0.5, 
                                 color=cm.tab20b(0.26315789), 
                                 hatch='//', 
                                 edgecolor='k'),
                hist2d_kwargs=dict(hatch='//'))

bnn_post_fig.subplots_adjust(right=1.5, top=1.5)

legend_elements = [
    Patch(facecolor='tab:gray', edgecolor='tab:gray', label=r'Training distribution (implicit prior)'),
    Patch(facecolor='#d6616b', edgecolor='k', alpha=1.0, label=r'BNN posterior'),
    Patch(facecolor='#8ca252', edgecolor='k', alpha=1.0, hatch='//', label=r'Forward modeling posterior'),
               ]
plt.legend(handles=legend_elements, fontsize=40, loc=[-3., 7])
for ax in bnn_post_fig.get_axes():
    ax.tick_params(axis='both', labelsize=20)

plt.show()
#plt.savefig('../example_cornerplot_no_image_{:d}.png'.format(lens_i))

In [None]:
print(lens_i)
bnn_post_fig.savefig('../example_cornerplot_cwp_{:d}.png'.format(lens_i), dpi=100, pad_inches=0.3, bbox_inches='tight')