In [1]:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from typing import Callable
import chex
import matplotlib.pyplot as plt

import astropy.units as u
from astropy.time import Time
from jorbit import System, Particle
from jorbit.utils.states import SystemState
from jorbit.integrators import create_leapfrog_times

In [None]:
p1 = Particle.from_horizons(
    name="274301",
    time=Time("2026-01-01"),
    integrator="Y8",
    max_step_size=0.1 * u.day,
    gravity="default solar system",
)

p2 = Particle.from_horizons(
    name="274301",
    time=Time("2026-01-01"),
    integrator="ias15",
    gravity="default solar system",
)

ts = Time("2026-01-01") + jnp.linspace(0, 10, 11)*u.day
eph1 = p1.ephemeris(times=ts, observer="kitt peak")
eph2 = p2.ephemeris(times=ts, observer="kitt peak")

expanding to 108 times for leapfrog integration
positions size: (Array(11, dtype=int64), Array(1, dtype=int64), Array(3, dtype=int64))
positions size: (Array(11, dtype=int64), Array(1, dtype=int64), Array(3, dtype=int64))


In [6]:
eph1.separation(eph2)

<Angle [0.00000000e+00, 4.89617553e-14, 1.97034800e-13, 3.98177296e-13,
        6.93688508e-13, 1.04201158e-12, 1.53624979e-12, 1.98583165e-12,
        2.67458987e-12, 3.40836347e-12, 4.19892045e-12] deg>

In [3]:
ts = Time("2026-01-01") + jnp.linspace(0, 10, 11)*u.day
p.ephemeris(times=ts, observer="kitt peak")

expanding to 108 times for leapfrog integration
positions size: (Array(11, dtype=int64), Array(1, dtype=int64), Array(3, dtype=int64))


<SkyCoord (ICRS): (ra, dec) in deg
    [(231.79303109, -15.74049002), (232.23814192, -15.81833182),
     (232.68292331, -15.89468045), (233.12735445, -15.96952629),
     (233.57141235, -16.04285911), (234.01507142, -16.11466818),
     (234.4583034 , -16.18494236), (234.90107753, -16.25367029),
     (235.34336097, -16.32084055), (235.78511915, -16.38644188),
     (236.22631611, -16.45046322)]>

In [None]:
s0 = SystemState(
    massive_positions=jnp.array([[0.0, 0.0, 0.0]]),
    massive_velocities=jnp.array([[0.0, 0.0, 0.0]]),
    tracer_positions=jnp.array([[1.0, 0.0, 0.0]]),
    tracer_velocities=jnp.array([[0.0, 1.0, 0.0]]),
    log_gms=jnp.array([0.0]),
    time=0.0,
    acceleration_func_kwargs={},
)

s = System(state=s0, gravity="generic newtonian", integrator="Y4", max_step_size=0.01*u.day)

In [None]:
s.integrate(times=jnp.array([jnp.pi/2, jnp.pi, 3*jnp.pi/2, 2*jnp.pi])*u.day)

(Array([[[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
         [ 5.34898903e-09,  1.00000000e+00,  0.00000000e+00]],
 
        [[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
         [-1.00000001e+00,  2.48618800e-08,  0.00000000e+00]],
 
        [[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
         [-4.43747871e-08, -1.00000000e+00,  0.00000000e+00]],
 
        [[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
         [ 1.00000000e+00, -4.97237911e-08,  0.00000000e+00]]],      dtype=float64),
 Array([[[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
         [-9.99999996e-01,  8.88996470e-09,  0.00000000e+00]],
 
        [[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
         [-2.48618836e-08, -9.99999993e-01,  0.00000000e+00]],
 
        [[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
         [ 9.99999996e-01, -4.08338104e-08,  0.00000000e+00]],
 
        [[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
         [ 4.97237893e-08,  1