In [None]:
import os
import numpy as np
import pandas as pd
import lenstronomy
print(lenstronomy.__path__)
from h0rton.configs import TrainValConfig, TestConfig
from baobab.configs import BaobabConfig
import matplotlib.pyplot as plt
from scipy.stats import norm
import h0rton.tdlmc_utils as tdlmc_utils
import baobab.sim_utils as sim_utils

import numba
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Plotting params
plt.rcParams.update(plt.rcParamsDefault)
plt.rc('font', family='STIXGeneral', size=20)
plt.rc('xtick', labelsize='medium')
plt.rc('ytick', labelsize='medium')
plt.rc('text', usetex=True)
plt.rc('axes', linewidth=2, titlesize='large', labelsize='large')

# Model selection and evaluation

__Author:__ Ji Won Park (@jiwoncpark)

__Created:__ 8/20/2020

__Last run:__ 11/29/2020

__Goals:__
We perform model selection and evaluation based on the validation set.

__Before_running:__
1. Train the BNN for all exposure times and dropout rates, e.g. for an exposure time of 2 HST orbits (v2) and dropout rate of 0.001,
```bash
python h0rton/train.py experiments/v2/train_val_cfg.json
```

2. Get BNN predictions on the validation set for various exposure times and dropout rates, e.g.
```bash
python h0rton/infer_mcmc_default.py experiments/v2/mcmc_default_samples_drop=0.001.json
```

