# scratch work

In [1]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from functools import partial


from decimal import Decimal, getcontext

getcontext().prec = 50


import mpmath

mpmath.mp.dps = 75

from jorbit.integrators import initialize_ias15_integrator_state
from jorbit.integrators.ias15 import ias15_step
from jorbit.utils.states import SystemState


from jorbit.utils.doubledouble import DoubleDouble, dd_sum, dd_sqrt
from jorbit.utils.generate_coefficients import create_iasnn_constants
from jorbit.data.constants import IAS15_D, IAS15_H, IAS15_RR, IAS15_C
from jorbit.integrators.ias15_dd import setup_iasnn_integrator, step

In [2]:
init_state = SystemState(
    massive_positions=jnp.empty((0, 3)),
    massive_velocities=jnp.empty((0, 3)),
    tracer_positions=jnp.ones((4, 3), dtype=jnp.float64) * 0.123,
    tracer_velocities=jnp.ones((4, 3), dtype=jnp.float64) * 0.2,
    log_gms=jnp.empty(0),
    time=0.0,
    acceleration_func_kwargs=None,
)


def _acc_func(s):
    return -jnp.concatenate([s.massive_positions, s.tracer_positions], axis=0)


acc_func = jax.tree_util.Partial(_acc_func)
a0 = acc_func(init_state)

init_integrator_state = initialize_ias15_integrator_state(a0)
init_integrator_state.dt = 0.01


b, csb, g, predictor_corrector_error, predictor_corrector_error_last = ias15_step(
    initial_system_state=init_state,
    acceleration_func=acc_func,
    initial_integrator_state=init_integrator_state,
)

In [3]:
precompued_setup = setup_iasnn_integrator(n_internal_points=7)

x0_init_dd = DoubleDouble(jnp.ones((4, 3), dtype=jnp.float64) * 0.123)
v0_init_dd = DoubleDouble(jnp.ones((4, 3), dtype=jnp.float64) * 0.2)
a0_init_dd = -x0_init_dd
b_init_dd = DoubleDouble(jnp.zeros((7, 4, 3), dtype=jnp.float64))

b_dd, g_dd, g_diff = step(
    x0_init_dd, v0_init_dd, a0_init_dd, b_init_dd, precompued_setup
)

100%|██████████| 998/998 [00:00<00:00, 1370.66it/s]


In [4]:
print(b_dd.hi[0] - b.p0)
print(b_dd.hi[1] - b.p1)
print(b_dd.hi[2] - b.p2)
print(b_dd.hi[3] - b.p3)
print(b_dd.hi[4] - b.p4)
print(b_dd.hi[5] - b.p5)
print(b_dd.hi[6] - b.p6)

[[1.73472348e-16 1.73472348e-16 1.73472348e-16]
 [1.73472348e-16 1.73472348e-16 1.73472348e-16]
 [1.73472348e-16 1.73472348e-16 1.73472348e-16]
 [1.73472348e-16 1.73472348e-16 1.73472348e-16]]
[[-2.37775108e-15 -2.37775108e-15 -2.37775108e-15]
 [-2.37775108e-15 -2.37775108e-15 -2.37775108e-15]
 [-2.37775108e-15 -2.37775108e-15 -2.37775108e-15]
 [-2.37775108e-15 -2.37775108e-15 -2.37775108e-15]]
[[1.19980601e-14 1.19980601e-14 1.19980601e-14]
 [1.19980601e-14 1.19980601e-14 1.19980601e-14]
 [1.19980601e-14 1.19980601e-14 1.19980601e-14]
 [1.19980601e-14 1.19980601e-14 1.19980601e-14]]
[[-3.0980925e-14 -3.0980925e-14 -3.0980925e-14]
 [-3.0980925e-14 -3.0980925e-14 -3.0980925e-14]
 [-3.0980925e-14 -3.0980925e-14 -3.0980925e-14]
 [-3.0980925e-14 -3.0980925e-14 -3.0980925e-14]]
[[4.32880188e-14 4.32880188e-14 4.32880188e-14]
 [4.32880188e-14 4.32880188e-14 4.32880188e-14]
 [4.32880188e-14 4.32880188e-14 4.32880188e-14]
 [4.32880188e-14 4.32880188e-14 4.32880188e-14]]
