In [None]:
import os, sys
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy.stats as stats
import corner
from baobab import bnn_priors
from baobab.configs import Config, tdlmc_diagonal_config
%matplotlib inline
%load_ext autoreload
%autoreload 2

# Visualizing the input prior PDF in the DiagonalBNNPrior and the resulting samples
__Author:__ Ji Won Park
    
__Created:__ 8/30/19
    
__Last run:__ 9/05/19

__Goals:__
Plot the (marginal) distributions of the parameters sampled from the diagonal BNN prior, in which all parameters are assumed to be independent.

__Before running this notebook:__
Generate some data. At the root of the `baobab` repo, run:
```
generate baobab/configs/tdlmc_diagonal_config.py --n_data 1000
```
This generates 1000 samples using `DiagonalBNNPrior` at the location this notebook expects.

In [None]:
# TODO add description

In [None]:
cfg_path = tdlmc_diagonal_config.__file__
#cfg_path = os.path.join('..', '..', 'time_delay_lens_modeling_challenge', 'data', 'baobab_configs', 'train_tdlmc_diagonal_config.py')
cfg = Config.fromfile(cfg_path)
#out_data_dir = os.path.join('..', '..', 'time_delay_lens_modeling_challenge', cfg.out_dir)
out_data_dir = os.path.join('..', 'out_data', cfg.out_dir)
print(out_data_dir)
meta = pd.read_csv(os.path.join(out_data_dir, 'metadata.csv'), index_col=None)
bnn_prior = getattr(bnn_priors, cfg.bnn_prior_class)(cfg.bnn_omega, cfg.components)

Here are the parameters available. 

In [None]:
sorted(meta.columns.values)

In [None]:
# Add shear and ellipticity modulus and angle
meta['external_shear_gamma_ext1'] = meta['external_shear_gamma_ext']*np.cos(2.0*meta['external_shear_psi_ext'])
meta['external_shear_gamma_ext2'] = meta['external_shear_gamma_ext']*np.sin(2.0*meta['external_shear_psi_ext'])
for comp in ['lens_mass', 'src_light', 'lens_light']:
    meta['{:s}_ellip'.format(comp)] = np.sqrt(meta['{:s}_e1'.format(comp)]**2.0 + meta['{:s}_e2'.format(comp)]**2.0)
    meta['{:s}_phi'.format(comp)] = 0.5*np.arctan(meta['{:s}_e2'.format(comp)]/meta['{:s}_e1'.format(comp)])    

In [None]:
# Add source gal positional offset
meta['src_pos_offset'] = np.sqrt(meta['src_light_center_x']**2.0 + meta['src_light_center_y']**2.0)

In [None]:
def plot_prior_samples(eval_at, component, param, unit):
    param_key = '{:s}_{:s}'.format(component, param)
    hyperparams = cfg.bnn_omega[component][param].copy()
    pdf_eval = bnn_prior.eval_param_pdf(eval_at, hyperparams)
    plt.plot(eval_at, pdf_eval, 'r-', lw=2, alpha=0.6, label='PDF')
    binning = np.linspace(eval_at[0], eval_at[-1], 50)
    plt.hist(meta[param_key], bins=binning, edgecolor='k', density=True, align='mid', label='sampled')
    print(hyperparams)
    plt.xlabel("{:s} ({:s})".format(param_key, unit))
    plt.ylabel("density")
    plt.legend()

In [None]:
def plot_derived_quantities(param_key, unit):
    _ = plt.hist(meta[param_key], bins=30, edgecolor='k', density=True, align='mid', label='sampled')
    plt.xlabel("{:s} ({:s})".format(param_key, unit))
    plt.ylabel("density")
    plt.legend()

## Lens mass params

In [None]:
plot_prior_samples(np.linspace(0.5, 1.5, 100), 'lens_mass', 'theta_E', 'arcsec')

In [None]:
plot_prior_samples(np.linspace(-0.04, 0.04, 100), 'lens_mass', 'center_x', 'arcsec')

