In [None]:
import os, sys
import numpy as np
import json
import matplotlib.pyplot as plt
import pandas as pd
from addict import Dict
import scipy.stats as stats
import corner
import lenstronomy.Util.param_util as param_util
from baobab import bnn_priors
from baobab.configs import BaobabConfig, tdlmc_diagonal_cosmo_config
%matplotlib inline
%load_ext autoreload
%autoreload 2

# Visualizing the input prior PDF in the DiagonalCosmoBNNPrior and the resulting samples
__Author:__ Ji Won Park
    
__Created:__ 11/20/19
    
__Last run:__ 11/20/19

__Goals:__
Plot the (marginal) distributions of the parameters sampled from the diagonal cosmology-aware BNN prior, in which parameters follow physically reasonable relations.

__Before running this notebook:__
1. Generate some data. At the root of the `baobab` repo, run:
```
generate baobab/configs/tdlmc_diagonal_cosmo_config.py --n_data 1000
```
This generates 1000 samples using `DiagonalCosmoBNNPrior` at the current working directory (the repo root). 

2. The `generate` script you just ran also exported a log file in the end, as a json file, to the current working directory. The name follows the format `"log_%m-%d-%Y_%H:%M_baobab.json"` where the date and time are of those at which you ran the script. Modify `baobab_log_path` in the below cell to the correct log path.

In [None]:
baobab_log_path = '/home/jwp/stage/sl/h0rton/log_12-10-2019_01:30_baobab.json'
with open(baobab_log_path, 'r') as f:
    log_str = f.read()
cfg = Dict(json.loads(log_str))
meta = pd.read_csv(os.path.abspath(os.path.join(cfg.out_dir, 'metadata.csv')), index_col=None)
bnn_prior = getattr(bnn_priors, cfg.bnn_prior_class)(cfg.bnn_omega, cfg.components)

Here are the parameters available. 

In [None]:
sorted(meta.columns.values)

In [None]:
# Add shear and ellipticity modulus and angle
if 'external_shear_gamma_ext' in meta.columns.values:
    gamma_ext = meta['external_shear_gamma_ext'].values
    psi_ext = meta['external_shear_psi_ext'].values
    gamma1, gamma2 = param_util.shear_polar2cartesian(phi=psi_ext, gamma=gamma_ext)
    meta['external_shear_gamma1'] = gamma1
    meta['external_shear_gamma2'] = gamma2
else:
    gamma1 = meta['external_shear_gamma1'].values
    gamma2 = meta['external_shear_gamma2'].values
    psi_ext, gamma_ext = param_util.shear_cartesian2polar(gamma1, gamma2)
    meta['external_shear_gamma_ext'] = gamma_ext
    meta['external_shear_psi_ext'] = psi_ext
for comp in cfg.components:
    if comp in ['lens_mass', 'src_light', 'lens_light']:
        if '{:s}_e1'.format(comp) in meta.columns.values:
            e1 = meta['{:s}_e1'.format(comp)].values
            e2 = meta['{:s}_e2'.format(comp)].values
            phi, q = param_util.ellipticity2phi_q(e1, e2)
            meta['{:s}_q'.format(comp)] = q
            meta['{:s}_phi'.format(comp)] = phi
        else:
            q = meta['{:s}_q'.format(comp)].values
            phi = meta['{:s}_phi'.format(comp)].values
            e1, e2 = param_util.phi_q2_ellipticity(phi, q)
            meta['{:s}_e1'.format(comp)] = e1
            meta['{:s}_e2'.format(comp)] = e2
meta['src_light_pos_offset_x'] = meta['src_light_center_x'] - meta['lens_mass_center_x']
meta['src_light_pos_offset_y'] = meta['src_light_center_y'] - meta['lens_mass_center_y']

In [None]:
# Add source gal positional offset
meta['src_pos_offset'] = np.sqrt(meta['src_light_center_x']**2.0 + meta['src_light_center_y']**2.0)

In [None]:
def plot_prior_samples(eval_at, component, param, unit):
    param_key = '{:s}_{:s}'.format(component, param)
    if param_key == 'src_light_pos_offset_x':
        hyperparams = cfg.bnn_omega['src_light']['center_x']
    elif param_key == 'src_light_pos_offset_y':
        hyperparams = cfg.bnn_omega['src_light']['center_y']
    elif (param_key == 'src_light_center_x') or (param_key == 'src_light_center_y'):
        raise NotImplementedError("Use `plot_derived_quantities` instead.")
    elif (component, param) in bnn_prior.params_to_exclude:
        raise NotImplementedError("This parameter wasn't sampled independently. Please use `plot_derived_quantities` instead.")
    else:
        hyperparams = cfg.bnn_omega[component][param].copy()
    pdf_eval = bnn_prior.eval_param_pdf(eval_at, hyperparams)
    plt.plot(eval_at, pdf_eval, 'r-', lw=2, alpha=0.6, label='PDF')
    binning = np.linspace(eval_at[0], eval_at[-1], 50)
    plt.hist(meta[param_key], bins=binning, edgecolor='k', density=True, align='mid', label='sampled')
    print(hyperparams)
    plt.xlabel("{:s} ({:s})".format(param_key, unit))
    plt.ylabel("density")
    plt.legend()

