In [None]:
# ruff: noqa
import jax
jax.config.update("jax_enable_x64", True)
from typing import Callable

import astropy.units as u
import jax.numpy as jnp
from astropy.coordinates import SkyCoord
from astropy.time import Time

from jorbit import Particle
from jorbit.accelerations import create_static_default_acceleration_func
from jorbit.accelerations.static_helpers import precompute_perturber_positions
from jorbit.integrators import (
    ias15_evolve,
    ias15_step,
    ias15_static_evolve,
    ias15_static_step,
    initialize_ias15_integrator_state,
)
from jorbit.utils.states import IAS15IntegratorState, SystemState


In [2]:
# get dts from a standard ias15_evolve run, use those in static_helpers
# then, create a tiny function around each observation time to handle continuous-time
# integrations for light travel time

In [2]:
t0 = Time("2026-01-01")
tf = Time("2036-01-01")

p = Particle.from_horizons("274301", t0)
state = p.keplerian_state.to_system()


state.fixed_perturber_positions = jnp.empty((27,3))
state.fixed_perturber_velocities = jnp.empty((27,3))
state.fixed_perturber_log_gms = jnp.empty(27,)

acc_func = create_static_default_acceleration_func()
a0 = acc_func(state)
state.fixed_perturber_positions *= 0.0 # just making sure these initial values aren't being used
state.fixed_perturber_velocities *= 0.0
integrator_init = initialize_ias15_integrator_state(a0)

In [3]:
def evolve(
    initial_system_state: IAS15IntegratorState,
    acceleration_func: Callable,
    final_time: float,
    initial_integrator_state: IAS15IntegratorState,
) -> tuple[SystemState, IAS15IntegratorState]:

    def step_needed(args: tuple) -> tuple:
        system_state, integrator_state, last_meaningful_dt, iter_num = args

        t = system_state.time

        diff = final_time - t
        step_length = jnp.sign(diff) * jnp.min(
            jnp.array([jnp.abs(diff), jnp.abs(integrator_state.dt)])
        )

        integrator_state.dt = step_length

        system_state, integrator_state = ias15_step(
            system_state, acceleration_func, integrator_state
        )
        return system_state, integrator_state, last_meaningful_dt, iter_num + 1

    def cond_func(args: tuple) -> bool:
        system_state, integrator_state, _last_meaningful_dt, iter_num = args
        t = system_state.time

        step_length = jnp.sign(final_time - t) * jnp.min(
            jnp.array([jnp.abs(final_time - t), jnp.abs(integrator_state.dt)])
        )
        return (step_length != 0) & (iter_num < 10_000)

    args = (
        initial_system_state,
        initial_integrator_state,
        initial_integrator_state.dt,
        0,
    )
    dts = []
    while cond_func(args):
        print("goin")
        args = step_needed(args)
        dts.apped(args[1].dt_done)

    return dts

evolve(
        state,
        acc_func,
        jnp.array([tf.tdb.jd]),
        integrator_init,
    )

ValueError: All input arrays must have the same shape.

In [None]:
t0 = Time("2026-01-01")
tf = Time("2036-01-01")

p = Particle.from_horizons("274301", t0)
state = p.keplerian_state.to_system()
x, v = p.integrate(tf)

planet_pos, planet_vel, asteroid_pos, asteroid_vel, all_times, obs_indices, gms = precompute_perturber_positions(
    t0=t0,
    times=tf,
    max_step_size=25,
    de_ephemeris_version="440"
)

perturber_pos = jnp.concatenate((planet_pos, asteroid_pos), axis=2)
perturber_vel = jnp.concatenate((planet_vel, asteroid_vel), axis=2)

state.fixed_perturber_positions = perturber_pos[0,0]
state.fixed_perturber_velocities = perturber_vel[0,0]
state.fixed_perturber_log_gms = gms

acc_func = create_static_default_acceleration_func()
a0 = acc_func(state)
state.fixed_perturber_positions *= 0.0 # just making sure these initial values aren't being used
state.fixed_perturber_velocities *= 0.0
integrator_init = initialize_ias15_integrator_state(a0)

In [2]:
t0 = Time("2026-01-01")
tf = Time("2036-01-01")

p = Particle.from_horizons("274301", t0)
state = p.keplerian_state.to_system()
x, v = p.integrate(tf)

planet_pos, planet_vel, asteroid_pos, asteroid_vel, all_times, obs_indices, gms = precompute_perturber_positions(
    t0=t0,
    times=tf,
    max_step_size=25,
    de_ephemeris_version="440"
)

perturber_pos = jnp.concatenate((planet_pos, asteroid_pos), axis=2)
perturber_vel = jnp.concatenate((planet_vel, asteroid_vel), axis=2)

state.fixed_perturber_positions = perturber_pos[0,0]
state.fixed_perturber_velocities = perturber_vel[0,0]
state.fixed_perturber_log_gms = gms

acc_func = create_static_default_acceleration_func()
a0 = acc_func(state)
state.fixed_perturber_positions *= 0.0 # just making sure these initial values aren't being used
state.fixed_perturber_velocities *= 0.0
integrator_init = initialize_ias15_integrator_state(a0)

