In [None]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import mpmath

mpmath.mp.dps = 75
from jorbit.data.constants import IAS15_RR
from jorbit.utils.generate_coefficients import create_iasnn_constants

In [2]:
at = jax.random.normal(jax.random.PRNGKey(0), (10, 3)) * 1_000
a0 = jax.random.normal(jax.random.PRNGKey(1), (10, 3)) * 1_000
g = jax.random.normal(jax.random.PRNGKey(2), (7, 10, 3)) * 1_000
# r = 1/IAS15_RR

In [None]:
# ruff: noqa
def old_scheme(at, a0, g):
    gk = at - a0
    gp0 = gk / IAS15_RR[0]
    gp1 = (gk / IAS15_RR[1] - g[0]) / IAS15_RR[2]
    gp2 = ((gk / IAS15_RR[3] - g[0]) / IAS15_RR[4] - g[1]) / IAS15_RR[5]
    gp3 = (
        ((gk / IAS15_RR[6] - g[0]) / IAS15_RR[7] - g[1]) / IAS15_RR[8] - g[2]
    ) / IAS15_RR[9]
    gp4 = (
        (((gk / IAS15_RR[10] - g[0]) / IAS15_RR[11] - g[1]) / IAS15_RR[12] - g[2])
        / IAS15_RR[13]
        - g[3]
    ) / IAS15_RR[14]
    gp5 = (
        (
            (((gk / IAS15_RR[15] - g[0]) / IAS15_RR[16] - g[1]) / IAS15_RR[17] - g[2])
            / IAS15_RR[18]
            - g[3]
        )
        / IAS15_RR[19]
        - g[4]
    ) / IAS15_RR[20]
    gp6 = (
        (
            (
                (
                    ((gk / IAS15_RR[21] - g[0]) / IAS15_RR[22] - g[1]) / IAS15_RR[23]
                    - g[2]
                )
                / IAS15_RR[24]
                - g[3]
            )
            / IAS15_RR[25]
            - g[4]
        )
        / IAS15_RR[26]
        - g[5]
    ) / IAS15_RR[27]
    return jnp.stack([gp0, gp1, gp2, gp3, gp4, gp5, gp6])


old_gs = old_scheme(at, a0, g)
old_gs.shape

(7, 10, 3)

In [4]:
r = 1 / IAS15_RR


def new_scheme(at, a0, g, r, substep_num):
    substep_num -= 1

    def scan_body(carry: tuple, idx: int) -> tuple:
        result, start_pos = carry
        result = (result - g[idx]) * r[start_pos + idx + 1]
        return (result, start_pos), result

    start_pos = (substep_num * (substep_num + 1)) // 2
    initial_result = (at - a0) * r[start_pos]
    indices = jnp.arange(substep_num)
    (new_gp_substep, _), _ = jax.lax.scan(
        scan_body, (initial_result, start_pos), indices
    )
    return new_gp_substep


new_gs = []
for substep_num in range(1, 8):
    new_gs.append(new_scheme(at, a0, g, r, substep_num))
new_gs = jnp.stack(new_gs)

# errors
jnp.max(jnp.abs(old_gs - new_gs)), jnp.max(jnp.abs((old_gs / new_gs) - 1))

(Array(2.32830644e-10, dtype=float64), Array(2.88657986e-15, dtype=float64))

In [5]:
def two_sum(a, b):
    s = a + b
    bb = s - a
    err = (a - (s - bb)) + (b - bb)
    return s, err


def two_diff(a, b):
    s = a - b
    bb = s - a
    err = (a - (s - bb)) - (b + bb)
    return s, err


def two_prod(x, y):
    constant = 2**27 + 1
    p = x * constant
    hx = x - p + p
    tx = x - hx

    p = y * constant
    hy = y - p + p
    ty = y - hy

    p = hx * hy
    q = hx * ty + tx * hy
    z = p + q
    zz = p - z + q + tx * ty
    return z, zz


def newest_scheme(at, a0, g, r, substep_num):
    substep_num -= 1

    def scan_body(carry, idx):
        x, x_err, start_pos = carry

        # Step 1: subtraction with compensation
        y, y_err = two_diff(x, g[idx])
        y_err = y_err + x_err

        # Step 2: multiplication with compensation
        r_val = r[start_pos + idx + 1]
        z, z_err = two_prod(y, r_val)

        # Propagate previous error through multiplication
        z_err = z_err + y_err * r_val

        return (z, z_err, start_pos), z

    start_pos = (substep_num * (substep_num + 1)) // 2
    initial_result = (at - a0) * r[start_pos]
    indices = jnp.arange(substep_num)
    (new_gp_substep, errs, _), _ = jax.lax.scan(
        scan_body, (initial_result, jnp.zeros((10, 3)), start_pos), indices
    )
    return new_gp_substep, errs


