In [None]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import rebound

from jorbit.accelerations import newtonian_gravity
from jorbit.integrators.ias15 import ias15_step, initialize_ias15_integrator_state
from jorbit.integrators.ias15_clean import _step
from jorbit.utils.states import SystemState


In [29]:
tracer_positions = jax.random.normal(jax.random.PRNGKey(0), (50, 3))
tracer_velocities = jax.random.normal(jax.random.PRNGKey(1), (50, 3))
massive_positions = jax.random.normal(jax.random.PRNGKey(2), (100, 3))
massive_velocities = jax.random.normal(jax.random.PRNGKey(3), (100, 3))
log_gms = jax.random.normal(jax.random.PRNGKey(4), (100,))


init_state = SystemState(
    tracer_positions=tracer_positions,
    tracer_velocities=tracer_velocities,
    massive_positions=massive_positions,
    massive_velocities=massive_velocities,
    log_gms=log_gms,
    acceleration_func_kwargs={},
    time=0.0,
)
a0 = newtonian_gravity(init_state)

integrator_state_old = initialize_ias15_integrator_state(a0)
integrator_state_new = initialize_ias15_integrator_state(a0)

integrator_state_old.dt = 1e-3
integrator_state_new.dt = 1e-3

In [30]:
sim = rebound.Simulation()
for p, v in zip(tracer_positions, tracer_velocities):
    sim.add(m=0.0, x=p[0], y=p[1], z=p[2], vx=v[0], vy=v[1], vz=v[2])
for p, v, log_gm in zip(massive_positions, massive_velocities, log_gms):
    sim.add(m=jnp.exp(log_gm), x=p[0], y=p[1], z=p[2], vx=v[0], vy=v[1], vz=v[2])

sim.integrate(1e-3)
rebound_pos = jnp.array([[p.x, p.y, p.z] for p in sim.particles])
rebound_vel = jnp.array([[p.vx, p.vy, p.vz] for p in sim.particles])

In [31]:
new = _step(
    init_state,
    jax.tree_util.Partial(newtonian_gravity),
    integrator_state_new,
)

In [32]:
old, _ = ias15_step(
    init_state,
    jax.tree_util.Partial(newtonian_gravity),
    integrator_state_old,
)

In [33]:
print("old v. new: ", jnp.max(jnp.abs(old.massive_positions - new.massive_positions)), jnp.max(jnp.abs(old.massive_velocities - new.massive_velocities)))
print("rebound v. new: ", jnp.max(jnp.abs(rebound_pos[50:] - new.massive_positions)), jnp.max(jnp.abs(rebound_vel[50:] - new.massive_velocities)))

old v. new:  6.938893903907228e-18 2.3592239273284576e-16
rebound v. new:  0.0 4.440892098500626e-16
