# Evaluating the KLD from posterior samples of cosmological parameters

_Alex I. Malz (GCCL@RUB)_

In [None]:
import itertools
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle as pkl
import re
from scipy import stats as sps
import sys

In [None]:
import proclam
from proclam.metrics.util import *
from proclam.metrics.util import RateMatrix

## simple demo

We will begin with samples of $(w, \Omega_{m})$ pairs, where one set of samples is defined as the reference sample corresponding to a best-case scenario of a 100% pure SN Ia data set.

In [None]:
# # replace with reading in the data
# def measure(n, w_bar, w_sig, Omm_bar,Omm_sig):
#     "Measurement model, return two coupled measurements."
#     w = np.random.normal(loc=w_bar, scale=w_sig, size=n)
#     Omm = np.random.normal(loc=Omm_bar, scale=Omm_sig, size=n)
#     return w, Omm

def measure(path, cols):
    alldims = pkl.load(open(path, 'rb'))
    return [alldims[col] for col in cols]

In [None]:
# '/media/RESSPECT/data/PLAsTiCC/SALT2mu_posteriors/perfect_classifier/chains_plasticc_perfect.pkl'
# '/media/RESSPECT/data/PLAsTiCC/SALT2mu_posteriors/static/DDF/train_10/batch_10/UncSampling/chains/chains_loop_99.pkl'
# '/media/emille/git/COIN/RESSPECT_work/PLAsTiCC/metrics_paper/resspect_metric/posteriors/'
# '/media/RESSPECT/data/PLAsTiCC/for_metrics/posteriors/'
# '/media/emille/data/PLAsTiCC/posteriors/'
# '/media2/RESSPECT2/data/posteriors'
postpath = '/media2/RESSPECT2/data/posteriors/'

refpath = postpath+'perfect/chains_perfect.pkl'
comppath = postpath+'fiducial/chains_fiducial.pkl'

kinda slow

In [None]:
# w_ref, Omm_ref = measure(1000, -1., 0.1, 0.5, 0.1)
# w_comp, Omm_comp = measure(1000, -1.1, 0.2, 0.25, 0.05)

[w_ref, Omm_ref] = measure(refpath, ['w', 'om'])
[w_comp, Omm_comp] = measure(comppath, ['w', 'om'])

In [None]:
# plt.scatter(w_ref, Omm_ref, s=1, alpha=0.2, label='best possible')
# plt.scatter(w_comp, Omm_comp, s=1, alpha=0.2, label='approximation')
# plt.legend(loc='lower left')