comp_gs = []
for substep_num in range(1, 8):
    comp_gs.append(newest_scheme(at, a0, g, r, substep_num))
comp_gs = jnp.array(comp_gs)

In [6]:
jnp.where(jnp.abs(old_gs - new_gs) == jnp.max(jnp.abs(old_gs - new_gs)))

(Array([6], dtype=int64), Array([9], dtype=int64), Array([2], dtype=int64))

In [7]:
h_mpf, r_mpf, c_mpf, d_mpf = create_iasnn_constants(7)


def high_prec_scheme(at, a0, g):
    gk = at - a0
    gp0 = gk / r_mpf[0]
    gp1 = (gk / r_mpf[1] - g[0]) / r_mpf[2]
    gp2 = ((gk / r_mpf[3] - g[0]) / r_mpf[4] - g[1]) / r_mpf[5]
    gp3 = (((gk / r_mpf[6] - g[0]) / r_mpf[7] - g[1]) / r_mpf[8] - g[2]) / r_mpf[9]
    gp4 = (
        (((gk / r_mpf[10] - g[0]) / r_mpf[11] - g[1]) / r_mpf[12] - g[2]) / r_mpf[13]
        - g[3]
    ) / r_mpf[14]
    gp5 = (
        (
            (((gk / r_mpf[15] - g[0]) / r_mpf[16] - g[1]) / r_mpf[17] - g[2])
            / r_mpf[18]
            - g[3]
        )
        / r_mpf[19]
        - g[4]
    ) / r_mpf[20]
    gp6 = (
        (
            (
                (((gk / r_mpf[21] - g[0]) / r_mpf[22] - g[1]) / r_mpf[23] - g[2])
                / r_mpf[24]
                - g[3]
            )
            / r_mpf[25]
            - g[4]
        )
        / r_mpf[26]
        - g[5]
    ) / r_mpf[27]
    return [gp0, gp1, gp2, gp3, gp4, gp5, gp6]


particle, dim = 9, 2
high_prec_gs = high_prec_scheme(
    at=mpmath.mpf(float(at[particle, dim])),
    a0=mpmath.mpf(float(a0[particle, dim])),
    g=[mpmath.mpf(float(g[i, particle, dim])) for i in range(7)],
)

100%|██████████| 998/998 [00:00<00:00, 1349.21it/s]


In [9]:
diffs_old = []
for i in range(7):
    diffs_old.append(abs(mpmath.mpf(float(old_gs[i, particle, dim])) - high_prec_gs[i]))

diffs_new = []
for i in range(7):
    diffs_new.append(abs(mpmath.mpf(float(new_gs[i, particle, dim])) - high_prec_gs[i]))

diff_comp = []
for i in range(7):
    diff_comp.append(
        abs(
            mpmath.mpf(float(comp_gs[:, 0, :, :][i, particle, dim])) - high_prec_gs[i]
        )
    )

for i in range(7):
    internal = abs(
        mpmath.mpf(float(old_gs[i, particle, dim]))
        - mpmath.mpf(float(new_gs[i, particle, dim]))
    )
    print(
        f"comp: {mpmath.nstr(diff_comp[i])}, new: {mpmath.nstr(diffs_new[i])}, old: {mpmath.nstr(diffs_old[i])}, internal: {mpmath.nstr(internal)}"
    )

comp: 2.3655e-12, new: 2.3655e-12, old: 2.3655e-12, internal: 0.0
comp: 9.19659e-12, new: 9.19659e-12, old: 9.19659e-12, internal: 0.0
comp: 1.03265e-11, new: 1.03265e-11, old: 1.87773e-11, internal: 2.91038e-11
comp: 1.33885e-11, new: 1.16344e-12, old: 1.16344e-12, internal: 0.0
comp: 1.3392e-11, new: 1.3392e-11, old: 1.3392e-11, internal: 0.0
comp: 3.9894e-11, new: 1.83136e-11, old: 1.83136e-11, internal: 0.0
comp: 4.85512e-11, new: 1.84279e-10, old: 4.85512e-11, internal: 2.32831e-10


In [10]:
diff_comp = []
for i in range(7):
    diff_comp.append(
        mpmath.mpf(float(comp_gs[:, 0, :, :][i, particle, dim])) - high_prec_gs[i]
    )
    print(
        f"{mpmath.nstr(mpmath.mpf(float(comp_gs[:,1,:,:][i, particle, dim])))}, {mpmath.nstr(diff_comp[i])}"
    )
    print(
        mpmath.nstr(
            mpmath.mpf(float(comp_gs[:, 1, :, :][i, particle, dim])) + diff_comp[i]
        )
    )
    print()

