In [1]:
# ruff: noqa
import jax
jax.config.update("jax_enable_x64", True)
from typing import Callable

import astropy.units as u
import jax.numpy as jnp
from astropy.coordinates import SkyCoord
from astropy.time import Time

from jorbit import Particle
from jorbit.accelerations import create_static_default_acceleration_func, create_default_ephemeris_acceleration_func
from jorbit.accelerations.static_helpers import precompute_perturber_positions, get_all_dynamic_intermediate_dts
from jorbit.integrators import (
    ias15_evolve,
    ias15_step,
    ias15_static_evolve,
    ias15_static_step,
    initialize_ias15_integrator_state,
)
from jorbit.utils.states import IAS15IntegratorState, SystemState
from jorbit import Ephemeris

# single step

In [2]:
t0 = Time("2026-01-01")
tf = Time("2026-01-30")

p = Particle.from_horizons("274301", t0)
state = p.keplerian_state.to_system()

eph = Ephemeris(ssos="default solar system")
acc_func = create_default_ephemeris_acceleration_func(eph.processor)
a0_dynamic = acc_func(state)
integrator_init = initialize_ias15_integrator_state(a0_dynamic)
integrator_init.dt = tf.tdb.jd - t0.tdb.jd

dynamic_state, dynamic_integrator_state = ias15_step(
    initial_system_state=state,
    acceleration_func=acc_func,
    initial_integrator_state=integrator_init,
)
dynamic_x = dynamic_state.tracer_positions

In [3]:
p = Particle.from_horizons("274301", t0)
state = p.keplerian_state.to_system()
x_init = state.tracer_positions
v_init = state.tracer_velocities


planet_pos, planet_vel, asteroid_pos, asteroid_vel, gms = precompute_perturber_positions(
    t0=t0,
    dts=jnp.array([tf.tdb.jd - t0.tdb.jd]),
    de_ephemeris_version="440"
)
perturber_pos = jnp.concatenate((planet_pos, asteroid_pos), axis=2)
perturber_vel = jnp.concatenate((planet_vel, asteroid_vel), axis=2)

state.fixed_perturber_positions = perturber_pos[0,0]
state.fixed_perturber_velocities = perturber_vel[0,0]
state.fixed_perturber_log_gms = gms

acc_func = create_static_default_acceleration_func()

a0_static = acc_func(state)
integrator_init = initialize_ias15_integrator_state(a0_static)
integrator_init.dt = tf.tdb.jd - t0.tdb.jd


static_state, static_integrator_state = ias15_static_step(
    initial_system_state=state,
    acceleration_func=acc_func,
    initial_integrator_state=integrator_init,
    fixed_perturber_positions=perturber_pos[0],
    fixed_perturber_velocities=perturber_vel[0],
    fixed_perturber_log_gms=gms,
)
static_x = static_state.tracer_positions

jnp.max(jnp.abs(static_x - dynamic_x))

Array(0., dtype=float64)

# several small steps

In [4]:
times = t0 + jnp.linspace(0, (tf.tdb.jd - t0.tdb.jd), 10) * u.day

In [5]:
p = Particle.from_horizons("274301", t0)
state = p.keplerian_state.to_system()

eph = Ephemeris(ssos="default solar system")
acc_func = create_default_ephemeris_acceleration_func(eph.processor)
a0_dynamic = acc_func(state)
integrator_init = initialize_ias15_integrator_state(a0_dynamic)
integrator_init.dt = jnp.diff(times.tdb.jd)[0]

dynamic_x, dynamic_v, dynamic_state, dynamic_integrator_state = ias15_evolve(
    initial_system_state=state,
    acceleration_func=acc_func,
    times=times.tdb.jd,
    initial_integrator_state=integrator_init,
)

In [6]:
p = Particle.from_horizons("274301", t0)
state = p.keplerian_state.to_system()

planet_pos, planet_vel, asteroid_pos, asteroid_vel, gms = precompute_perturber_positions(
    t0=t0,
    dts=jnp.diff(times.tdb.jd),
    de_ephemeris_version="440"
)
perturber_pos = jnp.concatenate((planet_pos, asteroid_pos), axis=2)
perturber_vel = jnp.concatenate((planet_vel, asteroid_vel), axis=2)

state.fixed_perturber_positions = perturber_pos[0,0]
state.fixed_perturber_velocities = perturber_vel[0,0]
state.fixed_perturber_log_gms = gms

acc_func = create_static_default_acceleration_func()

a0_static = acc_func(state)
integrator_init = initialize_ias15_integrator_state(a0_static)

integrator_init.dt = tf.tdb.jd - t0.tdb.jd


static_x, static_v, static_state, static_integrator_state = ias15_static_evolve(
    initial_system_state=state,
    acceleration_func=acc_func,
    dts=jnp.diff(times.tdb.jd),
    initial_integrator_state=integrator_init,
    perturber_positions=perturber_pos,
    perturber_velocities=perturber_vel,
    perturber_log_gms=gms,
)

In [7]:
static_x - dynamic_x[1:]

Array([[[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00]],

       [[ 0.00000000e+00, -2.22044605e-16,  0.00000000e+00]],

       [[ 0.00000000e+00,  0.00000000e+00,  5.55111512e-17]],

       [[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00]],

       [[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00]],

       [[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00]],

       [[ 0.00000000e+00, -2.22044605e-16,  0.00000000e+00]],

       [[ 0.00000000e+00, -2.22044605e-16,  0.00000000e+00]],

       [[ 0.00000000e+00, -2.22044605e-16,  0.00000000e+00]]],      dtype=float64)

