In [None]:
import tables_io
import numpy as np
import matplotlib.pyplot as plt
from rail.raruma import plotting_functions as raruma_plot
from rail.raruma import utility_functions as raruma_util
from rail.core.data import DataStore
DataStore.allow_overwrite = True
import jax.numpy as jnp
import jax
import scipy.differentiate as scipy_diff

In [None]:
from rail.estimation.estimator import CatEstimatorWrapper
from rail.estimation.algos.k_nearneigh import KNearNeighEstimator
from rail.utils.catalog_utils import RubinCatalogConfig
RubinCatalogConfig.apply(RubinCatalogConfig.tag)

In [None]:
model_file = './model_inform_knn.pkl'

In [None]:
a = np.arange(4)

In [None]:
np.repeat(np.atleast_2d(a), 4, axis=0)

In [None]:
knn = KNearNeighEstimator.make_stage(name='knn', model=model_file, input='dummy.in', output_mode='return', nzbins=3001)
knn_w = CatEstimatorWrapper(knn)

In [None]:
d = tables_io.read("/Users/echarles/pz/sandbox_data/roman_rubin_9925.hdf5")
train = tables_io.sliceObj(d, slice(0, -1, 10))
band_names = raruma_util.make_band_names('LSST_obs_{band}', 'ugrizy')
mags = raruma_util.extract_data_to_2d_array(train,band_names)
knn.stage_columns = knn.config.bands

In [None]:
obj1 = mags[0]

In [None]:
obj1.copy()

In [None]:
out = knn_w(obj1)

In [None]:
out

In [None]:
class WrapFunc:
    def __init__(self, func, vals, index):
        self.func = func
        self.vals = vals.copy()
        self.index = index
        
    def __call__(self, x):
        try:
            nd = len(x.shape)
        except:
            nd = 1        
        if nd > 1:
            vals = np.repeat(np.atleast_2d(self.vals), x.shape[-1], axis=0)
            ret_list = []
            vals[:,self.index] = x
            ret_vals = np.atleast_2d(self.func(vals))
            return ret_vals
        vals = self.vals.copy()
        vals[self.index] = float(x)
        ret_val = self.func(vals)
        return ret_val

def gradient(func, x_vals):
    n_val = len(x_vals)
    out = np.zeros(n_val)
    for i in range(n_val):
        wf = WrapFunc(func, x_vals, i)
        dd = scipy_diff.derivative(wf, x_vals[i], maxiter=4, initial_step=0.1, tolerances=dict(atol=0.2))
        out[i] = float(dd.df)
    return out
        

In [None]:
grad = gradient(knn_w, mags[3])

In [None]:
# grad

In [None]:
scipy_diff.jacobian(knn_w, mags[3])

In [None]:
grads = []
for i in range(1000):
    if i % 50 == 0:
        print(i)
    grads.append(gradient(knn_w, mags[i]))
gout = np.array(grads)

In [None]:
gout

In [None]:
import matplotlib.pyplot as plt

In [None]:
_ = plt.scatter(mags[0:1000, 3], np.log10(np.abs(gout[:,4]) + 0.0001))

In [None]:
_ = plt.hist(np.asinh(gout[:,4]), bins=500)
_ = plt.xscale('symlog')
_ = plt.xlim(-1, 1.)

In [None]:
_ = plt.scatter((gout[:,4]), np.asinh(gout[:,4]))
_ = plt.xlim(-400, 400)
_ = plt.ylim(-10, 10)