In [None]:
import jaxoplanet
from jaxoplanet.light_curves import limb_dark_light_curve
from jaxoplanet.orbits import TransitOrbit
import numpy as np
import matplotlib.pyplot as plt
import numpyro
import numpyro.distributions as dist
import numpyro_ext
import numpyro_ext.distributions as distx, numpyro_ext.optim as optimx
import jax
import jax.numpy as jnp
import corner
import arviz as az
import copy

numpyro.set_host_device_count(
    2
)  # For multi-core parallelism (useful when running multiple MCMC chains in parallel)
numpyro.set_platform("cpu")  # For CPU (use "gpu" for GPU)
jax.config.update(
    "jax_enable_x64", True
)  # For 64-bit precision since JAX defaults to 32-bit

In [None]:
PERIOD = 10.0  # day
DURATION = 0.3  # day
T0 = 0.0  # day
B = 0.3  # impact parameter
ROR = 0.1  # radius ratio

In [None]:
t = np.linspace(-1, 1, 500)
params = {
    "period": PERIOD,
    "duration": DURATION,
    "b": B,
    "t0": T0,
    "ror": ROR,
    "u": None,
}

In [None]:
orbit = TransitOrbit(
    period=params["period"],
    duration=params["duration"],
    impact_param=params["b"],
    time_transit=params["t0"],
    radius_ratio=params["ror"],
)

lc = limb_dark_light_curve(orbit)(t)

In [None]:
fig, ax = plt.subplots(dpi=200)
ax.plot(t, lc)

In [None]:
ldcs = distx.QuadLDParams().sample(jax.random.PRNGKey(100), sample_shape=(50000,))

In [None]:
fig, ax = plt.subplots(dpi=200)

gridsize = np.array([9, 4])

ax.hexbin(ldcs[:, 0], ldcs[:, 1], gridsize=10 * gridsize, mincnt=1)
ax.axvline(0.72844981)
ax.axhline(-0.3642249)
ax.set_xlabel("u1")
ax.set_ylabel("u2")

In [None]:
def eval_limb_dark_light_curve(u, t):
    orbit = TransitOrbit(
        period=params["period"],
        duration=params["duration"],
        impact_param=params["b"],
        time_transit=params["t0"],
        radius_ratio=params["ror"],
    )
    return limb_dark_light_curve(orbit, u)(t)


fig, ax = plt.subplots(dpi=200)
lcs = jax.vmap(eval_limb_dark_light_curve, in_axes=(0, None))(ldcs, t)

for u, lc in zip(ldcs[:1000], lcs[:1000]):
    ax.plot(t, lc, alpha=0.8, color="k", lw=0.1)
ax.axhline(-(ROR**2), zorder=-100, lw=0.5)
ax.set_ylim(-0.011, -0.009)
ax.set_xlim(-0.12, 0.12)

In [None]:
lcs.shape

In [None]:
fig, ax = plt.subplots(dpi=200)
for u, lc in zip(ldcs[:1000], lcs[:1000]):
    ax.plot(np.diff(lc), alpha=0.8, color="k", lw=0.05)
# ax.axhline(-ROR**2, zorder=-100, lw=0.5)
# ax.set_ylim(-0.011, -0.009)
# ax.set_xlim(-0.12, 0.12)

In [None]:
t = np.linspace(-0.5, 0.5, 300)
params = {
    "period": PERIOD,
    "duration": DURATION,
    "b": B,
    "t0": T0,
    "u": U,
    "rors": jnp.sqrt(DEPTHS),
}


def eval_limb_dark_light_curve(params, t):
    orbit = TransitOrbit(
        period=params["period"],
        duration=params["duration"],
        impact_param=params["b"],
        time_transit=params["t0"],
        radius_ratio=params["rors"],
    )
    return limb_dark_light_curve(orbit, params["u"])(t)


