In [None]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

import JaxPeriodDrwFit


from tape.ensemble import Ensemble
from tape.utils import ColumnMapper

from warnings import simplefilter
simplefilter(action="ignore", category=pd.errors.PerformanceWarning)

In [None]:
from dask.distributed import Client, LocalCluster
cluster = LocalCluster()
client = Client(cluster)
cluster.adapt(minimum=1, maximum=40) 

ens = Ensemble(client = client)  # initialize an ensemble object
ens.client_info()


In [None]:
# if running on baldur
data_path = "/astro/store/epyc/data3/hipscat/catalogs/tape_test/"

col_map = ColumnMapper(id_col="SDSS_NAME_dr16q_constant", 
                       time_col="mjd_ztf_source",
                       flux_col="mag_ztf_source", 
                       err_col="magerr_ztf_source",
                       band_col="band_ztf_source")

ens.from_hipscat(data_path,
                 source_subdir="tape_test_sources",
                 object_subdir="tape_test_obj",
                 column_mapper=col_map,
                 additional_cols=True,
                 sync_tables=True)

In [None]:
"""
0. Negative log-likelihood of the Gaussian Process model
1. log_drw_scale = theta[0]
2. log_drw_amp = theta[1]
3. log_per_scale = theta[2]
4. log_per_amp = theta[3]
"""

res_tsp_run_0 = pd.read_parquet('/astro/users/ncaplar/data/res_tsp_run_g_0.parquet', engine='pyarrow')
res_tsp_run_0_drw = pd.read_parquet('/astro/users/ncaplar/data/res_tsp_run_g_0_drw.parquet', engine='pyarrow')
# numpy version of the results
res_tsp_run_0_np = np.vstack(res_tsp_run_0.values)
# res_tsp_run_0_drw = pd.read_parquet('/astro/users/ncaplar/data/res_tsp_run_0_drw', engine='pyarrow')
# res_tsp_run_0_drw_np = np.vstack(res_tsp_run_0_drw.data.values)

In [None]:
name_str = '023922.87-000119.6'

ts_computed = ens._source.compute().loc[name_str]
ts_computed_g = ts_computed[ts_computed['band_ztf_source']=='g']
ts_computed_r = ts_computed[ts_computed['band_ztf_source']=='r']

ens._source.compute().loc[name_str]

ens.query("band_ztf_source == 'g'", table = 'source')
ts = ens._source.compute().loc[name_str]
t = ts['mjd_ztf_source'].values
sort_array = np.argsort(t)
t = t[sort_array]
y = ts['mag_ztf_source'].values[sort_array]
yerr = ts['magerr_ztf_source'].values[sort_array]

t_r = ts_computed_r['mjd_ztf_source'].values
y_r = ts_computed_r['mag_ztf_source'].values
yerr_r = ts_computed_r['magerr_ztf_source'].values

In [None]:
x = np.linspace(np.min(t) - 100, np.max(t) + 100, 2000)


JaxPeriodDrwFit_instance = JaxPeriodDrwFit.JaxPeriodDrwFit()
theta = res_tsp_run_0.loc[name_str][2:6].values
theta_drw = res_tsp_run_0_drw.loc[name_str][2:4].values

gp = JaxPeriodDrwFit_instance.build_gp(theta, t, y, yerr)
gp_drw = JaxPeriodDrwFit_instance.build_gp_drw(theta_drw, t, y, yerr)

cond_gp = gp.condition(y, x).gp
mu, var = cond_gp.loc, cond_gp.variance

cond_gp_drw = gp_drw.condition(y, x).gp
mu_drw, var_drw = cond_gp_drw.loc, cond_gp_drw.variance

mu_carma = gp.condition(y, x, kernel=gp.kernel.kernel1).gp.loc
var_carma = gp.condition(y, x, kernel=gp.kernel.kernel1).gp.variance

mu_periodic = gp.condition(y, x, kernel=gp.kernel.kernel2).gp.loc
var_periodic = gp.condition(y, x, kernel=gp.kernel.kernel2).gp.variance

In [None]:
# To optimize per one
#JaxPeriodDrwFit_instance.optimize_map(t, y, yerr, n_init = 10000)

In [None]:
res_tsp_run_0.loc[name_str]

In [None]:
res_tsp_run_0_drw.loc[name_str]

In [None]:
y_r

In [None]:
plt.figure(figsize=(20, 10))
plt.errorbar(t, y, yerr=yerr, fmt=".k", capsize=4, label = ' measurements, g', ms=10, color = 'green')
plt.errorbar(t_r, y_r, yerr=yerr_r, fmt=".k", capsize=4, label = ' measurements, r', ms=10, color = 'red')

plt.plot(x, mu_carma, color="blue", ls=':',lw=3, label = 'CARMA separated')
plt.plot(x, mu_periodic, color="red", ls=':', lw=3, label = 'periodic separated')

plt.fill_between(x, mu_carma + np.sqrt(var_carma),
                 mu_carma- np.sqrt(var_carma), color="blue", alpha=0.2)
plt.fill_between(x, mu_periodic + np.sqrt(var_periodic),
                 mu_periodic - np.sqrt(var_periodic), color="red", alpha=0.2)

plt.fill_between(x, mu + np.sqrt(var), mu - np.sqrt(var), color="magenta", alpha=0.2)
plt.plot(x, mu, color="magenta", lw=3, ls=':', label = 'CARMA + periodic')

plt.xlabel("Time (MJD)", fontsize=20)
plt.ylabel("Mag", fontsize=20)
plt.legend(fontsize = 20)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.ylim(np.min(mu)-0.1, np.max(mu)+0.1)

In [None]:
plt.figure(figsize=(20, 10))
plt.errorbar(t, y, yerr=yerr, fmt=".k", capsize=4, label = ' measurements, g', ms=10, color = 'green')
plt.errorbar(t_r, y_r, yerr=yerr_r, fmt=".k", capsize=4, label = ' measurements, r', ms=10, color = 'red')

plt.fill_between(x, mu_drw + np.sqrt(var_drw), mu_drw - np.sqrt(var_drw), color="blue", alpha=0.2)
plt.plot(x, mu_drw, color="blue", lw=3, ls=':', label = 'DRW')

plt.xlabel("Time (MJD)", fontsize=20)
plt.ylabel("Mag", fontsize=20)
plt.legend(fontsize = 20)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.ylim(np.min(mu)-0.1, np.max(mu)+0.1)
plt.title('Only DRW',  fontsize=20)