In [1]:
import jax

jax.config.update("jax_enable_x64", True)

from jaxoplanet.experimental.starry import Map, Ylm
from jaxoplanet.orbits import keplerian

params = {
    "central": keplerian.Central(
        mass=1.3,
        radius=1.1,
    ),
    "central_surface_map": Map(
        y=Ylm.from_dense([1, 0.005, 0.05, 0.09, 0.0, 0.1, 0.03, 0.04, 0.4, 0.2, 0.1]),
        inc=0.9,
        obl=0.3,
        period=1.2,
        u=(0.1, 0.1),
    ),
    "bodies": [
        {
            "radius": 0.5,
            "mass": 0.1,
            "period": 1.5,
            "surface_map": Map(
                y=Ylm.from_dense([1, 0.005, 0.05, 0.09, 0.0, 0.1, 0.03]),
                inc=-0.3,
                period=0.8,
                u=(0.2, 0.3),
            ),
        }
    ],
}

In [2]:
import starry

# starry.config.lazy = False
import numpy as np


from jaxoplanet.experimental.starry.light_curves import light_curve, map_light_curve
from jaxoplanet.experimental.starry.orbit import SurfaceMapSystem

from jaxoplanet.test_utils import assert_allclose


keplerian_system = SurfaceMapSystem(
    params["central"],
    central_surface_map=params.get("central_surface_map", None),
)

for body in params["bodies"]:
    keplerian_system = keplerian_system.add_body(**body)


def jaxoplanet2starry(body, surface_map=None):
    cls = starry.Primary if isinstance(body, keplerian.Central) else starry.Secondary
    if surface_map is None or surface_map.period is None:
        prot = 1e15
    else:
        prot = surface_map.period

    if surface_map is None:
        map_kwargs = dict(ydeg=0, rv=False, reflected=False)
    else:
        map_kwargs = dict(
            ydeg=surface_map.ydeg,
            udeg=surface_map.udeg,
            inc=np.rad2deg(surface_map.inc),
            obl=np.rad2deg(surface_map.obl),
            amp=surface_map.amplitude,
        )

    body_kwargs = dict(
        r=body.radius.magnitude,
        m=body.mass.magnitude,
        prot=prot,
    )

    if isinstance(body, keplerian.OrbitalBody):
        body_kwargs["porb"] = body.period.magnitude
        map_kwargs["amp"] = surface_map.amplitude if surface_map else 0.0

    starry_body = cls(starry.Map(**map_kwargs), **body_kwargs)

    if surface_map and surface_map.u:
        starry_body.map[1:] = surface_map.u
    if surface_map and surface_map.deg > 0:
        starry_body.map[1:, :] = np.asarray(surface_map.y.todense())[1:]

    return starry_body



In [4]:
time = np.linspace(-1.5, 1.0, 300)

# jaxoplanet
jaxoplanet_flux = light_curve(keplerian_system)(time)

# starry
primary = jaxoplanet2starry(
    keplerian_system.central, keplerian_system.central_surface_map
)
secondaries = [
    jaxoplanet2starry(body, surface_map)
    for body, surface_map in zip(
        keplerian_system.bodies, keplerian_system.bodies_surface_maps
    )
]

starry_system = starry.System(primary, *secondaries)
starry_flux = starry_system.flux(time, total=False)

# assert_allclose(jaxoplanet_flux.T, np.array(starry_flux))

Pre-computing some matrices... Done.
Pre-computing some matrices... Done.


In [15]:
starry_flux

[dot.0, dot.0]

In [13]:
%timeit starry_system.flux(time, total=False)

45.6 ms ± 514 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [14]:
from jax import block_until_ready as bur

f = jax.jit(light_curve(keplerian_system))
bur(f(time))

%timeit bur(f(time))

882 µs ± 7.97 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
