In [None]:
import jaxoplanet
from jaxoplanet.light_curves import LimbDarkLightCurve
from jaxoplanet.orbits import TransitOrbit
import numpy as np
import matplotlib.pyplot as plt
import numpyro
import numpyro_ext.distributions, numpyro_ext.optim
import jax
import jax.numpy as jnp
import corner
import arviz as az
import copy

numpyro.set_host_device_count(
    4
)  # For multi-core parallelism (useful when running multiple MCMC chains in parallel)
numpyro.set_platform("cpu")  # For CPU (use "gpu" for GPU)
jax.config.update(
    "jax_enable_x64", True
)  # For 64-bit precision since JAX defaults to 32-bit


print(f"jaxoplanet.__version__ = {jaxoplanet.__version__}")
print(f"numpy.__version__ = {np.__version__}")
print(f"matplotlib.__version__ = {plt.matplotlib.__version__}")
print(f"numpyro.__version__ = {numpyro.__version__}")
print(f"numpyro_ext.__version__ = {numpyro_ext.__version__}")
print(f"jax.__version__ = {jax.__version__}")
print(f"corner.__version__ = {corner.__version__}")
print(f"arviz.__version__ = {az.__version__}")

In [None]:
# The light curve calculation requires an orbit object.
# We'll use TransitOrbit (similar to SimpleTransitOrbit in the exoplanet package),
# which is an orbit parameterized by the observables of a transiting system:
# period, speed/duration, time of transit, impact parameter, and radius ratio.
orbit = TransitOrbit.init(
    period=3.456, duration=0.12, time_transit=0.0, impact_param=0.0, radius=0.1
)  # TODO: Is it actually the radius ratio?


# Compute a limb-darkened light curve for this orbit
t = np.linspace(-0.1, 0.1, 1000)
u = [0.3, 0.2]
light_curve, summed_lc, stencil, x, y, z = LimbDarkLightCurve.init(u).light_curve(
    orbit, t, texp=jnp.repeat(50, 1000)
)  # returns a 2D array of shape (1, 1000)
# light_curve = light_curve[0]  # remove the extra dimension
no_texp_lc, x, y, z = LimbDarkLightCurve.init(u).light_curve(orbit, t, texp=None)

# Plot the light curve
plt.figure(dpi=150)
plt.plot(t, summed_lc, lw=2)
plt.plot(t, no_texp_lc, lw=2)
plt.xlabel("time [days]")
plt.ylabel("relative flux")
plt.xlim(t.min(), t.max());

In [None]:
light_curve.shape

In [None]:
light_curve[700]