## Table of contents
1. [Model selection (choosing the dropout rate)](#dropout)
2. [Model evaluation (accuracy and precision)](#model_eval)

## 1. Model selection (choosing the dropout rate) <a name="dropout"></a> 

In [None]:
def get_bnn_predictions(version_id, dropout='drop=0.001'):
    """Stores the BNN predictions for each exposure time and dropout rate
    
    Parameters
    ----------
    version_id : int
        version ID corresponding to the orbit
    dropout : str
        identifying string for dropout. Default: 'drop=0.001'
    
    """
    n_val = 512 # number of validation lenses
    version_dir = '/home/jwp/stage/sl/h0rton/experiments/v{:d}'.format(version_id)
    test_cfg_path = os.path.join(version_dir, 'mcmc_default_samples_{:s}.json'.format(dropout))
    test_cfg = TestConfig.from_file(test_cfg_path)
    baobab_cfg = BaobabConfig.from_file(test_cfg.data.test_baobab_cfg_path)
    train_val_cfg = TrainValConfig.from_file(test_cfg.train_val_config_file_path)
    # Read in truth metadata of the validation set
    truth_info = pd.read_csv(os.path.join(baobab_cfg.out_dir, 'metadata.csv'), index_col=None, nrows=n_val)
    # Assign lens ID based on the row index of metadata, for merging with summary
    truth_info['id'] = truth_info.index
    # Note that metadata stores the absolute source position, so get relative to lens center
    truth_info['src_light_center_x'] -= truth_info['lens_mass_center_x']
    truth_info['src_light_center_y'] -= truth_info['lens_mass_center_y']
    # Read in the MC dropout BNN samples 
    samples = np.load('/home/jwp/stage/sl/h0rton/experiments/v{:d}/mcmc_default_samples_val_{:s}/samples.npy'.format(version_id, dropout))
    # Merge MC dropout axis with samples per dropout
    n_lenses = samples.shape[0]
    mcmc_Y_dim = samples.shape[-1]
    samples = samples.transpose(0, 3, 1, 2).reshape([n_lenses, mcmc_Y_dim, -1])[:n_val] 
    # Store BNN predictions (expected value and spread) as dataframes
    pred_mean = np.mean(samples, axis=-1) # expected value
    pred_std = np.std(samples, axis=-1)
    params_to_remove = [] #'src_light_R_sersic'] 
    mcmc_Y_cols = [col for col in train_val_cfg.data.Y_cols if col not in params_to_remove]
    pred_mean = pd.DataFrame(pred_mean, columns=mcmc_Y_cols)
    pred_std = pd.DataFrame(pred_std, columns=mcmc_Y_cols)
    pred_mean = sim_utils.add_qphi_columns(pred_mean)
    pred_mean = sim_utils.add_gamma_psi_ext_columns(pred_mean)
    truth_info = sim_utils.add_g1g2_columns(truth_info)
    return pred_std, pred_mean, truth_info, samples, mcmc_Y_cols

In [None]:
std = {} # BNN-predicted uncertainty
mean = {} # BNN-predicted center
truth = {} # True parameter value
samples = {} # BNN posterior samples
orbit_to_ver = dict(zip([0.5, 1, 2], [4, 3, 2])) # maps HST orbit to version folder
orbits_available = [0.5, 1, 2] # HST orbits
dropouts_available = ['no_dropout', 'drop=0.001', 'drop=0.005']
dropout_label_to_dropout_float = dict(zip(dropouts_available, [0.0, 0.001, 0.005]))
# Populate std, mean, truth, samples dicts with structure dict[orbit][dropout_rate]
for o in orbits_available:
    std[o] = {}
    mean[o] = {}
    truth[o] = {}
    samples[o] = {}
    for d in dropouts_available:
        simple_std, simple_mean, simple_truth, simple_samples, mcmc_Y_cols = get_bnn_predictions(version_id=orbit_to_ver[o], dropout=d)
        std[o][d] = simple_std
        mean[o][d] = simple_mean
        truth[o][d] = simple_truth # FIXME: truth doesn't vary with dropout or orbit
        samples[o][d] = simple_samples

We now generate the 2D calibration plot we defined in Wagner-Carena et al 2020. The metric asks: for a given percentage of the BNN posterior probability volume $p_X$, what percentage of the samples, $p_Y$, contains the truth within this volume? If the posterior is perfectly calibrated, we would expect $p_X$ of the samples to encompass the truth $p_Y = p_X$ of the time, for every value of $p_X$. We can apply this metric on the validation set as a whole by averaging the $p_Y$ values evaluated on individual lenses:

$$p_{ Y}^{\rm val}(p_X) = \frac{1}{N^{\rm val}} \sum_{k=1}^{N^{\rm val}} {1}\left\{ \frac{ \sum_{n=1}^N  {1}\left\{d(\xi_n^{(k)}) < d(\xi_{\rm true}^{(k)})\right\}}{N} < p_X \right\}$$

where:
- $k$ indexes the validation lenses;
- $\{\xi_n^{(k)}\}_{n=1}^N$ refers to the $N$ parameter samples drawn from the BNN posterior for some lens $k$;
- $1\{\cdot\}$ is an indicator function that evaluates to 1 when the argument is true and 0 otherwise; and
- $d(\xi)$ is a measure of distance of a particular point $\xi$ from the posterior predictive mean given the posterior width.

There are many choices for the distance measure $d$. Park et al 2020 defines $d$ as the distance of $\xi$ from the posterior predictive mean given the posterior width *for that lens*. Instead, Wagner-Carena et al 2020 approximated $d(\xi)$ as the position of $\xi$ given the covariance matrix of the *training set* (the interim prior), out of speed considerations. The results aren't too different but we enable both options here, via `get_p_Y_val_mahalanobis` (Park et al 2020) and `get_p_Y_val_approx_mahalanobis` (Wagner-Carena et al 2020).

In [None]:
def get_mae(predicted, true):
    """Get the median absolute error, or median(|predicted - true|)"""
    return np.median(np.abs(predicted - true))

In [None]:
def get_weighted_bias(predicted_mean, predicted_std, true):
    """Get the error (bias) for a parameter inversely weighted by the BNN uncertainty"""
    weighted_bias = np.average(predicted_mean - true, weights=1.0/predicted_std**2.0)
    weighted_bias_spread = np.average((predicted_mean - true - weighted_bias)**2.0, weights=1.0/predicted_std**2.0)**0.5
    return weighted_bias, weighted_bias_spread

In [None]:
def get_p_Y_val_approx_mahalanobis(post_samples, y_mean, y_truth, cov):
    """Calculate the percentage of draws from the predicted distribution that
    encompasses the truth, for all of the examples in the validation set.
    
    Parameters
    ----------
    post_samples : np.array of shape [n_samples, n_lenses, Y_dim]
        BNN posterior samples
    y_mean : np.array of shape [n_lenses, Y_dim]
        Central prediction to use in the distance calculation
    y_truth: np.array of shape [n_lenses, Y_dim]
        True parameter values
    cov : np.array of shape []
        Covariance matrix to use in the distance calculation
        
    Notes
    -----
    Adapted from https://github.com/swagnercarena/ovejero
    
    """
    @numba.njit
    def approx_mahalanobis(dif, cov):
        """Metric used in Wagner-Carena et al 2020"""
        d_metric = np.zeros(dif.shape[0:2])
        for i in range(d_metric.shape[0]):
            for j in range(d_metric.shape[1]):
                d_metric[i, j] = np.dot(dif[i, j], np.dot(cov, dif[i, j]))
        return d_metric
    p_Y_val = (approx_mahalanobis(post_samples-y_mean, cov) < approx_mahalanobis(np.expand_dims(y_truth-y_mean, axis=0), cov))
    p_Y_val = np.mean(p_Y_val, axis=0)
    return p_Y_val

In [None]:
def get_p_Y_val_mahalanobis(post_samples, y_truth):
    """Calculate the percentage of draws from the predicted distribution that
    encompasses the truth, for all of the examples in the validation set

    Parameters
    ----------
    post_samples : np.array of shape [n_samples, n_lenses, Y_dim]
        BNN posterior samples
    y_mean : np.array of shape [n_lenses, Y_dim]
        Central prediction to use in the distance calculation
    y_truth: np.array of shape [n_lenses, Y_dim]
        True parameter values
        
    """
    # The metric for the distance calculation. Using numba for speed.
    @numba.njit
    def mahalanobis(dif, prec):
        """Metric used in Park et al 2020"""
        d_metric = np.zeros(dif.shape[0:2]) # [n_samples, n_lenses]
        for j in range(d_metric.shape[1]):
            for i in range(d_metric.shape[0]):
                d_metric[i, j] = np.dot(dif[i, j], np.dot(prec[:, :, j], dif[i, j]))
        return d_metric
    
    n_samples, n_lenses, Y_dim = post_samples.shape
    p_Y_val = np.empty((n_samples, n_lenses))
    y_mean = np.mean(post_samples, axis=0) # [n_lenses, Y_dim]
    prec_val = np.empty((Y_dim, Y_dim, n_lenses)) # precision matrix of all lenses
    for lens_i in range(y_truth.shape[0]): # loop over the lenses
        prec_val[:, :, lens_i] = np.inv(np.cov(post_samples[:, lens_i, :])) # [Y_dim, Y_dim]
    p_Y_val = (mahalanobis(post_samples-y_mean, prec_val) < mahalanobis(np.expand_dims(y_truth-y_mean, axis=0), prec_val)) # [n_samples, n_lenses]
    p_Y_val = np.mean(p_Y_val, axis=0)
    return p_Y_val

In [None]:
# Get the covariance matrix from the training set
train_truth = pd.read_csv('/home/jwp/stage/sl/h0rton/v7_train_prior=DiagonalCosmoBNNPrior_seed=1113/metadata.csv', index_col=None)
train_truth = sim_utils.add_g1g2_columns(train_truth) # convert gamma_ext, psi_ext to g1, g2
train_cov = np.cov(train_truth[mcmc_Y_cols].values.T) # [Y_dim, Y_dim]

In [None]:
def plot_calibration(post_samples, y_mean, y_truth, cov, color_map=["#377eb8", "#4daf4a"], n_perc_points=20, figure=None, ls='--', legend=None, show_plot=True, block=True, title=None, dpi=200):
    """Plot the calibration metric for a grid of p_X percentages, with error bars
    obtained through jackknife sampling
    
    Parameters
    ----------
    See the docstring for `get_p_Y_val_approx_mahalanobis`.
    n_perc_points : int
        Grid size of p_X (probability volume) to compare p_Y against
    
    Notes
    -----
    Adapted from https://github.com/swagnercarena/ovejero
    
    """
    p_Y_val = get_p_Y_val_approx_mahalanobis(post_samples, y_mean, y_truth, cov=cov)

    # Plot what percentage of images have at most x% of draws with
    # p(draws)>p(true).
    percentages = np.linspace(0.0, 1.0, n_perc_points)
    p_images = np.zeros_like(percentages)
    if figure is None:
        fig = plt.figure(figsize=(8,8), dpi=dpi)
        plt.plot(percentages, percentages, c=color_map[0], ls='--', label=legend[0])
    else:
        fig = figure

    # We'll estimate the uncertainty in our plot using a jacknife method.
    p_images_jn = np.zeros((len(p_Y_val), n_perc_points))
    for pi in range(n_perc_points):
        percent = percentages[pi]
        p_images[pi] = np.mean(p_Y_val<=percent)
        for ji in range(len(p_Y_val)):
            samp_p_Y_val = np.delete(p_Y_val, ji)
            p_images_jn[ji,pi] = np.mean(samp_p_Y_val<=percent)
    # Estimate the standard deviation from the jacknife
    p_Y_val_std = np.sqrt((len(p_Y_val)-1)*np.mean(np.square(p_images_jn - np.mean(p_images_jn,axis=0)), axis=0))
    plt.plot(percentages, p_images, c=color_map[1], ls=ls, label=legend[1])
    # Plot the 1 sigma contours from the jacknife estimate to get an idea of our sample variance.
    plt.fill_between(percentages, p_images+p_Y_val_std, p_images-p_Y_val_std, color=color_map[1], alpha=0.2)
    if figure is None:
        plt.grid(True, ls='dotted', alpha=0.5)
        plt.xlabel(r'Fraction of posterior volume = $p_X$', fontsize=15)
        plt.ylabel(r'Fraction of validation lenses with truth in the volume = $p_Y^{\mathrm{val}}$', fontsize=15)
        plt.text(-0.03, 1,'Underconfident', fontsize=15)
        plt.text(0.80, 0,'Overconfident', fontsize=15)
    if legend is None:
        plt.legend(['Perfect calibration','Network Calibration'], loc=loc)
    else:
        plt.legend(fontsize=15, loc=9)
    if show_plot:
        plt.show(block=block)

    return fig

In [None]:
plt.close('all')
# Colors for each exposure time (in HST orbits)
colors_dict = dict(zip([2, 1, 0.5], ['#880519', '#c04546', '#f97978']))
dropout_to_linestyle = dict(zip([0, 0.001, 0.005], ['dotted', 'solid', 'dashdot']))
dropout_to_label = dict(zip([0, 0.001, 0.005], [r'$p_{\rm drop} = 0\%$', r'$p_{\rm drop} = 0.1\%$', r'$p_{\rm drop} = 0.5\%$']))

show_orbit = 0.5 # orbit for which the dropouts are compared
 
first_dropout = dropouts_available[0]
fig = plot_calibration(post_samples=np.transpose(samples[show_orbit][first_dropout], [2, 0, 1]), 
                       y_mean=mean[show_orbit][first_dropout][mcmc_Y_cols].values,
                       y_truth=truth[show_orbit][first_dropout][mcmc_Y_cols].values,
                       cov=train_cov, 
                       show_plot=False,
                       ls = dropout_to_linestyle[dropout_label_to_dropout_float[first_dropout]],
                       color_map=['tab:gray'] + [colors_dict[show_orbit]],
                       legend=['Perfect calibration', dropout_to_label[dropout_label_to_dropout_float[first_dropout]]])


for d in dropouts_available[1:]:
    fig = plot_calibration(post_samples=np.transpose(samples[show_orbit][d], [2, 0, 1]), 
                       y_mean=mean[show_orbit][d][mcmc_Y_cols].values,
                       y_truth=truth[show_orbit][d][mcmc_Y_cols].values,
                       cov=train_cov, 
                           figure=fig,
                       show_plot=False,
                       color_map=['tab:gray'] + [colors_dict[show_orbit]],
                           ls = dropout_to_linestyle[dropout_label_to_dropout_float[d]],
                       legend=['Perfect calibration', dropout_to_label[dropout_label_to_dropout_float[d]]])
#fig.savefig('../plots/calibration.png', bbox_inches='tight', pad_inches=0, dpi=200)
plt.show()

## 2. Model evaluation  <a name="model_eval"></a>

In [None]:
chosen_dropout = 'drop=0.001'

for param in ['lens_mass_center_x', 'src_light_center_x', 'lens_mass_center_y', 'src_light_center_y', 'lens_mass_gamma', 'lens_mass_theta_E', 'lens_mass_e1', 'lens_mass_e2', 'external_shear_gamma1', 'external_shear_gamma2', 'src_light_R_sersic']:
    print(param)
    for o in orbits_available:
        print("HST orbit(s): ", o)
        print("MAD: ", get_mad(mean[o][chosen_dropout][param].values, truth[o][chosen_dropout][param].values))
        print("bias: ", get_weighted_bias(mean[o][chosen_dropout][param].values, std[o][chosen_dropout][param].values, truth[o][chosen_dropout][param].values))
        print("precision: ", np.median(std[o][chosen_dropout][param].values))
        #print("h0rton: ", get_mad(h0rton_mean[param].values, h0rton_truth[param].values))
    print("===================================")

In [None]:
# We will take min_img_conf number of doubles/quads from the 512 validation lenses

chosen_orbit = 2
min_img_conf = min(mean[chosen_orbit][chosen_dropout].loc[truth[chosen_orbit][chosen_dropout]['n_img'] == 2].shape[0], mean[chosen_orbit][chosen_dropout].loc[truth[chosen_orbit][chosen_dropout]['n_img'] == 4].shape[0])
print(min_img_conf) # number of doubles/quads to take

In [None]:
for o in [2]:
    print("HST orbit(s): ", o)
    doubles_mean = mean[o][chosen_dropout].loc[truth[o][chosen_dropout]['n_img'] == 2].iloc[:min_img_conf]
    quads_mean = mean[o][chosen_dropout].loc[truth[o][chosen_dropout]['n_img'] == 4].iloc[:min_img_conf]
    doubles_std = std[o][chosen_dropout].loc[truth[o][chosen_dropout]['n_img'] == 2].iloc[:min_img_conf]
    quads_std = std[o][chosen_dropout].loc[truth[o][chosen_dropout]['n_img'] == 4].iloc[:min_img_conf]
    doubles_truth = truth[o][chosen_dropout].loc[truth[o][chosen_dropout]['n_img'] == 2].iloc[:min_img_conf]
    quads_truth = truth[o][chosen_dropout].loc[truth[o][chosen_dropout]['n_img'] == 4].iloc[:min_img_conf]
    for param in ['lens_mass_center_x', 'src_light_center_x', 'lens_mass_center_y', 'src_light_center_y', 'lens_mass_gamma', 'lens_mass_theta_E', 'lens_mass_e1', 'lens_mass_e2', 'external_shear_gamma1', 'external_shear_gamma2', 'src_light_R_sersic']:
        print(param)
        print("Doubles")
        print("MAD: ", get_mad(doubles_mean[param].values, doubles_truth[param].values))
        print("bias: ", get_weighted_bias(doubles_mean[param].values, doubles_std[param].values, doubles_truth[param].values))
        print("precision: ", np.median(doubles_std[param].values))
        print("Quads")
        print("MAD: ", get_mad(quads_mean[param].values, quads_truth[param].values))
        print("bias: ", get_weighted_bias(quads_mean[param].values, quads_std[param].values, quads_truth[param].values))
        print("precision: ", np.median(quads_std[param].values))
        print("===================================")