In [None]:
def plot_derived_quantities(param_key, unit, binning=None):
    binning = 30 if binning is None else binning
    _ = plt.hist(meta[param_key], bins=binning, edgecolor='k', density=True, align='mid', label='sampled')
    plt.xlabel("{:s} ({:s})".format(param_key, unit))
    plt.ylabel("density")
    plt.legend()

## Lens mass params

In [None]:
plot_prior_samples(np.linspace(0.7, 1.3, 30), 'lens_mass', 'theta_E', 'arcsec')

In [None]:
plot_prior_samples(np.linspace(-0.04, 0.04, 100), 'lens_mass', 'center_x', 'arcsec')

In [None]:
plot_prior_samples(np.linspace(-0.04, 0.04, 100), 'lens_mass', 'center_y', 'arcsec')

In [None]:
plot_prior_samples(np.linspace(1.7, 2.2, 30), 'lens_mass', 'gamma', 'dimensionless')

In [None]:
plot_prior_samples(np.linspace(0.5, 1.0, 30), 'lens_mass', 'q', 'dimensionless')

In [None]:
plot_prior_samples(np.linspace(-0.5*np.pi, 0.5*np.pi, 30), 'lens_mass', 'phi', 'dimensionless')

In [None]:
plot_derived_quantities('lens_mass_e1', 'dimensionless', 50)

In [None]:
plot_derived_quantities('lens_mass_e2', 'dimensionless', 50)

## External shear params

In [None]:
plot_prior_samples(np.linspace(0, 0.1, 100), 'external_shear', 'gamma_ext', 'no unit')

In [None]:
plot_prior_samples(np.linspace(-0.5*np.pi, 0.5*np.pi, 30), 'external_shear', 'psi_ext', 'rad')

In [None]:
plot_derived_quantities('external_shear_gamma1', 'dimensionless')

In [None]:
plot_derived_quantities('external_shear_gamma2', 'dimensionless')

## Lens light params

In [None]:
plot_prior_samples(np.linspace(15.0, 20.0, 30), 'lens_light', 'magnitude', 'mag')

In [None]:
plot_prior_samples(np.linspace(1, 6, 100), 'lens_light', 'n_sersic', 'dimensionless')

In [None]:
plot_prior_samples(np.linspace(0.2, 1.5, 30), 'lens_light', 'R_sersic', 'arcsec')

In [None]:
plot_prior_samples(np.linspace(0.3, 1.2, 30), 'lens_light', 'q', 'dimensionless')

In [None]:
plot_prior_samples(np.linspace(-0.5*np.pi, 0.5*np.pi, 30), 'lens_light', 'phi', 'rad')

In [None]:
plot_derived_quantities('lens_light_e1', 'dimensionless')

In [None]:
plot_derived_quantities('lens_light_e2', 'rad')

## Source light params

In [None]:
plot_prior_samples(np.linspace(15, 25, 100), 'src_light', 'magnitude', 'mag')

In [None]:
plot_prior_samples(np.linspace(0.5, 6.0, 100), 'src_light', 'n_sersic', 'dimensionless')

In [None]:
plot_prior_samples(np.linspace(0.1, 0.6, 100), 'src_light', 'R_sersic', 'dimensionless')

In [None]:
plot_prior_samples(np.linspace(-0.1, 0.1, 100), 'src_light', 'pos_offset_x', 'arcsec')

In [None]:
plot_prior_samples(np.linspace(-0.1, 0.1, 100), 'src_light', 'pos_offset_y', 'arcsec')

In [None]:
plot_prior_samples(np.linspace(0.2, 1.0, 100), 'src_light', 'q', 'dimensionless')

In [None]:
plot_prior_samples(np.linspace(-0.5*np.pi, 0.5*np.pi, 100), 'src_light', 'phi', 'rad')

In [None]:
plot_derived_quantities('src_light_e1', 'dimensionless', 30)

In [None]:
plot_derived_quantities('src_light_e2', 'dimensionless', 30)

## AGN light params

In [None]:
plot_prior_samples(np.linspace(15, 25, 100), 'src_light', 'magnitude', 'mag')

## Total magnification

In [None]:
plot_derived_quantities('total_magnification', 'dimensionless', binning=np.linspace(0, 300, 30))

## Other quantities

In [None]:
plot_derived_quantities('z_lens', 'dimensionless', 20)

In [None]:
plot_derived_quantities('z_src', 'dimensionless', 20)

In [None]:
_ = plt.hist(meta['z_src'] - meta['z_lens'], edgecolor='k', bins=30)

## Pairwise distributions

In [None]:
def plot_pairwise_dist(df, cols, fig=None):
    n_params = len(cols)
    plot = corner.corner(meta[cols],
                        color='tab:blue', 
                        smooth=1.0, 
                        labels=cols,
                        show_titles=True,
                        fill_contours=True,
                        levels=[0.68, 0.95, 0.997],
                        fig=fig,
                        range=[0.99]*n_params,
                        hist_kwargs=dict(density=True, ))
    return plot

In [None]:
cols = ['src_pos_offset', 'total_magnification',
        'external_shear_gamma_ext', 'external_shear_psi_ext',
        'lens_mass_q', 'lens_mass_theta_E',
        'src_light_q', 'src_light_R_sersic']
_ = plot_pairwise_dist(meta, cols)

In [None]:
cols = ['lens_mass_gamma', 'lens_light_n_sersic' ]
_ = plot_pairwise_dist(meta, cols)