In [1]:
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
import pandas as pd
import corner
from jnkepler.jaxttv.utils import elements_to_pdic, params_to_elements
from jnkepler.nbodytransit import *
from jax.config import config
import numpyro, jax
config.update('jax_enable_x64', True)
numpyro.set_platform('cpu')

In [3]:
path = "/Users/k_masuda/Dropbox/repos/jnkepler/src/jnkepler/data/"

In [4]:
d = pd.read_csv(path+"kep51_ttv_photodtest.txt", delim_whitespace=True, header=None, names=['tnum', 'tc', 'tcerr', 'dnum', 'planum'])
tcobs = [jnp.array(d.tc[d.planum==j+1]) for j in range(3)]
errorobs = [jnp.array(d.tcerr[d.planum==j+1]) for j in range(3)]
p_init = [45.155305, 85.31646, 130.17809]

dt = 1.0
t_start, t_end = 155., 2950.-1450
nt = NbodyTransit(t_start, t_end, dt)
nt.set_tcobs(tcobs, p_init, errorobs=errorobs, print_info=True)

# integration starts at:           155.00
# first transit time in data:      159.11
# last transit time in data:       1489.75
# integration ends at:             1500.00
# integration time step:           1.0000 (1/45 of innermost period)


## lightcurve data

In [5]:
df = pd.read_csv(path+"kep51_lc_photodtest.csv")

In [6]:
times_lc = jnp.array(df.time)
fluxes, errors = jnp.array(df.flux), jnp.array(df.flux_err)

In [7]:
nt.set_lcobs(times_lc)

# exposure time (min):             29.4
# supersampling factor:            10
# overlapping transit ignored.     


## nan in tc?
- M1, jax=0.2.28, numpyro=0.10.0
- not reproduced with jax=0.2.13 with intel macbookpro
- what about unix server?

In [8]:
print (jax.__version__)

0.2.28


In [9]:
print (numpyro.__version__)

0.10.0


In [10]:
elements, masses = np.loadtxt(path+"tcbug_elements.txt"), np.loadtxt(path+"tcbug_masses.txt")

In [11]:
rstar, u1, u2 = 1., 0.5, 0.2
prad = jnp.array([0.07, 0.05, 0.1])

In [12]:
nt.get_lc(elements, masses, rstar, prad, u1, u2)

(DeviceArray([1., 1., 1., ..., 1., 1., 1.], dtype=float64),
 DeviceArray([ 159.10962818,  204.26356894,  249.41762174,  294.57132468,
               339.72512225,  384.87907157,  430.03372513,  520.34407767,
               565.50045497,  610.65736847,  655.81379071,  700.97230582,
               746.12836105,  791.28633531,  836.44185803,  881.59781525,
               926.75233429,  971.9063965 , 1017.06031134, 1062.21363032,
              1107.36734885, 1152.52108559, 1197.67531679, 1242.83008414,
              1287.98533976, 1333.14145533, 1378.29826324, 1423.45470178,
              1468.61323666,  295.31954395,  380.63770666,           nan,
               551.26165241,  636.57517693,  892.51335784,  977.84176186,
              1148.47033585, 1233.80274638, 1319.11513582, 1489.75636904,
               212.02387996,  342.20839472,  472.39025711,  602.57377467,
               862.93290922,  993.10376866, 1123.2846728 , 1253.45058239,
              1383.6302946 ], dtype=float64))

In [13]:
nt.get_ttvs(elements, masses)

(DeviceArray([ 159.10962818,  204.26356894,  249.41762174,  294.57132468,
               339.72512225,  384.87907157,  430.03372513,  520.34407767,
               565.50045497,  610.65736847,  655.81379071,  700.97230582,
               746.12836105,  791.28633531,  836.44185803,  881.59781525,
               926.75233429,  971.9063965 , 1017.06031134, 1062.21363032,
              1107.36734885, 1152.52108559, 1197.67531679, 1242.83008414,
              1287.98533976, 1333.14145533, 1378.29826324, 1423.45470178,
              1468.61323666,  295.31954395,  380.63770666,           nan,
               551.26165241,  636.57517693,  892.51335784,  977.84176186,
              1148.47033585, 1233.80274638, 1319.11513582, 1489.75636904,
               212.02387996,  342.20839472,  472.39025711,  602.57377467,
               862.93290922,  993.10376866, 1123.2846728 , 1253.45058239,
              1383.6302946 ], dtype=float64),
 DeviceArray(1.70236603e-09, dtype=float64))

In [None]:
nt.get_ttvs_nodata(elements, masses)