In [1]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import rebound

from jorbit.accelerations import newtonian_gravity
from jorbit.integrators.ias15 import ias15_step, initialize_ias15_integrator_state
from jorbit.integrators.ias15_clean import _step
from jorbit.utils.states import SystemState


In [7]:
tracer_positions = jax.random.normal(jax.random.PRNGKey(0), (5, 3))
tracer_velocities = jax.random.normal(jax.random.PRNGKey(1), (5, 3))
massive_positions = jax.random.normal(jax.random.PRNGKey(2), (10, 3))
massive_velocities = jax.random.normal(jax.random.PRNGKey(3), (10, 3))
log_gms = jax.random.normal(jax.random.PRNGKey(4), (10,))


init_state = SystemState(
    tracer_positions=tracer_positions,
    tracer_velocities=tracer_velocities,
    massive_positions=massive_positions,
    massive_velocities=massive_velocities,
    log_gms=log_gms,
    acceleration_func_kwargs={},
    time=0.0,
)
a0 = newtonian_gravity(init_state)

integrator_state_old = initialize_ias15_integrator_state(a0)
integrator_state_new = initialize_ias15_integrator_state(a0)

integrator_state_old.dt = 1e-2
integrator_state_new.dt = 1e-2

In [8]:
sim = rebound.Simulation()
for p, v in zip(tracer_positions, tracer_velocities):
    sim.add(m=0.0, x=p[0], y=p[1], z=p[2], vx=v[0], vy=v[1], vz=v[2])
for p, v, log_gm in zip(massive_positions, massive_velocities, log_gms):
    sim.add(m=jnp.exp(log_gm), x=p[0], y=p[1], z=p[2], vx=v[0], vy=v[1], vz=v[2])

sim.integrate(integrator_state_old.dt)
rebound_pos = jnp.array([[p.x, p.y, p.z] for p in sim.particles])
rebound_vel = jnp.array([[p.vx, p.vy, p.vz] for p in sim.particles])

In [9]:
new = _step(
    init_state,
    jax.tree_util.Partial(newtonian_gravity),
    integrator_state_new,
)

pc err: 1e+300
pc err: 0.0005976427514210323
pc err: 0.0005846773799080454
pc err: 6.589989688739858e-06
pc err: 1.6107340014655208e-08
pc err: 1.9107949176529905e-11
pc err: 0.0
just chillin
pc err: 0.0
just chillin
pc err: 0.0
just chillin
pc err: 0.0
just chillin


In [10]:
old, _ = ias15_step(
    init_state,
    jax.tree_util.Partial(newtonian_gravity),
    integrator_state_old,
)
import time  # jax.debug.printing keeps bumping to the next cell otherwise

time.sleep(1)

pc err: 1e+300
pc err: 0.0005976427578464939
pc err: 0.0005846773812093832
pc err: 6.589991056108137e-06
pc err: 1.6107573770849236e-08
pc err: 1.801303899115905e-11
pc err: 0.0
just chillin
pc err: 0.0
just chillin
pc err: 0.0
just chillin
pc err: 0.0
just chillin


In [11]:
print("old v. new: ", jnp.max(jnp.abs(old.massive_positions - new.massive_positions)), jnp.max(jnp.abs(old.massive_velocities - new.massive_velocities)))
print("rebound v. new: ", jnp.max(jnp.abs(rebound_pos[5:] - new.massive_positions)), jnp.max(jnp.abs(rebound_vel[5:] - new.massive_velocities)))

old v. new:  3.469446951953614e-18 1.1102230246251565e-16
rebound v. new:  1.734723475976807e-18 1.1102230246251565e-16
