## This notebook shows how to optimize a cut on the standard deviation of p(z) by exploring the trade between efficiency and purity 

#### Do the usual imports

In [None]:
import tables_io
import numpy as np
import matplotlib.pyplot as plt
import qp
from rail.raruma import plotting_functions as raruma_plot
from rail.raruma import utility_functions as raruma_util

#### Change this to be the root of the current PZ working area

In [None]:
# pz_dir = '/global/cfs/cdirs/lsst/groups/PZ/DP1'  
# if that fails you can use this
# pz_dir = '/global/u2/e/echarles/dx'
pz_dir = '/Users/echarles/pz'

#### Get the data.  First the reference redshifts, then the estimates

In [None]:
d = tables_io.read(f"{pz_dir}/data/test/dp1_matched_v4_test.hdf5")
redshifts = d['redshift']

In [None]:
pz_dict = {}
algos = ['knn', 'fzboost', 'tpz', 'bpz', 'dnf', 'lephare', 'gpz', 'cmnn']
for algo in algos:
    pz_dict[algo] = qp.read(f"{pz_dir}/projects/dp1_v4/data/gold_dp1_6band_paper/output_estimate_{algo}.hdf5")

#### Functions we will use to do the optimization

In [None]:
def calc_frac_dels(qp_dstn, truth):
    delta = np.squeeze(qp_dstn.ancil['zmode']) - truth
    return delta / (1 + truth)

In [None]:
def calc_std(qp_dstn, grid):
    pdfs = qp_dstn.pdf(grid)
    norms = pdfs.sum(axis=1)
    means = np.sum(pdfs * grid, axis=1) / norms
    diffs = (np.expand_dims(grid, -1) - means).T
    wt_diffs = diffs * diffs * pdfs
    stds = np.sqrt((wt_diffs).sum(axis=1)/norms)
    return np.expand_dims(stds, -1)

In [None]:
def effic_and_purirty_curves(std, frac_dels, cuts, purity_cut=0.05):
    effic_list = []
    purity_list = []    
    for cut_ in cuts:
        mask = np.abs(np.squeeze(std)) < cut_
        effic = mask.sum() / mask.size
        good_mask = np.abs(np.squeeze(frac_dels[mask])) < purity_cut
        purtiy = good_mask.sum() / good_mask.size        
        effic_list.append(effic)
        purity_list.append(purtiy)        
    return np.array(effic_list), np.array(purity_list)

#### Compute the standard deviations of p(z) using a grid from 0 to 4.

In [None]:
grid = np.linspace(0., 4., 401)
std = calc_std(pz_dict['knn'], grid)


In [None]:
std_dict = {}
for algo in algos:
    std_dict[algo] = calc_std(pz_dict[algo], grid)

#### Compute the fractional deviations

In [None]:
frac_dels = calc_frac_dels(pz_dict['knn'], redshifts)

In [None]:
frac_del_dict = {}
for algo in algos:
    frac_del_dict[algo] = calc_frac_dels(pz_dict[algo], redshifts)

#### Make a grid to scan the cut on p(z)

In [None]:
cuts = np.linspace(0.01, 1.50, 100)

#### 
Get the efficiency and purity as a fuction of the cut

In [None]:
effic, purity = effic_and_purirty_curves(std, frac_dels, cuts, 0.15)

In [None]:
eff_dict = {}
pur_dict = {}
for algo in algos:
    eff_dict[algo], pur_dict[algo] = effic_and_purirty_curves(std_dict[algo], frac_del_dict[algo], cuts, 0.15)

### Make some plots

#### Efficiency v. cut value curve

In [None]:
fig = plt.figure()
for algo in algos:
    _ = plt.plot(cuts, eff_dict[algo], label=algo)
_ = plt.xlabel(r'cut on $\sigma_{p(z)}$')
_ = plt.ylabel('Efficiency')
_ = plt.legend()
fig.savefig('efficiency.pdf')

#### Purity v. cut value curve

In [None]:
fig = plt.figure()
for algo in algos:
    _ = plt.plot(cuts, pur_dict[algo], label=algo)
_ = plt.xlabel(r'cut on $\sigma_{p(z)}$')
_ = plt.ylabel(r'Purity: ($\frac{\delta z}{1 + z_{\rm spec}} < 0.15)$')
_ = plt.ylim(0.8, 1.)
_ = plt.legend()
fig.savefig('purity.pdf')

#### Efficency v. Purity curve with cut value in color

In [None]:
fig = plt.figure()
_ = plt.scatter(purity, effic, c=cuts)
mask = np.abs(cuts-0.15) < 0.0001
_ = plt.scatter(purity[mask], effic[mask], marker="*", s=50, color="red")

_ = plt.xlabel('Purity')
_ = plt.ylabel('Efficiency')
_ = plt.ylim(0, 1.)
_ = plt.xlim(0.875, 1.)
_ = plt.colorbar()
fig.savefig('purity_v_effic.pdf')

#### Efficency v. Purity curve for all algos

In [None]:
fig = plt.figure()
for algo in algos:
    _ = plt.plot(pur_dict[algo], eff_dict[algo], label=algo)
#mask = np.abs(cuts-0.15) < 0.0001
#_ = plt.scatter(purity[mask], effic[mask], marker="*", s=50, color="red")

_ = plt.xlabel('Purity')
_ = plt.ylabel('Efficiency')
_ = plt.ylim(0, 1.)
_ = plt.xlim(0.80, 1.)
_ = plt.legend()
#_ = plt.colorbar()
#fig.savefig('purity_v_effic.pdf')