[`chippr`](https://github.com/aimalz/chippr/) contains code for calculating the KLD of PDFs evaluated on a grid, so we start by fitting a 2D KDE to the samples.
The PDFs must be $\geq0$ over the entire range of the grid, so we make a grid based on the reference sample's range.

In [None]:
# ngrid_x = 100
# ngrid_y = 100
# xmin = w_ref.min()#-1.2
# xmax = w_ref.max()#-0.8
# ymin = Omm_ref.min()#0.2
# ymax = Omm_ref.max()#0.4

# w_grid, Omm_grid = np.mgrid[xmin:xmax:100*1.j, ymin:ymax:100*1.j]
# w_vec, Omm_vec = w_grid[:, 0], Omm_grid[0, :]
# dw = (xmax - xmin) / (ngrid_x - 1)
# dOmm = (ymax - ymin) / (ngrid_y - 1)
# # use meshgrid instead of mgrid

def make_grid(x, y, x_ngrid=100, y_ngrid=100):
    x_min = x.min()#-1.2
    x_max = x.max()#-0.8
    y_min = y.min()#0.2
    y_max = y.max()#0.4

    x_grid, y_grid = np.mgrid[x_min:x_max:x_ngrid*1.j, y_min:y_max:y_ngrid*1.j]
    x_vec, y_vec = x_grid[:, 0], y_grid[0, :]
    dx = (x_max - x_min) / (x_ngrid - 1)
    dy = (y_max - y_min) / (y_ngrid - 1)

    return(((x_min, y_min), (x_max, y_max)), (x_grid, y_grid), (x_vec, y_vec), (dx, dy))

In [None]:
ref_extrema, ref_grids, ref_vecs, ref_ds = make_grid(w_ref, Omm_ref)
(w_vec, Omm_vec) = ref_vecs
(dw, dOmm) = ref_ds
((xmin, ymin), (xmax, ymax)) = ref_extrema
(w_grid, Omm_grid) = ref_grids

In [None]:
# plt.hist2d(w_ref, Omm_ref, bins=[w_vec, Omm_vec], density=True, cmap=plt.cm.Blues, alpha=0.5)
# plt.hist2d(w_comp, Omm_comp, bins=[w_vec, Omm_vec], density=True, cmap=plt.cm.Reds, alpha=0.5)

In [None]:
# hist_ref, xgrid, ygrid = np.histogram2d(w_ref, Omm_ref, bins=[w_vec, Omm_vec], density=True)
# hist_comp, xgrid, ygrid = np.histogram2d(w_comp, Omm_comp, bins=[w_vec, Omm_vec], density=True)
# print(np.sum(hist_ref * ((np.ones_like(hist_ref) * dw).T * dOmm).T))
# print(np.sum(hist_comp * ((np.ones_like(hist_comp) * dw).T * dOmm).T))
# plt.imshow(hist_ref, origin='lower', extent=[xmin, xmax, ymin, ymax], cmap=plt.cm.Blues, alpha=0.5)
# plt.imshow(hist_comp, origin='lower', extent=[xmin, xmax, ymin, ymax], cmap=plt.cm.Reds, alpha=0.5)

In [None]:
eps = 2. * sys.float_info.min

def safe_log(arr, threshold=eps):
    """
    Takes the natural logarithm of an array that might contain zeros.

    Parameters
    ----------
    arr: ndarray, float
        array of values to be logged
    threshold: float, optional
        small, positive value to replace zeros and negative numbers

    Returns
    -------
    logged: ndarray
        logged values, with small value replacing un-loggable values
    """
    arr = np.asarray(arr)
    arr[arr < threshold] = threshold
    logged = np.log(arr)
    return logged

def make_kde(Xgrid, Ygrid, Xsamps, Ysamps, to_log=False):
    positions = np.vstack([Xgrid.ravel(), Ygrid.ravel()])
    values = np.vstack([Xsamps, Ysamps])
    kernel = sps.gaussian_kde(values, bw_method='scott')#'scott'
    Z = np.reshape(kernel(positions).T, Xgrid.shape)
    if to_log:
        return save_log(Z)
    else:
        return Z
# TODO: normalize up here before log!

In [None]:
kde_ref = make_kde(w_grid, Omm_grid, w_ref, Omm_ref)
# plt.imshow(kde_ref, extent=[xmin, xmax, ymin, ymax], origin='lower', cmap=plt.cm.Blues)

In [None]:
# replace with reading in other sets of posteriors
kde_comp = make_kde(w_grid, Omm_grid, w_comp, Omm_comp)
# plt.imshow(kde_comp, extent=[xmin, xmax, ymin, ymax], origin='lower', cmap=plt.cm.Reds)

Now that we have the 2D PDFs, let's define the KLD.

In [None]:
# stolen from chippr
def calculate_kld(lpe, lqe, dx, from_log=False, vb=True):
    """
    Calculates the Kullback-Leibler Divergence between two N-dimensional PDFs 
    evaluated on a shared, regular grid (sorry, too lazy to deal with irregular grid)

    Parameters
    ----------
    lpe: numpy.ndarray, float
        log-probability distribution evaluated on a grid whose distance from `q`
        will be calculated.
    lqe: numpy.ndarray, float
        log-probability distribution evaluated on a grid whose distance to `p` will
        be calculated.
    dx: numpy.ndarray, float
        separation of grid values in each dimension
    from_log: boolean, optional
        if False, lpe, lqe are probability distributions, not log-probability distributions
    vb: boolean, optional
        report on progress to stdout?

    Returns
    -------
    Dpq: float
        the value of the Kullback-Leibler Divergence from `q` to `p`
    """
    # Normalize the evaluations, so that the integrals can be done
    gridnorm = np.ones_like(lpe) * np.prod(dx)
    if from_log:
        pe = np.exp(lpe)
        qe = np.exp(lqe)
#     print(np.prod(dx))
#     print(gridnorm)
    else:
        pe = lpe
        qe = lqe
    pi = np.sum(pe * gridnorm)
    qi = np.sum(qe * gridnorm)
    # (very approximately!) by simple summation:
    pn = pe / pi
    qn = qe / qi
    # Compute the log of the normalized PDFs
    logp = safe_log(pn)
    logq = safe_log(qn)
    # Calculate the KLD from q to p
    Dpq = np.sum(pn * (logp - logq))
    return Dpq

Now we can evaluate it for our reference sample and a comparison sample.

In [None]:
calculate_kld(kde_ref, kde_comp, np.array([dw, dOmm]))

could also do this as a function of chain iteration

### also some deterministic metrics for comparison

wouldn't need this if I'd been smart enough to put it in `proclam` already. . .

In [None]:
class det_mets(RateMatrix):
    "binary classification metrics"
    def __init__(self, **rates):
        """
        Call like `thing = det_mets(**rates._asdict())`
        """
#         self.rates = rates#.asdict()
        self._get_tots()
        self._from_rates()
        self._sn_mets()
        self._translate()
    def _get_tots(self):
        self.CP = self.TP + self.FN
        self.CN = self.TN + self.FP
        self.T = self.TP + self.TN
        self.F = self.FP + self.FN
        self.P = self.TP + self.FP
        self.N = self.TN + self.FN
    def _from_rates(self):
        self.PPV = self.TP / (self.TP + self.FP)
        self.NPV = self.TN / (self.TN + self.FN)
        self.PT = (np.sqrt(self.TPR * (1. - self.TNR)) + self.TNR - 1.) / (self.TPR + self.TNR - 1.)
        self.TS = self.TP / (self.TP + self.FN + self.FP)
        self._derived()
    def _derived(self):
        self.ACC = (self.TP + self.TN) / (self.CP + self.CN)
        self.BA = (self.TPR + self.TNR) / 2,
        self.F1S = 2. * self.PPV * self.TPR / (self.PPV + self.TPR)
        self.MCC = (self.TP * self.TN - self.FP * self.FN) / (np.sqrt(self.P * self.CP * self.CN * self.N))
        self.FM = np.sqrt(self.PPV * self.TPR)
        self.BM = self.TPR + self.TNR - 1.
        self.MK = self.PPV + self.NPV - 1.
    def _translate(self):
        self.positive = self.CP
        self.negative = self.CN
        self.sensitivity = self.TPR
        self.recall = self.TPR
        self.specificity = self.TNR
        self.selectivity = self.TNR
        self.precision = self.PPV
        self.FDR = 1. - self.PPV
        self.FOR = 1. - self.NPV
        self.CSI = self.TS
        self.accuracy = self.ACC
        self.f1_score = self.F1S
        self.informedness = self.BM
        self.deltaP = self.MK
    def _sn_mets(self):
        self.get_efficiency()
        self.get_purity()
    def get_efficiency(self):
        self.efficiency = self.TP / self.CP
        return self.efficiency
    def get_purity(self):
        self.purity = self.TP / self.P
        return self.purity
    def get_fom(self, penalty):
        self.pseudo_purity = self.TP / (self.TP + penalty * self.FP)
        return self.pseudo_purity * self.efficiency

## with "real" contaminated samples

In [None]:
savedpath = '/media/RESSPECT/data/PLAsTiCC/for_metrics/'
metpaths = {field: savedpath+field+'/metrics/' for field in ['ddf', 'wfd']}
# metpaths = {field: savedpath+'metrics/' for field in ['ddf']}

In [None]:
maybe_sn_classes = {90: 'SNIa', 
                    67: 'SNIa-91bg', 
                    52: 'SNIax', 
                    42: 'SNII', 
                    62: 'SNIbc', 
                    95: 'SLSN-I', 
                    88: 'AGN'}
maybe_sn_classes[15] = 'TDE'
maybe_sn_classes[64] = 'KN'

sel_class = 90

# ia_percents = np.array([50, 68, 75, 90, 95, 98, 99])
# mix_percents = 100 - ia_percents
contaminants = maybe_sn_classes.copy()
contaminants.pop(sel_class)

evaluate on the grid for the perfect samples as reference

The KDEs are the slow step here. . . don't run this more than once

TODO: save the KDEs, just in case?

In [None]:
d_ref, grid_ref, kde_ref = {}, {}, {}
for field in ['ddf', 'wfd']:
    if field == 'wfd':
        prepend = 'WFD/'
    else:
        prepend = ''
    [w_ref, Omm_ref] = measure(postpath+prepend+'perfect/chains_perfect.pkl', ['w', 'om'])
    ref_extrema, ref_grids, ref_vecs, ref_ds = make_grid(w_ref, Omm_ref)
    (w_vec, Omm_vec) = ref_vecs
    (dw, dOmm) = ref_ds
    ((xmin, ymin), (xmax, ymax)) = ref_extrema
    (w_grid, Omm_grid) = ref_grids
    d_ref[field] = {'w': dw, 'Omm': dOmm}
    grid_ref[field] = {'w': w_grid, 'Omm': Omm_grid}
    kde_ref[field] = make_kde(w_grid, Omm_grid, w_ref, Omm_ref)

In [None]:
allmets = pd.read_csv(savedpath+'directory.csv')
allmets['KLD'] = None
allmets = allmets.drop_duplicates(ignore_index=True)


for ind in allmets.index:
    row = allmets.loc[ind]
    testname = str(100-row['percent'])+str(maybe_sn_classes[sel_class])+str(row['percent'])+row['contaminant']
    if row['field'] == 'wfd':
        comppath = postpath+'WFD/'+testname
    else:
        comppath = postpath+testname
    compfn = comppath+'/chains_'+testname+'_lowz_withbias.pkl'
    print(compfn)
    if os.path.exists(compfn):
        [w_comp, Omm_comp] = measure(compfn, ['w', 'om'])
        kde_comp = make_kde(grid_ref[row['field']]['w'], grid_ref[row['field']]['Omm'], w_comp, Omm_comp)
        allmets['KLD'].loc[ind] = calculate_kld(kde_ref[row['field']], kde_comp, 
                                                np.array([d_ref[row['field']]['w'], d_ref[row['field']]['Omm']]))
allmets.to_csv(savedpath+'KLD.csv', index=False)

run me only once

In [None]:
allmets = pd.read_csv(savedpath+'KLD.csv')
allmets['fom1'] = None
allmets['fom3'] = None
allmets['purity'] = None
allmets['efficiency'] = None
allmets['f1'] = None
for ind in allmets.index:
    row = allmets.loc[ind]
    concode = list(contaminants.keys())[list(contaminants.values()).index(row['contaminant'])]
    testname = f"{100-row['percent']}_{sel_class}_{row['percent']}_{concode}"
    metfn = metpaths[row['field']]+testname+'.pkl'#f'{100-perc}_{sel_class}_{perc}_{key}.pkl'
    with open(metfn, 'rb') as metfile:
        rates = proclam.util.RateMatrix(**pkl.load(metfile))
        ratedict = rates._asdict()
        mets = det_mets(**ratedict)
    allmets['purity'].loc[ind] = mets.purity
    allmets['efficiency'].loc[ind] = mets.efficiency
    allmets['f1'].loc[ind] = mets.f1_score
    allmets['fom1'].loc[ind] = mets.get_fom(1.)
    allmets['fom3'].loc[ind] = mets.get_fom(3.)
allmets.to_csv(savedpath+'FOM.csv', index=False)

In [None]:
allmets = pd.read_csv(savedpath+'FOM.csv')

In [None]:
allmets

## plotting cosmo/principled vs. traditional/deterministic metrics

~~TODO: actually make these~~
~~- [X] by contaminant (markershape)~~
~~- [X] by contamination rate (continuous colors? or markersize?)~~
~~- [X] by field (markersize? or discrete colors?)~~

TODO:
- [X] hardcode the shapes
- [X] open/closed for field
- [ ] logscale colors
- [X] check WFD directory for posteriors

In [None]:
# sizes = {'ddf': 50, 'wfd': 150}
all_shapes = {}
# for i, (k, v) in enumerate(maybe_sn_classes.items()):
#     shapes[v] = (np.mod(i, 3)+3, int(i / 3), np.mod(i, 4)*45)
shape_pairs = [('.', 'o'), ('1', 'v'), ('2', '^'), ('3', '<'), ('4', '>'), ('+', 'P'), ('x', 'X'), ('*', 'p')]
for i, field in enumerate(['ddf', 'wfd']):
    shapes = {}
    for j, contaminant in enumerate(allmets['contaminant'].unique()):
        shapes[contaminant] = shape_pairs[j][i]
    all_shapes[field] = shapes

colors = {i: plt.get_cmap('plasma_r')(i) for i in allmets['percent'].unique()}

alldets = ['f1', 'fom1', 'fom3', 'purity', 'efficiency']
dim = len(alldets)

fave_cmap = 'viridis_r'#'plasma_r'#'cool'

TODO: add in 1D histograms for metric value for each contaminant, for percent

In [None]:
pairs = list(itertools.combinations(range(dim), 2))
fig, axs = plt.subplots(dim-1, dim-1, figsize=(5*(dim-1), 5*(dim-1)))
norm = mpl.colors.Normalize(vmin=0., vmax=50.)
fig.colorbar(mpl.cm.ScalarMappable(cmap=plt.get_cmap(fave_cmap), norm=norm), cax=axs[-1][0], 
             ticks=allmets['percent'].unique())

for i, c in enumerate(allmets['contaminant'].unique()):
    for f in ['ddf', 'wfd']:
        axs[-1][0].scatter(-1, -1, marker=all_shapes[f][c], color='k', label=f+':'+c)
        one_c = allmets[(allmets['contaminant'] == c) & (allmets['field'] == f)]
        for pair in pairs:
            axs[pair[0]][pair[1]-1].scatter(one_c[alldets[pair[0]]], one_c[alldets[pair[1]]],
                                   alpha=0.5, s=100,#[sizes[f] for f in one_c['field']], 
                                   marker=all_shapes[f][c], 
                                   color=plt.get_cmap(fave_cmap)(one_c['percent']/50.))
for pair in pairs:
    axs[pair[0]][pair[1]-1].set_xlabel(alldets[pair[0]])
    axs[pair[0]][pair[1]-1].set_ylabel(alldets[pair[1]])
axs[-1][0].legend()
fig.savefig('draft_dets.png', bbox_inches='tight', pad_inches=0)

In [None]:
dets_to_plot

In [None]:
dets_to_plot = ['fom3']
dim = len(dets_to_plot)

fig, axs = plt.subplots(1, dim+1, figsize=(5*(dim+1), 5))#,
#                         gridspec_kw={'width_ratios': [10]*dim+[1]})
norm = mpl.colors.Normalize(vmin=0., vmax=50.)
fig.colorbar(mpl.cm.ScalarMappable(cmap=plt.get_cmap(fave_cmap), norm=norm), cax=axs[-1], 
             ticks=allmets['percent'].unique())
for i, metric in enumerate(dets_to_plot):
    for field in sizes.keys():
        for contaminant in shapes.keys():
            plotmask = allmets[(allmets['field'] == field) & (allmets['contaminant'] == contaminant)]
            axs[i].scatter(plotmask[metric], plotmask['KLD'], alpha=0.75, s=100,
                           marker=all_shapes[field][contaminant], 
                           color=plt.get_cmap(fave_cmap)(plotmask['percent']/50.))
    axs[i].semilogy()
    axs[i].set_xlabel(metric)
    axs[i].set_ylabel('KLD')
    axs[i].set_xlim(-0.1, 1.1)
for field in sizes.keys():
#         axs[i].scatter(-1, -1, s=100, marker='o', color='k', 
#                        label=field)
    for contaminant in contaminants.values():
        axs[-1].scatter(-1, -1, s=50, marker=all_shapes[field][contaminant], color='k', 
                       label=field+':'+contaminant)
axs[-1].legend(loc='lower left', fontsize='small')
fig.savefig('draft_kld.png', bbox_inches='tight', pad_inches=0)

TODO: thought re: histograms of w\_est, sigma\_w, make two versions: color by which contaminant (discrete colors) and by % contamninant (continuous colors)

these live in postpath = '/media2/RESSPECT2/data/posteriors/' stan\_summary .dat

In [None]:
allmets