In [1]:
import jax

jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp

from jorbit.data.constants import IAS15_D, IAS15_D_MATRIX
from jorbit.integrators.ias15 import (
    initialize_ias15_helper,
    initialize_ias15_integrator_state,
)

In [None]:
def _test_new_system(seed: int) -> None:
    n_particles = 31
    b = initialize_ias15_helper(n_particles)
    keys = jax.random.split(jax.random.PRNGKey(seed), 8)
    b.p0 = jax.random.normal(keys[0], (n_particles, 3))
    b.p1 = jax.random.normal(keys[1], (n_particles, 3))
    b.p2 = jax.random.normal(keys[2], (n_particles, 3))
    b.p3 = jax.random.normal(keys[3], (n_particles, 3))
    b.p4 = jax.random.normal(keys[4], (n_particles, 3))
    b.p5 = jax.random.normal(keys[5], (n_particles, 3))
    b.p6 = jax.random.normal(keys[6], (n_particles, 3))
    initial_integrator_state = initialize_ias15_integrator_state(
        jax.random.normal(keys[7], (n_particles, 3))
    )

    g = initialize_ias15_helper(n_particles)
    g.p0 = (
        b.p6 * IAS15_D[15]
        + b.p5 * IAS15_D[10]
        + b.p4 * IAS15_D[6]
        + b.p3 * IAS15_D[3]
        + b.p2 * IAS15_D[1]
        + b.p1 * IAS15_D[0]
        + b.p0
    )
    g.p1 = (
        b.p6 * IAS15_D[16]
        + b.p5 * IAS15_D[11]
        + b.p4 * IAS15_D[7]
        + b.p3 * IAS15_D[4]
        + b.p2 * IAS15_D[2]
        + b.p1
    )
    g.p2 = (
        b.p6 * IAS15_D[17]
        + b.p5 * IAS15_D[12]
        + b.p4 * IAS15_D[8]
        + b.p3 * IAS15_D[5]
        + b.p2
    )
    g.p3 = b.p6 * IAS15_D[18] + b.p5 * IAS15_D[13] + b.p4 * IAS15_D[9] + b.p3
    g.p4 = b.p6 * IAS15_D[19] + b.p5 * IAS15_D[14] + b.p4
    g.p5 = b.p6 * IAS15_D[20] + b.p5
    g.p6 = b.p6

    g2 = initial_integrator_state.g
    b_stack = jnp.stack([b.p0, b.p1, b.p2, b.p3, b.p4, b.p5, b.p6], axis=0)
    g_stack = jnp.einsum("ij,jnk->ink", IAS15_D_MATRIX, b_stack)
    g2.p0 = g_stack[0]
    g2.p1 = g_stack[1]
    g2.p2 = g_stack[2]
    g2.p3 = g_stack[3]
    g2.p4 = g_stack[4]
    g2.p5 = g_stack[5]
    g2.p6 = g_stack[6]

    assert jnp.allclose(g.p0, g2.p0)
    assert jnp.allclose(g.p1, g2.p1)
    assert jnp.allclose(g.p2, g2.p2)
    assert jnp.allclose(g.p3, g2.p3)
    assert jnp.allclose(g.p4, g2.p4)
    assert jnp.allclose(g.p5, g2.p5)
    assert jnp.allclose(g.p6, g2.p6)


for seed in range(100):
    _test_new_system(seed)