In [None]:
import jax.numpy as jnp
from jaxoplanet.orbits import keplerian
from jaxoplanet.starry.orbit import SurfaceSystem, Surface
from jaxoplanet.light_curves.emission import light_curve
from jaxoplanet.light_curves.transforms import integrate
from tinygp import GaussianProcess, kernels
from jaxoplanet.units import unit_registry as ureg


def system_model(params):

    # from inferred parameters to system parameters
    R1 = jnp.exp(params["log_R1"]) * ureg.R_sun
    M1 = jnp.exp(params["log_M1"]) * ureg.M_sun
    R2 = jnp.exp(params["log_R1"] + params["log_k"]) * ureg.R_sun
    M2 = jnp.exp(params["log_M1"] + params["log_q"]) * ureg.M_sun
    period = jnp.exp(params["log_period"]) * ureg.day
    t0 = params["t0"] * ureg.day
    eccentricities = params["ecs"]
    eccentricity = jnp.sqrt(jnp.sum(eccentricities**2))
    omega = jnp.arctan2(eccentricities[1], eccentricities[0])
    inclination = params["inclination"] * ureg.rad
    s = jnp.exp(params["log_s"])
    u1 = params["u1"]
    u2 = params["u2"]

    primary = keplerian.Central(radius=R1, mass=M1)
    primary_surface = Surface(u=(u1, u2), amplitude=1.0, normalize=False)
    secondary_surface = Surface(amplitude=s, normalize=False)

    system = SurfaceSystem(primary, primary_surface).add_body(
        radius=R2,
        mass=M2,
        period=period,
        time_transit=t0,
        eccentricity=eccentricity,
        omega_peri=omega,
        inclination=inclination,
        surface=secondary_surface,
    )

    return system


def light_curve_model(time, params):

    system = system_model(params)

    def flux_function(time):
        flux = 1e3 * (
            (jnp.sum(integrate(light_curve(system), exposure_time=exposure)(time)))
            / (1 + jnp.exp(params["log_s"]))
            - 1.0
        )

        return flux

    gp = GaussianProcess(
        kernels.quasisep.SHO(
            sigma=params["lc_gp_sigma"],
            omega=params["lc_gp_omega"],
            quality=params["lc_gp_quality"],
        ),
        time,
        mean=flux_function,
        diag=params["lc_sigma"] ** 2,
    )

    return gp