In [1]:
import jax

jax.config.update("jax_enable_x64", True)
from types import SimpleNamespace

import jax.numpy as jnp
from jax import random

# Set up test data
key = random.PRNGKey(0)
N = 10
n = 1

b = SimpleNamespace(
    p0=random.normal(key, (N, 3)),
    p1=random.normal(random.split(key)[0], (N, 3)),
    p2=random.normal(random.split(key)[1], (N, 3)),
    p3=random.normal(random.split(key)[2], (N, 3)),
    p4=random.normal(random.split(key)[3], (N, 3)),
    p5=random.normal(random.split(key)[4], (N, 3)),
    p6=random.normal(random.split(key)[5], (N, 3)),
)

IAS15_H = jnp.array(
    [
        0.0,
        0.0562625605369221,
        0.180240691736892,
        0.352624717113169,
        0.547153626330556,
        0.734210177215410,
        0.885320946839095,
        1.0,
    ]
)
dt = 0.01
x0 = random.normal(random.split(key)[6], (N, 3))
v0 = random.normal(random.split(key)[7], (N, 3))
a0 = random.normal(random.split(key)[8], (N, 3))
csx = random.normal(random.split(key)[9], (N, 3))
csv = random.normal(random.split(key)[10], (N, 3))

h = IAS15_H[n]

# OLD METHOD
# fmt: off
x_old = x0 - csx + ((((((((b.p6*7.*IAS15_H[n]/9. + b.p5)*3.*IAS15_H[n]/4. + b.p4)*5.*IAS15_H[n]/7. + b.p3)*2.*IAS15_H[n]/3. + b.p2)*3.*IAS15_H[n]/5. + b.p1)*IAS15_H[n]/2. + b.p0)*IAS15_H[n]/3. + a0)*dt*IAS15_H[n]/2. + v0)*dt*IAS15_H[n]
v_old = v0 - csv + (((((((b.p6*7.*IAS15_H[n]/8. + b.p5)*6.*IAS15_H[n]/7. + b.p4)*5.*IAS15_H[n]/6. + b.p3)*4.*IAS15_H[n]/5. + b.p2)*3.*IAS15_H[n]/4. + b.p1)*2.*IAS15_H[n]/3. + b.p0)*IAS15_H[n]/2. + a0)*dt*IAS15_H[n]
# fmt: on

# NEW METHOD - try stacking all coefficients at once
b_len = 7
k = jnp.arange(b_len)
b_x_denoms = ((k + 2) * (k + 3))[:, None, None]  # [6, 12, 20, 30, 42, 56, 72]
b_v_denoms = (k + 2)[:, None, None]  # [2, 3, 4, 5, 6, 7, 8]
b_stack = jnp.stack([b.p0, b.p1, b.p2, b.p3, b.p4, b.p5, b.p6], axis=0)
init_cond_x_stack = jnp.stack((x0, v0 * dt, a0 * dt * dt / 2.0))
init_cond_v_stack = jnp.stack((v0, a0 * dt))

x_coeffs = jnp.concatenate(
    [init_cond_x_stack, (b_stack * dt * dt / b_x_denoms)], axis=0
)[::-1]

v_coeffs = jnp.concatenate([init_cond_v_stack, (b_stack * dt / b_v_denoms)], axis=0)[
    ::-1
]

x2 = jnp.polyval(x_coeffs, h) - csx
v2 = jnp.polyval(v_coeffs, h) - csv

jnp.max(jnp.abs(x_old - x2)), jnp.max(jnp.abs(v_old - v2))  # should be close to zero

(Array(9.71445147e-17, dtype=float64), Array(1.02565526e-16, dtype=float64))