# longer integration

In [8]:
times = t0 + jnp.linspace(0, 3652.5, 10) * u.day

In [9]:
p = Particle.from_horizons("274301", t0)
state = p.keplerian_state.to_system()

eph = Ephemeris(ssos="default solar system")
acc_func = create_default_ephemeris_acceleration_func(eph.processor)
a0_dynamic = acc_func(state)
integrator_init = initialize_ias15_integrator_state(a0_dynamic)
integrator_init.dt = jnp.diff(times.tdb.jd)[0]

dynamic_x, dynamic_v, dynamic_state, dynamic_integrator_state = ias15_evolve(
    initial_system_state=state,
    acceleration_func=acc_func,
    times=times.tdb.jd,
    initial_integrator_state=integrator_init,
)

In [10]:
%%timeit

dynamic_x, dynamic_v, dynamic_state, dynamic_integrator_state = ias15_evolve(
    initial_system_state=state,
    acceleration_func=acc_func,
    times=times.tdb.jd,
    initial_integrator_state=integrator_init,
)

dynamic_x.block_until_ready()

96 ms ± 2.88 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [11]:
p = Particle.from_horizons("274301", t0)
state = p.keplerian_state.to_system()

eph = Ephemeris(ssos="default solar system")
acc_func = create_default_ephemeris_acceleration_func(eph.processor)
a0_dynamic = acc_func(state)
integrator_init = initialize_ias15_integrator_state(a0_dynamic)
integrator_init.dt = jnp.diff(times.tdb.jd)[0]

dts, inds = get_all_dynamic_intermediate_dts(
    initial_system_state=state,
    acceleration_func=acc_func,
    times=times,
    initial_integrator_state=integrator_init,
)


p = Particle.from_horizons("274301", t0)
state = p.keplerian_state.to_system()

planet_pos, planet_vel, asteroid_pos, asteroid_vel, gms = precompute_perturber_positions(
    t0=t0,
    dts=dts,
    de_ephemeris_version="440"
)
perturber_pos = jnp.concatenate((planet_pos, asteroid_pos), axis=2)
perturber_vel = jnp.concatenate((planet_vel, asteroid_vel), axis=2)

state.fixed_perturber_positions = perturber_pos[0,0]
state.fixed_perturber_velocities = perturber_vel[0,0]
state.fixed_perturber_log_gms = gms

acc_func = create_static_default_acceleration_func()

a0_static = acc_func(state)
integrator_init = initialize_ias15_integrator_state(a0_static)

integrator_init.dt = dts[0]


static_x, static_v, static_state, static_integrator_state = ias15_static_evolve(
    initial_system_state=state,
    acceleration_func=acc_func,
    dts=dts,
    initial_integrator_state=integrator_init,
    perturber_positions=perturber_pos,
    perturber_velocities=perturber_vel,
    perturber_log_gms=gms,
)


In [12]:
static_x[inds] - dynamic_x

Array([[[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00]],

       [[ 2.66453526e-15, -6.66133815e-15, -2.13717932e-15]],

       [[ 2.88657986e-14,  4.44089210e-15,  1.44328993e-15]],

       [[ 4.44089210e-15,  1.61537450e-14,  5.27355937e-15]],

       [[-1.66533454e-14, -1.59872116e-14, -4.38538095e-15]],

       [[ 3.99680289e-14, -1.73194792e-14, -4.88498131e-15]],

       [[-2.66453526e-15,  4.05231404e-14,  1.16850973e-14]],

       [[-1.54876112e-13, -4.08562073e-14, -1.47659662e-14]],

       [[ 1.37445610e-13, -2.06501483e-13, -6.25055563e-14]],

       [[ 1.39444012e-13,  1.13020704e-13,  3.39728246e-14]]],      dtype=float64)

In [13]:
%%timeit

static_x, static_v, static_state, static_integrator_state = ias15_static_evolve(
    initial_system_state=state,
    acceleration_func=acc_func,
    dts=dts,
    initial_integrator_state=integrator_init,
    perturber_positions=perturber_pos,
    perturber_velocities=perturber_vel,
    perturber_log_gms=gms,
)
static_x.block_until_ready()

27.5 ms ± 1.35 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [14]:
from jorbit.astrometry.sky_projection import on_sky, sky_sep
from jorbit.utils.horizons import get_observer_positions

eph = Ephemeris(ssos="default solar system")
acc_func = create_default_ephemeris_acceleration_func(eph.processor)
obspos = get_observer_positions(
    observatories="kitt peak",
    times=times[-1],
    de_ephemeris_version="440",
)

ra_dynamic, dec_dynamic = on_sky(
    x=dynamic_x[-1,0],
    v=dynamic_v[-1,0],
    time=times[-1].tdb.jd,
    observer_position=obspos[0],
    acc_func=acc_func,
)
ra_static, dec_static = on_sky(
    x=static_x[-1,0],
    v=static_v[-1,0],
    time=times[-1].tdb.jd,
    observer_position=obspos[0],
    acc_func=acc_func,
)

sky_sep(ra_dynamic, dec_dynamic, ra_static, dec_static)

Array(1.72812242e-08, dtype=float64)