[[-3.10031316e-14 -3.10

In [5]:
init_state = SystemState(
    massive_positions=jnp.empty((0, 3)),
    massive_velocities=jnp.empty((0, 3)),
    tracer_positions=jnp.ones((4, 3), dtype=jnp.float64) * 0.123,
    tracer_velocities=jnp.ones((4, 3), dtype=jnp.float64) * 0.2,
    log_gms=jnp.empty(0),
    time=0.0,
    acceleration_func_kwargs=None,
)


def _acc_func(s):
    return -jnp.concatenate([s.massive_positions, s.tracer_positions], axis=0)


acc_func = jax.tree_util.Partial(_acc_func)
a0 = acc_func(init_state)

init_integrator_state = initialize_ias15_integrator_state(a0)
init_integrator_state.dt = 0.01
init_integrator_state.b.p0 += 1.0
init_integrator_state.b.p1 += 2.0
init_integrator_state.b.p2 += 3.0
init_integrator_state.b.p3 += 4.0
init_integrator_state.b.p4 += 5.0
init_integrator_state.b.p5 += 6.0
init_integrator_state.b.p6 += 7.0


b, csb, g, predictor_corrector_error, predictor_corrector_error_last = ias15_step(
    initial_system_state=init_state,
    acceleration_func=acc_func,
    initial_integrator_state=init_integrator_state,
)

precompued_setup = setup_iasnn_integrator(n_internal_points=7)

x0_init_dd = DoubleDouble(jnp.ones((4, 3), dtype=jnp.float64) * 0.123)
v0_init_dd = DoubleDouble(jnp.ones((4, 3), dtype=jnp.float64) * 0.2)
a0_init_dd = -x0_init_dd
b_init_dd = DoubleDouble(jnp.zeros((7, 4, 3), dtype=jnp.float64))

b_init_dd[0] += DoubleDouble(1.0)
b_init_dd[1] += DoubleDouble(2.0)
b_init_dd[2] += DoubleDouble(3.0)
b_init_dd[3] += DoubleDouble(4.0)
b_init_dd[4] += DoubleDouble(5.0)
b_init_dd[5] += DoubleDouble(6.0)
b_init_dd[6] += DoubleDouble(7.0)

b_dd, g_dd, g_diff = step(
    x0_init_dd, v0_init_dd, a0_init_dd, b_init_dd, precompued_setup
)

print(b_dd.hi[0] - b.p0)
print(b_dd.hi[1] - b.p1)
print(b_dd.hi[2] - b.p2)
print(b_dd.hi[3] - b.p3)
print(b_dd.hi[4] - b.p4)
print(b_dd.hi[5] - b.p5)
print(b_dd.hi[6] - b.p6)

100%|██████████| 998/998 [00:00<00:00, 1354.47it/s]


[[-6.37510877e-17 -6.37510877e-17 -6.37510877e-17]
 [-6.37510877e-17 -6.37510877e-17 -6.37510877e-17]
 [-6.37510877e-17 -6.37510877e-17 -6.37510877e-17]
 [-6.37510877e-17 -6.37510877e-17 -6.37510877e-17]]
[[2.45901796e-15 2.45901796e-15 2.45901796e-15]
 [2.45901796e-15 2.45901796e-15 2.45901796e-15]
 [2.45901796e-15 2.45901796e-15 2.45901796e-15]
 [2.45901796e-15 2.45901796e-15 2.45901796e-15]]
[[-1.01370277e-14 -1.01370277e-14 -1.01370277e-14]
 [-1.01370277e-14 -1.01370277e-14 -1.01370277e-14]
 [-1.01370277e-14 -1.01370277e-14 -1.01370277e-14]
 [-1.01370277e-14 -1.01370277e-14 -1.01370277e-14]]
[[5.37831363e-15 5.37831363e-15 5.37831363e-15]
 [5.37831363e-15 5.37831363e-15 5.37831363e-15]
 [5.37831363e-15 5.37831363e-15 5.37831363e-15]
 [5.37831363e-15 5.37831363e-15 5.37831363e-15]]
[[5.39282161e-16 5.39282161e-16 5.39282161e-16]
 [5.39282161e-16 5.39282161e-16 5.39282161e-16]
 [5.39282161e-16 5.39282161e-16 5.39282161e-16]
 [5.39282161e-16 5.39282161e-16 5.39282161e-16]]
[[-5.193504