positions, velocities, _, _ = ias15_static_evolve(
    state,
    acc_func,
    jnp.diff(all_times),
    perturber_pos,
    perturber_vel,
    gms,
    integrator_init,
)

positions[-1] - x

Array([[4.21973567e-12, 4.91118257e-12, 1.43129952e-12]], dtype=float64)

In [3]:
%%timeit
positions, velocities, _, _ = ias15_static_evolve(
    state,
    acc_func,
    jnp.diff(all_times),
    perturber_pos,
    perturber_vel,
    gms,
    integrator_init,
)
positions.block_until_ready()

35.4 ms ± 523 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [4]:
%%timeit
x, v = p.integrate(tf)
x.block_until_ready()

84.8 ms ± 1.64 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [5]:
from jorbit.data.constants import SPEED_OF_LIGHT
from jorbit.utils.states import KeplerianState

t0_num = t0.tdb.jd
def grad_wrapper(a,e,nu,Omega,inc,omega):
    s = KeplerianState(
        semi=a,
        ecc=e,
        nu=nu,
        Omega=Omega,
        inc=inc,
        omega=omega,
        acceleration_func_kwargs={"c2": SPEED_OF_LIGHT**2},
        time=t0_num,
    )
    state = s.to_system()
    state.fixed_perturber_positions = jnp.empty((27,3))
    state.fixed_perturber_velocities = jnp.empty((27,3))
    state.fixed_perturber_log_gms = jnp.empty((27,))

    positions, velocities, _, _ = ias15_static_evolve(
        state,
        acc_func,
        jnp.diff(all_times),
        perturber_pos,
        perturber_vel,
        gms,
        integrator_init,
    )
    return positions[-1]

t0 = Time("2026-01-01")
tf = Time("2036-01-01")

p = Particle.from_horizons("274301", t0)
state = p.keplerian_state.to_system()

planet_pos, planet_vel, asteroid_pos, asteroid_vel, all_times, obs_indices, gms = precompute_perturber_positions(
    t0=t0,
    times=tf,
    max_step_size=25,
    de_ephemeris_version="440"
)

perturber_pos = jnp.concatenate((planet_pos, asteroid_pos), axis=2)
perturber_vel = jnp.concatenate((planet_vel, asteroid_vel), axis=2)

state.fixed_perturber_positions = perturber_pos[0,0]
state.fixed_perturber_velocities = perturber_vel[0,0]
state.fixed_perturber_log_gms = gms
k = p.keplerian_state

j1 = jax.jit(jax.jacfwd(grad_wrapper))
j1(k.semi, k.ecc, k.nu, k.Omega, k.inc, k.omega)

Array([[[15.20479439],
        [15.48402737],
        [ 4.75829435]]], dtype=float64)

In [6]:
%timeit j1(k.semi, k.ecc, k.nu, k.Omega, k.inc, k.omega).block_until_ready()

265 ms ± 10 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
from jorbit.data.constants import SPEED_OF_LIGHT
from jorbit.utils.states import KeplerianState

t0_num = t0.tdb.jd
def grad_wrapper(a,e,nu,Omega,inc,omega):
    s = KeplerianState(
        semi=a,
        ecc=e,
        nu=nu,
        Omega=Omega,
        inc=inc,
        omega=omega,
        acceleration_func_kwargs={"c2": SPEED_OF_LIGHT**2},
        time=t0_num,
    )
    state = s.to_system()
    state.fixed_perturber_positions = jnp.empty((27,3))
    state.fixed_perturber_velocities = jnp.empty((27,3))
    state.fixed_perturber_log_gms = jnp.empty((27,))

    positions, velocities, _, _ = ias15_evolve(
        state,
        acc_func,
        jnp.array([tf.tdb.jd]),
        integrator_init,
    )
    return positions

t0 = Time("2026-01-01")
tf = Time("2036-01-01")

p = Particle.from_horizons("274301", t0)
state = p.keplerian_state.to_system()

planet_pos, planet_vel, asteroid_pos, asteroid_vel, all_times, obs_indices, gms = precompute_perturber_positions(
    t0=t0,
    times=tf,
    max_step_size=10,
    de_ephemeris_version="440"
)

perturber_pos = jnp.concatenate((planet_pos, asteroid_pos), axis=2)
perturber_vel = jnp.concatenate((planet_vel, asteroid_vel), axis=2)

state.fixed_perturber_positions = perturber_pos[0,0]
state.fixed_perturber_velocities = perturber_vel[0,0]
state.fixed_perturber_log_gms = gms
k = p.keplerian_state

j2 = jax.jit(jax.jacfwd(grad_wrapper))
j2(k.semi, k.ecc, k.nu, k.Omega, k.inc, k.omega)

Array([[[[nan],
         [nan],
         [nan]]]], dtype=float64)

In [8]:
%timeit j2(k.semi, k.ecc, k.nu, k.Omega, k.inc, k.omega).block_until_ready()

1.46 s ± 46.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
