In [1]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from jorbit.integrators.ias15 import IAS15Helper, initialize_ias15_helper
from jorbit.data.constants import IAS15_BEZIER_COEFFS

In [12]:
# ruff: noqa
def old_version(ratio: jnp.ndarray, _e: IAS15Helper, _b: IAS15Helper) -> tuple:

    e = IAS15Helper(
        p0=jnp.zeros_like(_e.p0, dtype=jnp.float64),
        p1=jnp.zeros_like(_e.p1, dtype=jnp.float64),
        p2=jnp.zeros_like(_e.p2, dtype=jnp.float64),
        p3=jnp.zeros_like(_e.p3, dtype=jnp.float64),
        p4=jnp.zeros_like(_e.p4, dtype=jnp.float64),
        p5=jnp.zeros_like(_e.p5, dtype=jnp.float64),
        p6=jnp.zeros_like(_e.p6, dtype=jnp.float64),
    )
    b = IAS15Helper(
        p0=jnp.zeros_like(_b.p0, dtype=jnp.float64),
        p1=jnp.zeros_like(_b.p1, dtype=jnp.float64),
        p2=jnp.zeros_like(_b.p2, dtype=jnp.float64),
        p3=jnp.zeros_like(_b.p3, dtype=jnp.float64),
        p4=jnp.zeros_like(_b.p4, dtype=jnp.float64),
        p5=jnp.zeros_like(_b.p5, dtype=jnp.float64),
        p6=jnp.zeros_like(_b.p6, dtype=jnp.float64),
    )

    def large_ratio(ratio: jnp.ndarray, er: IAS15Helper, br: IAS15Helper) -> tuple:
        return e, b

    def reasonable_ratio(ratio: jnp.ndarray, er: IAS15Helper, br: IAS15Helper) -> tuple:
        q1 = ratio
        q2 = q1 * q1
        q3 = q1 * q2
        q4 = q2 * q2
        q5 = q2 * q3
        q6 = q3 * q3
        q7 = q3 * q4

        be0 = _b.p0 - _e.p0
        be1 = _b.p1 - _e.p1
        be2 = _b.p2 - _e.p2
        be3 = _b.p3 - _e.p3
        be4 = _b.p4 - _e.p4
        be5 = _b.p5 - _e.p5
        be6 = _b.p6 - _e.p6

        e.p0 = q1 * (
            _b.p6 * 7.0
            + _b.p5 * 6.0
            + _b.p4 * 5.0
            + _b.p3 * 4.0
            + _b.p2 * 3.0
            + _b.p1 * 2.0
            + _b.p0
        )
        e.p1 = q2 * (
            _b.p6 * 21.0
            + _b.p5 * 15.0
            + _b.p4 * 10.0
            + _b.p3 * 6.0
            + _b.p2 * 3.0
            + _b.p1
        )
        e.p2 = q3 * (_b.p6 * 35.0 + _b.p5 * 20.0 + _b.p4 * 10.0 + _b.p3 * 4.0 + _b.p2)
        e.p3 = q4 * (_b.p6 * 35.0 + _b.p5 * 15.0 + _b.p4 * 5.0 + _b.p3)
        e.p4 = q5 * (_b.p6 * 21.0 + _b.p5 * 6.0 + _b.p4)
        e.p5 = q6 * (_b.p6 * 7.0 + _b.p5)
        e.p6 = q7 * _b.p6

        b.p0 = e.p0 + be0
        b.p1 = e.p1 + be1
        b.p2 = e.p2 + be2
        b.p3 = e.p3 + be3
        b.p4 = e.p4 + be4
        b.p5 = e.p5 + be5
        b.p6 = e.p6 + be6

        return e, b

    e, b = jax.lax.cond(ratio > 20.0, large_ratio, reasonable_ratio, ratio, _e, _b)
    return e, b

_e = jax.random.normal(jax.random.PRNGKey(0), (7, 100, 3))
e = initialize_ias15_helper(100)
e.p0 = _e[0]
e.p1 = _e[1]
e.p2 = _e[2]
e.p3 = _e[3]
e.p4 = _e[4]
e.p5 = _e[5]
e.p6 = _e[6]
_b = jax.random.normal(jax.random.PRNGKey(1), (7, 100, 3))
b = initialize_ias15_helper(100)
b.p0 = _b[0]
b.p1 = _b[1]
b.p2 = _b[2]
b.p3 = _b[3]
b.p4 = _b[4]
b.p5 = _b[5]
b.p6 = _b[6]
ratio = 2.0

old_e, old_b = old_version(ratio, e, b)
old_e = jnp.stack([old_e.p0, old_e.p1, old_e.p2, old_e.p3, old_e.p4, old_e.p5, old_e.p6], axis=0)
old_b = jnp.stack([old_b.p0, old_b.p1, old_b.p2, old_b.p3, old_b.p4, old_b.p5, old_b.p6], axis=0)

In [13]:
def new_version(ratio, e, b) -> tuple:
    
    def large_ratio(ratio, e, b) -> tuple:
        e_new = jnp.zeros_like(e)
        b_new = jnp.zeros_like(b)
        return e_new, b_new

    def reasonable_ratio(ratio, e, b) -> tuple:
        qs = ratio ** jnp.arange(1, 8)
        diff = b - e
        e = jnp.einsum('i,ij,j...->i...', qs, IAS15_BEZIER_COEFFS, b)
        b = e + diff
        return e, b

    e, b = jax.lax.cond(ratio > 20.0, large_ratio, reasonable_ratio, ratio, e, b)

    return e, b

new_e, new_b = new_version(ratio, _e, _b)

In [14]:
jnp.max(jnp.abs(new_e - old_e)), jnp.max(jnp.abs(new_b - old_b))

(Array(4.54747351e-13, dtype=float64), Array(4.54747351e-13, dtype=float64))