In [1]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from jorbit.data.constants import IAS15_EPS_Modified

In [2]:
# ruff: noqa
a0 = jax.random.normal(jax.random.PRNGKey(0), (10, 3))
b = jax.random.normal(jax.random.PRNGKey(1), (7, 10, 3))
dt_done = 1e-3

def old_version(a0, b, dt_done):

    tmp = a0 + b[0] + b[1] + b[2] + b[3] + b[4] + b[5] + b[6]
    y2 = jnp.sum(tmp * tmp, axis=1)
    tmp = (
        b[0]
        + 2.0 * b[1]
        + 3.0 * b[2]
        + 4.0 * b[3]
        + 5.0 * b[4]
        + 6.0 * b[5]
        + 7.0 * b[6]
    )
    y3 = jnp.sum(tmp * tmp, axis=1)
    tmp = (
        2.0 * b[1] + 6.0 * b[2] + 12.0 * b[3] + 20.0 * b[4] + 30.0 * b[5] + 42.0 * b[6]
    )
    y4 = jnp.sum(tmp * tmp, axis=1)
    print(y4.sum())

    timescale2 = 2.0 * y2 / (y3 + jnp.sqrt(y4 * y2))  # PRS23
    min_timescale2 = jnp.nanmin(timescale2)

    dt_new = jnp.sqrt(min_timescale2) * dt_done * IAS15_EPS_Modified
    return dt_new

old_version(a0, b, dt_done)

78181.74904275755


Array(2.27691114e-05, dtype=float64)

In [4]:
def new_version(a0, b, dt_done):
    tmp = a0 + jnp.sum(b, axis=0)
    y2 = jnp.sum(tmp * tmp, axis=1)

    coeffs_1 = jnp.arange(1, 8)
    tmp = jnp.sum(coeffs_1[:, None, None] * b, axis=0)
    y3 = jnp.sum(tmp * tmp, axis=1)

    coeffs_2 = jnp.arange(2, 8) * jnp.arange(1, 7)
    tmp = jnp.sum(coeffs_2[:, None, None] * b[1:], axis=0)
    y4 = jnp.sum(tmp * tmp, axis=1)

    timescale2 = 2.0 * y2 / (y3 + jnp.sqrt(y4 * y2))  # PRS23
    min_timescale2 = jnp.nanmin(timescale2)
    dt_new = jnp.sqrt(min_timescale2) * dt_done * IAS15_EPS_Modified
    return dt_new

new_version(a0, b, dt_done)

Array(2.27691114e-05, dtype=float64)