In [None]:
import tables_io
import numpy as np
import matplotlib.pyplot as plt
from rail.raruma import plotting_functions as raruma_plot
from rail.raruma import utility_functions as raruma_util
from rail.core.data import DataStore
from rail.core.stage import RailStage
DataStore.allow_overwrite = True
import jax.numpy as jnp
import jax
import scipy.differentiate as scipy_diff

In [None]:
from rail.estimation.algos.k_nearneigh import KNearNeighEstimator
from rail.utils.catalog_utils import RubinCatalogConfig
RubinCatalogConfig.apply(RubinCatalogConfig.tag)
from rail.raruma.wrapper_classes import CatEstimatorDerivativeWrapper, CatEstimatorJacobianWrapper, CatEstimatorHessianWrapper

In [None]:
model_file = './model_inform_knn.pkl'

In [None]:
d = tables_io.read("/Users/echarles/pz/sandbox_data/roman_rubin_9925.hdf5")
train = tables_io.sliceObj(d, slice(0, 1))
band_names = raruma_util.make_band_names('LSST_obs_{band}', 'ugrizy')
mags = raruma_util.extract_data_to_2d_array(train,band_names)

In [None]:
mags.T

In [None]:
init_values = {key:val for key, val in zip(band_names, mags.T)}

In [None]:
init_values

In [None]:
knn = KNearNeighEstimator.make_stage(name='knn', model=model_file, input='dummy.in', output_mode='return', nzbins=3001)
knn.stage_columns = knn.config.bands
knn_j = CatEstimatorJacobianWrapper(knn, band_names)
knn_d = CatEstimatorDerivativeWrapper(knn, 'LSST_obs_g', init_values=init_values)
knn_h = CatEstimatorHessianWrapper(knn, band_names)

In [None]:
knn_j._estimator.model

In [None]:
out = knn_j(mags.T)
print(out)

In [None]:
mags.T

In [None]:
mags.T[1]

In [None]:
out = knn_d(mags.T[1])

In [None]:
from scipy.differentiate import jacobian, hessian, derivative

In [None]:
x = np.array([mags.T[1], mags.T[1]+1.]).T

In [None]:
x.shape

In [None]:
mags.T

In [None]:
hesse = hessian(knn_h, mags.T, initial_step=0.4, tolerances=dict(atol=0.1, rtol=0.1))

In [None]:
hesse.success.shape

In [None]:
mags.T.shape

In [None]:
out = knn_d(x)

In [None]:
g_vals = np.linspace(24, 24.5, 131)

In [None]:
z_vals = knn_d(g_vals)

In [None]:
_ = plt.plot(g_vals, z_vals)

In [None]:
der = derivative(knn_d, g_vals, initial_step=0.2, tolerances=dict(atol=0.1, rtol=0.1))

In [None]:
_ = plt.plot(der.x, der.df)

In [None]:
mags.T

In [None]:
jac_matrix = jacobian(knn_j, mags.T, initial_step=0.1, tolerances=dict(atol=0.1, rtol=0.1))


In [None]:
jac_matrix

In [None]:
hesse

In [None]:
class WrapFunc:
    def __init__(self, func, vals, index):
        self.func = func
        self.vals = vals.copy()
        self.index = index
        
    def __call__(self, x):
        try:
            nd = len(x.shape)
        except:
            nd = 1        
        if nd > 1:
            vals = np.repeat(np.atleast_2d(self.vals), x.shape[-1], axis=0)
            ret_list = []
            vals[:,self.index] = x
            ret_vals = np.atleast_2d(self.func(vals))
            return ret_vals
        vals = self.vals.copy()
        vals[self.index] = float(x)
        ret_val = self.func(vals)
        return ret_val

def gradient(func, x_vals):
    n_val = len(x_vals)
    out = np.zeros(n_val)
    for i in range(n_val):
        wf = WrapFunc(func, x_vals, i)
        dd = scipy_diff.derivative(wf, x_vals[i], maxiter=4, initial_step=0.1, tolerances=dict(atol=0.2))
        out[i] = float(dd.df)
    return out
        

In [None]:
grad = gradient(knn_w, mags[3])

In [None]:
# grad

In [None]:
scipy_diff.jacobian(knn_w, mags[3])

In [None]:
grads = []
for i in range(1000):
    if i % 50 == 0:
        print(i)
    grads.append(gradient(knn_w, mags[i]))
gout = np.array(grads)

In [None]:
gout

In [None]:
import matplotlib.pyplot as plt

In [None]:
_ = plt.scatter(mags[0:1000, 3], np.log10(np.abs(gout[:,4]) + 0.0001))

In [None]:
_ = plt.hist(np.asinh(gout[:,4]), bins=500)
_ = plt.xscale('symlog')
_ = plt.xlim(-1, 1.)

In [None]:
_ = plt.scatter((gout[:,4]), np.asinh(gout[:,4]))
_ = plt.xlim(-400, 400)
_ = plt.ylim(-10, 10)