In [None]:
plot_prior_samples(np.linspace(-0.04, 0.04, 100), 'lens_mass', 'center_y', 'arcsec')

In [None]:
plot_prior_samples(np.linspace(1.5, 2.5, 100), 'lens_mass', 'gamma', 'dimensionless')

In [None]:
plot_prior_samples(np.linspace(-1.0, 1.0, 100), 'lens_mass', 'e1', 'dimensionless')

In [None]:
plot_prior_samples(np.linspace(-1.0, 1.0, 100), 'lens_mass', 'e2', 'dimensionless')

In [None]:
plot_derived_quantities('lens_mass_ellip', 'dimensionless')

In [None]:
plot_derived_quantities('lens_mass_phi', 'rad')

## External shear params

In [None]:
plot_prior_samples(np.linspace(0, 1.0, 100), 'external_shear', 'gamma_ext', 'no unit')

In [None]:
plot_prior_samples(np.linspace(0.0 - 0.5, 2.0*np.pi + 0.5, 100), 'external_shear', 'psi_ext', 'rad')

## Lens light params

In [None]:
plot_derived_quantities('lens_light_magnitude', 'mag')

In [None]:
plot_prior_samples(np.linspace(2, 6, 100), 'lens_light', 'n_sersic', 'dimensionless')

In [None]:
plot_prior_samples(np.linspace(0.0, 2.0, 100), 'lens_light', 'R_sersic', 'arcsec')

In [None]:
plot_prior_samples(np.linspace(-1.0, 1.0, 100), 'lens_light', 'e1', 'dimensionless')

In [None]:
plot_prior_samples(np.linspace(-1.0, 1.0, 100), 'lens_light', 'e2', 'dimensionless')

In [None]:
plot_derived_quantities('lens_light_ellip', 'dimensionless')

In [None]:
plot_derived_quantities('lens_light_phi', 'rad')

## Source light params

In [None]:
plot_derived_quantities('src_light_magnitude', 'mag')

In [None]:
plot_prior_samples(np.linspace(0.0, 6.0, 100), 'src_light', 'n_sersic', 'dimensionless')

In [None]:
plot_prior_samples(np.linspace(0.0, 2.0, 100), 'src_light', 'R_sersic', 'arcsec')

In [None]:
plot_prior_samples(np.linspace(-0.04, 0.04, 100), 'src_light', 'center_x', 'arcsec')

In [None]:
plot_prior_samples(np.linspace(-0.04, 0.04, 100), 'src_light', 'center_y', 'arcsec')

In [None]:
plot_prior_samples(np.linspace(-1.0, 1.0, 100), 'src_light', 'e1', 'dimensionless')

In [None]:
plot_prior_samples(np.linspace(-1.0, 1.0, 100), 'src_light', 'e2', 'dimensionless')

In [None]:
plot_derived_quantities('src_light_ellip', 'dimensionless')

In [None]:
plot_derived_quantities('src_light_phi', 'rad')

## AGN light params

In [None]:
plot_derived_quantities('agn_light_magnitude', 'mag')

## Total magnification

In [None]:
plot_derived_quantities('total_magnification', 'dimensionless')

## Pairwise distributions

In [None]:
def plot_pairwise_dist(df, cols, fig=None):
    n_params = len(cols)
    plot = corner.corner(meta[cols],
                        color='tab:blue', 
                        smooth=1.0, 
                        labels=cols,
                        show_titles=True,
                        fill_contours=True,
                        levels=[0.68, 0.95, 0.997],
                        fig=fig,
                        range=[0.99]*n_params,
                        hist_kwargs=dict(density=True, ))
    return plot

In [None]:
cols = ['src_pos_offset', 'total_magnification',
        'external_shear_gamma_ext', 'external_shear_psi_ext',
        'lens_mass_ellip', 'lens_mass_theta_E',
        'src_light_ellip', ]
_ = plot_pairwise_dist(meta, cols)

In [None]:
cols = ['lens_mass_gamma', 'lens_light_n_sersic' ]
_ = plot_pairwise_dist(meta, cols)