In [None]:
import os, sys
import numpy as np
from scipy import stats
import json
from addict import Dict
import matplotlib.pyplot as plt
from h0rton.h0_inference import plotting_utils, h0_utils
from h0rton.configs import TrainValConfig, TestConfig
from scipy.stats import norm, gaussian_kde
import glob
import matplotlib.image as mpimg
%matplotlib inline
%load_ext autoreload
%autoreload 2

# Plotting params
plt.rcParams.update(plt.rcParamsDefault)
plt.rc('font', family='STIXGeneral', size=20)
plt.rc('xtick', labelsize='medium')
plt.rc('ytick', labelsize='medium')
plt.rc('text', usetex=True)
plt.rc('axes', linewidth=2, titlesize='large', labelsize='large')

# Illustrating the importance of D_dt fit distribution

__Author:__ Ji Won Park (@jiwoncpark)

__Created:__ 8/20/2020

__Last run:__ 11/29/2020

__Goals:__
We illustrate the importance of the $D_{\Delta t}$ fit distribution by comparing the fit quality between normal, lognormal, and KDE fits.

__Before_running:__
1. Train the BNN, e.g.
```bash
python h0rton/train.py experiments/v2/train_val_cfg.json
```

2. Get inference results for the trained model and the precision ceiling, e.g.
```bash
python h0rton/infer_h0_mcmc_default.py experiments/v2/mcmc_default.json
python h0rton/infer_h0_simple_mc_truth.py experiments/v0/simple_mc_default.json
```

3. Summarize the inference results, e.g.
```bash
python h0rton/summarize.py 2 mcmc_default
python h0rton/summarize.py 0 mcmc_default
```

In [None]:
# Read in the inference config
default_version_id = 2 # corresponds to 2 HST orbits
default_version_dir = '/home/jwp/stage/sl/h0rton/experiments/v{:d}'.format(default_version_id)
test_cfg_path = os.path.join(default_version_dir, 'mcmc_default.json')
test_cfg = TestConfig.from_file(test_cfg_path)

In [None]:
plt.close('all')
# Plot a D_dt histogram for the pipeline diagram
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
for i, lens_i in enumerate([ 20, 169, 31,]):
    saved_dict = np.load('/home/jwp/stage/sl/h0rton/experiments/v2/mcmc_default/D_dt_dict_{0:04}.npy'.format(lens_i), allow_pickle=True).item()
    uncorrected_D_dt_samples = saved_dict['D_dt_samples']
    oversampling = 20
    uncorrected_D_dt_samples = h0_utils.remove_outliers_from_lognormal(uncorrected_D_dt_samples, 3).reshape(-1, 1) # [n_samples, 1] 
    k_ext_rv = getattr(stats, test_cfg.kappa_ext_prior.dist)(**test_cfg.kappa_ext_prior.kwargs)
    k_ext = k_ext_rv.rvs(size=[len(uncorrected_D_dt_samples), oversampling]) # [n_samples, oversampling]
    if test_cfg.kappa_ext_prior.transformed:
        #print("Transformed")
        D_dt_samples = (uncorrected_D_dt_samples*k_ext).flatten()
        #print(D_dt_samples.shape)
    else:
        D_dt_samples = (uncorrected_D_dt_samples/(1.0 - k_ext)).flatten() # [n_samples,]
        
    D_dt_grid = np.linspace(0, 15000, 100)
    # Plot KDE
    kde = gaussian_kde(D_dt_samples, bw_method='scott')
    axes[i].plot(D_dt_grid, kde(D_dt_grid), color='k', linestyle='solid', label='KDE', linewidth=2)
    # Plot lognormal fit
    D_dt_stats = h0_utils.get_lognormal_stats(D_dt_samples)
    axes[i].plot(D_dt_grid, plotting_utils.lognormal(D_dt_grid, D_dt_stats['mu'], D_dt_stats['sigma']), color='k', linestyle='dashed', label='Lognormal fit')
    # Plot normal fit
    D_dt_stats_normal = h0_utils.get_normal_stats(D_dt_samples)
    axes[i].plot(D_dt_grid, norm.pdf(D_dt_grid, loc=D_dt_stats_normal['mean'], scale=D_dt_stats_normal['std']), color='k', ls='dotted', label='Normal fit', linewidth=2)
    # Plot samples
    axes[i].hist(D_dt_samples, range=[0, 15000], bins=100, color='#d6616b', density=True, label='$D_{\Delta t}$ posterior samples', alpha=0.75)
    #plt.axvline(saved_dict['true_D_dt'], c='tab:gray', ls='--', label='True $D_{\Delta t}$')
    axes[i].set_yticks([])
    axes[i].set_xlabel('$D_{\Delta t}$ (Mpc)', fontsize=25)
    #axes[i].set_ylabel('Density', fontsize=25)
    #plt.xticks([0, 15000], [0, 15000], fontsize=20)
global_legend = axes[0].legend(bbox_to_anchor=(0.03, 1.03, 2 + 1.1, 0.102), loc='upper center', ncol=4, mode="expand", borderaxespad=-0.5, fontsize=18, frameon=False, columnspacing=0.08)

axes[0].add_artist(global_legend)
axes[0].set_ylabel('Density', fontsize=28)
plt.subplots_adjust(wspace=0.08)

plt.show()
#fig.savefig('../plots/kde_vs_lognormal_vs_normal.png', bbox_inches='tight', pad_inches=0)