0.0, -2.3655e-12
-2.3655e-12

4.90811e-12, 9.19659e-12
1.41047e-11

-5.73603e-12, 1.03265e-11
4.59048e-12

1.07803e-11, -1.33885e-11
-2.60813e-12

1.43376e-11, -1.3392e-11
9.45626e-13

3.42126e-11, -3.9894e-11
-5.68139e-12

1.91801e-11, 4.85512e-11
6.77314e-11



In [11]:
def _new_scheme(at, a0, g, r, substep_num):
    print("starting for substep_num =", substep_num)
    substep_num -= 1

    def scan_body(carry: tuple, idx: int) -> tuple:
        result, start_pos = carry
        result = (result - g[idx]) * r[start_pos + idx + 1]
        return (result, start_pos), result

    start_pos = (substep_num * (substep_num + 1)) // 2
    initial_result = (at - a0) * r[start_pos]
    indices = jnp.arange(substep_num)
    print("start pos = ", start_pos)
    print(len(indices))
    print()
    (new_gp_substep, _), _ = jax.lax.scan(
        scan_body, (initial_result, start_pos), indices
    )
    return new_gp_substep


new_gs = []
for substep_num in range(1, 8):
    new_gs.append(_new_scheme(at, a0, g, r, substep_num))
new_gs = jnp.stack(new_gs)

starting for substep_num = 1
start pos =  0
0

starting for substep_num = 2
start pos =  1
1

starting for substep_num = 3
start pos =  3
2

starting for substep_num = 4
start pos =  6
3

starting for substep_num = 5
start pos =  10
4

starting for substep_num = 6
start pos =  15
5

starting for substep_num = 7
start pos =  21
6



In [12]:
rr = 1 / IAS15_RR

ias15_r1 = jnp.array([rr[0]])
ias15_r2 = rr[1:1+2]
ias15_r3 = rr[3:3+3]
ias15_r4 = rr[6:6+4]
ias15_r5 = rr[10:10+5]
ias15_r6 = rr[15:15+6]
ias15_r7 = rr[21:21+7]

In [None]:
def refine_sub_g(at, a0, previous_gs, r):  # noqa: ANN201

    def scan_body(carry: tuple, scan_over: tuple):
        result = carry
        g, r_sub = scan_over
        result = (result - g) * r_sub
        return result, None

    initial_result = (at - a0) * r[0]
    new_g, _ = jax.lax.scan(
        scan_body, initial_result, (previous_gs, r[1:])
    )
    return new_g

refine_sub_g(at, a0, g[:3], ias15_r4) - old_gs[3]

Array([[ 0.00000000e+00, -4.36557457e-11,  0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00, -1.81898940e-12],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  7.27595761e-12],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 0.00000000e+00, -1.81898940e-12,  3.63797881e-12],
       [-3.63797881e-12,  0.00000000e+00,  4.36557457e-11],
       [ 1.45519152e-11,  5.82076609e-11,  0.00000000e+00],
       [ 0.00000000e+00,  3.63797881e-12, -1.45519152e-11],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00]],      dtype=float64)

In [14]:
final_gs = []
for i, r in enumerate([ias15_r1, ias15_r2, ias15_r3, ias15_r4, ias15_r5, ias15_r6, ias15_r7]):
    final_gs.append(refine_sub_g(at, a0, g[:i], r))
final_gs = jnp.stack(final_gs)

final_gs - old_gs

Array([[[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [-1.81898940e-12,  1.81898940e-12,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  3.63797881e-12],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00]],

       [[-7.27595761e-12,  0.00000000e+00, -1.81898940e-12],
        [ 7.27595761e-12, -3.63797881e-12,  0.00000000e+00],
        [ 0.00000000e+00, -1.81898940e-12,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00, -7.27595761e-12],
        [-1.45519152e-11, -2.91038305e-11,  7.27595761e-12],
        [-7.27595761e-12,  3.63797881e-12,  3.63797881e-12],
        [-1.81898940e-

In [17]:
print(jnp.max(jnp.abs(final_gs - old_gs)), jnp.max(jnp.abs((final_gs / old_gs) - 1)))
print(jnp.max(jnp.abs(new_gs - old_gs)), jnp.max(jnp.abs((new_gs / old_gs) - 1)))
print(jnp.max(jnp.abs(new_gs - final_gs)), jnp.max(jnp.abs((new_gs / final_gs) - 1)))

2.3283064365386963e-10 2.886579864025407e-15
2.3283064365386963e-10 2.886579864025407e-15
0.0 0.0
