In [None]:
import tables_io
import numpy as np
import matplotlib.pyplot as plt
import pyccl as ccl
from rail.raruma import utility_functions as raruma_util
from rail.raruma import plotting_functions as raruma_plot
from rail.raruma import admixture_functions as raruma_admix
from rail.raruma import wrapper_classes as raruma_wrap
from scipy.stats import sigmaclip
from astropy.stats import biweight_location, biweight_scale


In [None]:
# pz_dir = '/global/cfs/cdirs/lsst/groups/PZ/DP1'  
# if that fails you can use this
# pz_dir = '/global/u2/e/echarles/dx'
pz_dir = '/Users/echarles/pz'

In [None]:
train = tables_io.read(f"{pz_dir}/data/train/dp1_ecdfs_matched_specgold_train.hdf5")
test = tables_io.read(f"{pz_dir}/data/test/dp1_ecdfs_matched_specgold_test.hdf5")
#d.keys()
#train = tables_io.sliceObj(d, slice(0, -1, 10))
#test = tables_io.sliceObj(d, slice(1, -1, 10))
d = tables_io.read(f"{pz_dir}/data/on_sky/dp1_v29.0.0_gold_all.hdf5")
model_file_knn = './model_inform_knn.pkl'
model_file_bpz = './model_inform_bpz.pkl'
model_file_fzb = './model_inform_fzboost.pkl'
model_file_gpz = './model_inform_gpz.pkl'

In [None]:
from rail.estimation.algos.k_nearneigh import KNearNeighEstimator
from rail.estimation.algos.bpz_lite import BPZliteEstimator
from rail.estimation.algos.flexzboost import FlexZBoostEstimator
from rail.estimation.algos.gpz import GPzEstimator
from rail.utils.catalog_utils import RubinCatalogConfig
RubinCatalogConfig.apply(RubinCatalogConfig.tag)
band_names = raruma_util.make_band_names('LSST_obs_{band}', 'ugrizy')
error_band_names = raruma_util.make_band_names('LSST_obs_{band}_err', 'ugrizy')

In [None]:
train_features = raruma_util.get_band_values(train, '{band}_gaap1p0Mag', 'ugrizy')
error_bands = raruma_util.get_band_values(train, '{band}_gaap1p0MagErr', 'ugrizy')

In [None]:
library = raruma_util.get_band_values(d, '{band}_gaap1p0Mag', 'ugrizy')

In [None]:
knn = KNearNeighEstimator.make_stage(name='knn', model=model_file_knn, input='dummy.in', output_mode='return', nzbins=3001)
knn.stage_columns = knn.config.bands
knn_w = raruma_wrap.CatEstimatorWrapper(knn, band_names)

In [None]:
gpz = KNearNeighEstimator.make_stage(name='gpz', model=model_file_gpz, input='dummy.in', output_mode='return', nzbins=3001)
gpz.stage_columns = gpz.config.bands
gpz_w = raruma_wrap.CatEstimatorWrapper(gpz, band_names)

In [None]:
bpz = BPZliteEstimator.make_stage(name='bpz', model=model_file_bpz, input='dummy.in', output_mode='return', nzbins=3001)
bpz.stage_columns = bpz.config.bands
bpz_w = raruma_wrap.CatEstimatorWrapper(bpz, band_names+error_band_names)

In [None]:
fzb = FlexZBoostEstimator.make_stage(name='fzboost', model=model_file_fzb, input='dummy.in', output_mode='return', nzbins=3001, calc_summary_stats=True)
fzb.stage_columns = fzb.config.bands
fzb_w = raruma_wrap.CatEstimatorWrapper(fzb, band_names+error_band_names, point_estimate='z_mode')

In [None]:
admix_grid = np.logspace(-4, 0, 17)

In [None]:
band_names+error_band_names

In [None]:
def doit(wrapper, nclip=3):
    the_dict = {}
    inputs = np.hstack([train_features, error_bands])
    est_orig = wrapper(inputs.T)
    means = []
    stds = []
    outlier_fracs = []
    for admix in admix_grid:
        mixed_mags = raruma_admix.make_admixture(train_features, library, admixture=admix)
        inputs = np.hstack([mixed_mags, error_bands])
        ad_vals = wrapper(inputs.T)
        raruma_plot
        deltas = (ad_vals - est_orig)/(1 + est_orig)
        subset_clip, _, _ = sigmaclip(deltas, low=3, high=3)
        for _j in range(nclip):
            subset_clip, _, _ = sigmaclip(subset_clip, low=3, high=3)

        the_dict[admix] = deltas
        outliers = (np.fabs(deltas) > 0.15).sum() / float(deltas.size)
        outlier_fracs.append(outliers)
        #means.append(deltas.mean())
        #stds.append(deltas.std())
        means.append(biweight_location(subset_clip))
        stds.append(biweight_scale(subset_clip))

    _ = plt.plot(admix_grid, means, label=r"mean $\delta z$")
    _ = plt.plot(admix_grid, stds, label=r"RMS $\delta z$")
    _ = plt.plot(admix_grid, outlier_fracs, label=r"f $\delta z > 0.15$")
    _ = plt.xscale('log')
    _ = plt.legend()
    _ = plt.xlabel("Flux Admixture Fraction")
    _ = plt.ylabel(r"$\delta z$")

In [None]:
doit(fzb_w)

In [None]:
doit(knn_w)

In [None]:
doit(gpz_w)

In [None]:
for k, v in the_dict.items():
    _ = plt.hist(v, bins=np.linspace(-1, 1, 101), label=k)