y_true = jax.vmap(
    eval_limb_dark_light_curve,
    in_axes=(
        {
            "period": None,
            "duration": None,
            "b": None,
            "t0": None,
            "u": None,
            "rors": 0,
        },
        None,
    ),
)(params, t)

stddevs = 5e-6 * wavelengths**3
yerr = np.repeat(stddevs, repeats=t.size).reshape(num_lcs, t.size)
keys = jax.random.split(jax.random.PRNGKey(89), num=stddevs.size)
dy = jax.vmap(
    lambda stddev, key: stddev * jax.random.normal(key, shape=(t.size,)), in_axes=(0, 0)
)(stddevs, keys)
y = y_true + dy

# Let's check our spectroscopic light curves
fig, ax = plt.subplots(dpi=200)
offset = 0.0
for _y_true, _y, stddev, wv in zip(y_true, y, stddevs, wavelengths):
    ax.plot(t, _y_true + offset, lw=0.5, color="k")
    ax.errorbar(
        t, _y + offset, yerr=stddev, marker=".", ms=1, ls="none", lw=0.8, capsize=0
    )
    ax.annotate(f"{100*wv:.1f} nm", xy=(-0.5, 0.002 + offset), fontsize=8)
    offset += 0.01
ax.set_xlabel("time [day]", fontsize=10)
ax.set_ylabel("relative flux + arbitrary offset", fontsize=10);

