In [1]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import mpmath

mpmath.mp.dps = 75
from jorbit.data.constants import IAS15_C

In [2]:
@jax.jit
def add_cs(p: jnp.ndarray, csp: jnp.ndarray, inp: jnp.ndarray) -> tuple:
    """Compensated summation.

    Args:
        p (jnp.ndarray):
            The current sum.
        csp (jnp.ndarray):
            The current compensation.
        inp (jnp.ndarray):
            The input to add.

    Returns:
        tuple:
            The new sum and compensation.
    """
    y = inp - csp
    t = p + y
    csp = (t - p) - y
    p = t
    return p, csp

In [3]:
at = jax.random.normal(jax.random.PRNGKey(0), (10, 3)) * 1_000
a0 = jax.random.normal(jax.random.PRNGKey(1), (10, 3)) * 1_000
g = jax.random.normal(jax.random.PRNGKey(2), (7, 10, 3)) * 1_000
# r = 1/IAS15_RR

In [None]:
# ruff: noqa
def old_scheme(b, csb, tmp):
    n1 = (add_cs(b[0], csb[0], tmp),)

    n2 = (
        add_cs(b[0], csb[0], tmp * IAS15_C[0]),
        add_cs(b[1], csb[1], tmp)
    )

    n3 = (
        add_cs(b[0], csb[0], tmp * IAS15_C[1]),
        add_cs(b[1], csb[1], tmp * IAS15_C[2]),
        add_cs(b[2], csb[2], tmp)
    )

    n4 = (
        add_cs(b[0], csb[0], tmp * IAS15_C[3]),
        add_cs(b[1], csb[1], tmp * IAS15_C[4]),
        add_cs(b[2], csb[2], tmp * IAS15_C[5]),
        add_cs(b[3], csb[3], tmp)
    )

    n5 = (
        add_cs(b[0], csb[0], tmp * IAS15_C[6]),
        add_cs(b[1], csb[1], tmp * IAS15_C[7]),
        add_cs(b[2], csb[2], tmp * IAS15_C[8]),
        add_cs(b[3], csb[3], tmp * IAS15_C[9]),
        add_cs(b[4], csb[4], tmp)
    )

    n6 = (
        add_cs(b[0], csb[0], tmp * IAS15_C[10]),
        add_cs(b[1], csb[1], tmp * IAS15_C[11]),
        add_cs(b[2], csb[2], tmp * IAS15_C[12]),
        add_cs(b[3], csb[3], tmp * IAS15_C[13]),
        add_cs(b[4], csb[4], tmp * IAS15_C[14]),
        add_cs(b[5], csb[5], tmp)
    )

    n7 = (
        add_cs(b[0], csb[0], tmp * IAS15_C[15]),
        add_cs(b[1], csb[1], tmp * IAS15_C[16]),
        add_cs(b[2], csb[2], tmp * IAS15_C[17]),
        add_cs(b[3], csb[3], tmp * IAS15_C[18]),
        add_cs(b[4], csb[4], tmp * IAS15_C[19]),
        add_cs(b[5], csb[5], tmp * IAS15_C[20]),
        add_cs(b[6], csb[6], tmp)
    )

    return [n1, n2, n3, n4, n5, n6, n7]

b = jax.random.normal(jax.random.PRNGKey(3), (7, 10, 3)) * 1_000
csb = jax.random.normal(jax.random.PRNGKey(4), (7, 10, 3)) * 1_000
tmp = jax.random.normal(jax.random.PRNGKey(5), (10, 3)) * 1_000

old_version = old_scheme(b, csb, tmp)

In [9]:
ias15_c1 = jnp.array([1.0])
ias15_c2 = jnp.concatenate([IAS15_C[0:1], jnp.array([1.0])])
ias15_c3 = jnp.concatenate([IAS15_C[1:3], jnp.array([1.0])])
ias15_c4 = jnp.concatenate([IAS15_C[3:6], jnp.array([1.0])])
ias15_c5 = jnp.concatenate([IAS15_C[6:10], jnp.array([1.0])])
ias15_c6 = jnp.concatenate([IAS15_C[10:15], jnp.array([1.0])])
ias15_c7 = jnp.concatenate([IAS15_C[15:21], jnp.array([1.0])])

def update_bs(current_bs, current_csbs, g_diff, c):
    return add_cs(current_bs,  current_csbs, (g_diff[None, :] * c[:, None, None]))



In [10]:
new_version = []
for i, c in enumerate([ias15_c1, ias15_c2, ias15_c3, ias15_c4, ias15_c5, ias15_c6, ias15_c7]):
    b_, csb_ = update_bs(b[:i+1], csb[:i+1], tmp, c)
    new_version.append((b_, csb_))

In [19]:
for substep in range(7):
    for b in range(substep):
        new = new_version[substep][0][b]
        old = old_version[substep][b][0]
        print(jnp.max(jnp.abs(new - old)))
        new = new_version[substep][1][b]
        old = old_version[substep][b][1]
        print(jnp.max(jnp.abs(new - old)))
        print()


0.0
0.0

0.0
0.0

0.0
0.0

0.0
0.0

0.0
0.0

0.0
0.0

0.0
0.0

0.0
0.0

0.0
0.0

0.0
0.0

0.0
0.0

0.0
0.0

0.0
0.0

0.0
0.0

0.0
0.0

0.0
0.0

0.0
0.0

0.0
0.0

0.0
0.0

0.0
0.0

0.0
0.0

