# A simple transit fit with NumPyro

This demo follows the [Transits, occultations, and eclipses](https://docs.exoplanet.codes/en/latest/tutorials/data-and-models/#transits-occultations-and-eclipses) tutorial in the exoplanet docs, but updated to use NumPyro and jaxoplanet.

First we simulate some data:

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

import numpyro
from numpyro import distributions as dist, infer

numpyro.set_host_device_count(2)

from numpyro_ext import distributions as distx, optim

import corner
import arviz as az

from jaxoplanet.orbits import KeplerianBody
from jaxoplanet.light_curves import LimbDarkLightCurve


def light_curve_model(params, t):
    orbit = KeplerianBody.init(
        period=params["period"],
        time_transit=params["t0"],
        impact_param=params["b"],
        radius=params["r"],
    )
    return 1e3 * LimbDarkLightCurve.init(params["u"]).light_curve(orbit=orbit, t=t)


true_params = {
    "u": np.array([0.3, 0.2]),
    "period": 7.2125,
    "t0": 4.35,
    "b": 0.35,
    "r": 0.04,
}

random = np.random.default_rng(123)
t = np.arange(0, 35, 0.02)
yerr = 0.3
y = light_curve_model(true_params, t) + random.normal(0, yerr, t.size)
plt.plot(t, y, ".k", ms=1)
plt.xlabel("time [days]")
plt.ylabel("relative flux [ppt]");

Here's the simple probabilistic model implemented in NumPyro:

In [None]:
num_transits = 5


def model(t, yerr, y=None):
    # The baseline flux
    mean = numpyro.sample("mean", dist.Normal(0.0, 1.0))

    # Often the best parameterization is actually in terms of the
    # times of two reference transits rather than t0 and period
    t0 = numpyro.sample("t0", dist.Normal(4.35, 1.0))
    t1 = numpyro.sample("t1", dist.Normal(33.2, 1.0))
    period = numpyro.deterministic("period", (t1 - t0) / (num_transits - 1))

    # The Kipping (2013) parameterization for quadratic limb darkening
    # parameters
    u = numpyro.sample("u", distx.QuadLDParams())

    # The radius ratio and impact parameter; these parameters can
    # introduce pretty serious covariances and are ripe for
    # reparameterization
    log_r = numpyro.sample("log_r", dist.Normal(np.log(0.04), 2.0))
    r = numpyro.deterministic("r", jnp.exp(log_r))
    f = numpyro.sample("f", dist.Uniform(0, 1))
    b = numpyro.deterministic("b", f * (1 + r))

    # Some extra "jitter" for measurement uncertainties
    log_jitter = numpyro.sample("log_jitter", dist.Normal(np.log(0.3), 1.0))

    params = {
        "u": u,
        "period": period,
        "t0": t0,
        "b": b,
        "r": r,
    }
    lc = mean + light_curve_model(params, t)
    numpyro.sample(
        "y", dist.Normal(lc, jnp.sqrt(yerr**2 + jnp.exp(2 * log_jitter))), obs=y
    )

With a lot of these kinds of models, it can be useful to find the MAP parameters to begin with, rather than initializing by sampling from the prior:

In [None]:
soln = optim.optimize(model)(jax.random.PRNGKey(0), t, yerr, y=y)
soln

Then we can run a basic MCMC:

In [None]:
sampler = infer.MCMC(
    infer.NUTS(
        model,
        dense_mass=True,
        init_strategy=infer.init_to_value(values=soln),
        max_tree_depth=6,
    ),
    num_warmup=1000,
    num_samples=1000,
    num_chains=2,
)
sampler.run(jax.random.PRNGKey(0), t, yerr, y=y)

And look at the results:

In [None]:
inf_data = az.from_numpyro(sampler)
corner.corner(inf_data, var_names=["mean", "t0", "t1", "log_r", "f", "u", "log_jitter"])
az.summary(inf_data)

## Exposing pure-JAX probability function

To use this model outside NumPyro, we can use the `numpyro.infer.util.initialize_model` function to extract a pure-JAX version of the probability function:

In [None]:
param_info, potential_fn, postprocess_fn, _ = infer.util.initialize_model(
    jax.random.PRNGKey(0),
    model,
    model_args=(t, yerr),
    model_kwargs={"y": y},
    dynamic_args=True,
    init_strategy=infer.init_to_value(values=soln),
)

# This in the initial parameter coordinates.
# NOTE: These are the _unconstrained_ parameters, and the `potential_fn`
# automatically handles remapping to the constrained space.
initial_params = param_info.z

# The `potential_fn` is a function (generator) that takes in the unconstrained
# parameters and returns the _negative_ log probability. We can wrap it like
# this to get a function that returns the log probability.
log_prob_function = lambda params: -potential_fn(t, yerr, y=y)(params)

print(initial_params)
print(log_prob_function(initial_params))

If you need the log probability function to operate just on arrays, you can use

In [None]:
from jax.flatten_util import ravel_pytree

flat_initial_params, unravel = ravel_pytree(initial_params)
flat_log_prob_function = lambda flat_params: log_prob_function(unravel(flat_params))
print(flat_initial_params)
print(flat_log_prob_function(flat_initial_params))