## Broadband light curve
Let's also see what the broadband light curve would look like. We'll use [inverse-variance weighting](https://en.wikipedia.org/wiki/Inverse-variance_weighting) to combine the spectroscopic light curves.

In [None]:
inv_var = 1 / stddevs**2
y_bb = np.dot(y.T, inv_var)
y_bb /= np.sum(inv_var)
yerr_bb = np.sqrt(np.sum(inv_var) ** -1)

fig, ax = plt.subplots(dpi=200)
ax.errorbar(t, y_bb, yerr=yerr_bb, marker=".", ls="none", lw=0.8, color="k")
ax.set_xlabel("time [day]", fontsize=10)
ax.set_ylabel("relative flux", fontsize=10);

## Setting up our Numpyro model

We'll follow a pretty similar setup for the numpyro model as the one we set up in the [single transit tutorial](transit.ipynb).

Let's also assume we have some informative priors from previous measurements that are relatively close to the true values for all the parameters besides the limb-darkening coefficients (LDCs). For those, we'll use the `QuadLDParams` distribution from the `numpyro_ext` package which implements the uninformative prior for quadratic LD as specified in [Kipping (2013)](https://doi.org/10.1093/mnras/stt1435).

In [None]:
def jitter_value(value, jitter_fraction, key):
    jitter = jitter_fraction * value * jax.random.normal(key)
    return value + jitter


# Priors
mu_duration = jitter_value(DURATION, 1e-3, jax.random.PRNGKey(8))
mu_t0 = jitter_value(T0, 1e-4, jax.random.PRNGKey(131))
mu_b = jitter_value(B, 1e-2, jax.random.PRNGKey(23))
keys = jax.random.split(jax.random.PRNGKey(55), num=num_lcs)
mu_depths = jax.vmap(jitter_value, in_axes=(0, None, 0))(DEPTHS, 1e-9, keys)
mu_depths = np.mean(mu_depths)

In [None]:
mu_depths

In [None]:
def model(t, yerr, y=None):
    num_lcs = jnp.atleast_2d(yerr).shape[0]

    # Priors

    ## Parameters shared across spectroscopic light curves
    logD = numpyro.sample("logD", dist.Normal(jnp.log(mu_duration), 1e-2))
    duration = numpyro.deterministic("duration", jnp.exp(logD))

    t0 = numpyro.sample(
        "t0", dist.Normal(mu_t0, 1e-3)
    )  # We usually have pretty good constraints on t0
    b = numpyro.sample(
        "b",
        dist.TruncatedNormal(mu_b, 0.1, low=0.0, high=1.0),
    )
    u = numpyro.sample("u", distx.QuadLDParams())

    ## Parameters for each light curve
    depths = numpyro.sample(
        "depths",
        dist.TruncatedNormal(
            mu_depths,
            1e-1 * jnp.ones_like(mu_depths),
            low=0.0,
            high=1.0,
        ),
    )
    rors = jnp.atleast_1d(numpyro.deterministic("rors", jnp.sqrt(depths)))

    params = {
        "period": PERIOD,
        "duration": duration,
        "t0": t0,
        "b": b,
        "u": u,
        "rors": rors,
    }

    y_model = jax.vmap(
        eval_limb_dark_light_curve,
        in_axes=(
            {
                "period": None,
                "duration": None,
                "b": None,
                "t0": None,
                "u": None,
                "rors": 0,
            },
            None,
        ),
    )(params, t)

    numpyro.sample("obs", dist.Normal(y_model, yerr), obs=y)

## Checking priors
Let's check our priors to:
1. Make sure the range of our priors are physically sensible, and
2. We're not *too* off from the true values

In [None]:
n_prior_samples = 2000
prior_samples = numpyro.infer.Predictive(model, num_samples=n_prior_samples)(
    jax.random.PRNGKey(0), t, yerr
)

# Let's make it into an arviz InferenceData object.
# To do so we'll first need to reshape the samples to be of shape (chains, draws, *shape)
converted_prior_samples = {
    f"{p}": np.expand_dims(prior_samples[p], axis=0) for p in prior_samples
}
prior_samples_inf_data = az.from_dict(converted_prior_samples)

# Plot the corner plot
fig = plt.figure(figsize=(20, 20))
_ = corner.corner(
    prior_samples_inf_data,
    fig=fig,
    var_names=["t0", "duration", "b", "u", "rors", "depths"],
    truths=[T0, DURATION, B, *U, jnp.sqrt(jnp.mean(DEPTHS)), jnp.mean(DEPTHS)],
    show_titles=True,
    title_kwargs={"fontsize": 10},
    label_kwargs={"fontsize": 10},
)

## Initial fit to the broadband light curve

In [None]:
init_params = {
    "period": PERIOD,
    "duration": mu_duration,
    "b": mu_b,
    # "u": numpyro_ext.distributions.QuadLDParams().sample(jax.random.PRNGKey(2345)),
    "u": U,
    "t0": mu_t0,
    "depths": np.mean(mu_depths),
}

keys = jax.random.split(jax.random.PRNGKey(535), num=3)

soln = optimx.optimize(
    model,
    sites=["duration", "t0", "b", "u"],
    start=init_params,
)(keys[0], t, yerr_bb, y=y_bb)

soln = optimx.optimize(
    model,
    sites=["depths"],
    start=soln,
)(keys[1], t, yerr_bb, y=y_bb)

soln = optimx.optimize(
    model,
    start=soln,
)(keys[2], t, yerr_bb, y=y_bb)

In [None]:
map_params

In [None]:
param_keys = [k for k in params.keys() if k != "period"]
param_keys
map_params = {"period": PERIOD} | {k: soln[k] for k in param_keys}
map_params["rors"] = jnp.atleast_1d(map_params["rors"])

in_axes = {
    "period": None,
    "duration": None,
    "b": None,
    "t0": None,
    "u": None,
    "rors": 0,
}

t_model = np.linspace(t[0], t[-1], 1000)
y_model = jax.vmap(eval_limb_dark_light_curve, in_axes=(in_axes, None))(
    map_params, t_model
)

fig, ax = plt.subplots(dpi=200)
offset = 0.0
_label = "MAP model"

ax.errorbar(t, y_bb, yerr=yerr_bb, marker=".", ms=1, ls="none", lw=0.8, capsize=0)
ax.plot(t_model, y_model.T, ls="none", marker=".")

# for _y_model, _y, stddev, wv in zip(y_model, y, stddevs, wavelengths):
#     ax.errorbar(
#         t, _y + offset, yerr=stddev, marker=".", ms=1, ls="none", lw=0.8, capsize=0
#     )
#     ax.annotate(f"{wv:.1f} nm", xy=(-0.5, 0.002 + offset), fontsize=8)
#     ax.plot(t, _y_model + offset, lw=0.5, color="k", label=_label)
#     offset += 0.01
#     _label = None
ax.set_xlabel("time [day]", fontsize=10)
ax.set_ylabel("relative flux + arbitrary offset", fontsize=10)
ax.legend(markerscale=2, edgecolor="k");

In [None]:
map_params

In [None]:
map_params

## Optimize and get MAP estimate
Let's optimize the model to calculate the *maximum a posteriori* (MAP) estimate so that we can use it as the starting point for our MCMC run.

We've found the optimization to be more robust (i.e., not sensitive to the random seed) when we optimize the parameters in batches instead of all at once. 

In [None]:
init_params = {
    "period": PERIOD,
    "duration": mu_duration,
    "b": mu_b,
    # "u": numpyro_ext.distributions.QuadLDParams().sample(jax.random.PRNGKey(2345)),
    "u": np.array([0.2, 0.04]),
    "t0": mu_t0,
    "depths": mu_depths,
}

keys = jax.random.split(jax.random.PRNGKey(535), num=3)

soln = optimx.optimize(
    model,
    sites=["duration", "t0", "b", "u"],
    start=init_params,
)(keys[0], t, yerr, y=y)

soln = optimx.optimize(
    model,
    sites=["depths"],
    start=soln,
)(keys[1], t, yerr, y=y)

soln = optimx.optimize(
    model,
    start=soln,
)(keys[2], t, yerr, y=y)

Let's extract the model parameters from the `soln` dictionary and plot our MAP model

In [None]:
param_keys = [k for k in params.keys() if k != "period"]
param_keys
map_params = {"period": PERIOD} | {k: soln[k] for k in param_keys}

in_axes = {
    "period": None,
    "duration": None,
    "b": None,
    "t0": None,
    "u": None,
    "rors": 0,
}

y_model = jax.vmap(eval_limb_dark_light_curve, in_axes=(in_axes, None))(map_params, t)

fig, ax = plt.subplots(dpi=200)
offset = 0.0
_label = "MAP model"
for _y_model, _y, stddev, wv in zip(y_model, y, stddevs, wavelengths):
    ax.errorbar(
        t, _y + offset, yerr=stddev, marker=".", ms=1, ls="none", lw=0.8, capsize=0
    )
    ax.annotate(f"{wv:.1f} nm", xy=(-0.5, 0.002 + offset), fontsize=8)
    ax.plot(t, _y_model + offset, lw=0.5, color="k", label=_label)
    offset += 0.01
    _label = None
ax.set_xlabel("time [day]", fontsize=10)
ax.set_ylabel("relative flux + arbitrary offset", fontsize=10)
ax.legend(markerscale=2, edgecolor="k");

In [None]:
map_params

## Sampling

In [None]:
sampler = numpyro.infer.MCMC(
    numpyro.infer.NUTS(
        model,
        dense_mass=True,
        regularize_mass_matrix=True,
        init_strategy=numpyro.infer.init_to_value(values=map_params),
    ),
    num_warmup=300,
    num_samples=100,
    num_chains=2,
    progress_bar=True,
)

In [None]:
with jax.profiler.trace("./"):
    jax.block_until_ready(sampler.run(jax.random.PRNGKey(10), t, yerr, y=y))

In [None]:
sampler.print_summary()
inf_data = az.from_numpyro(sampler)

In [None]:
corner.corner(
    inf_data,
    var_names=["duration", "u", "rors", "b", "t0"],
    truths=[DURATION, *U, *DEPTHS, B, T0],
);