In [1]:
import mpmath as mpm
from mpmath import matrix, mp, mpf
import numpy as np
from tqdm import tqdm

from jorbit.integrators.mpmath_testing_only import *

mp.dps = 50

In [2]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jorbit.integrators.ias15_dd import _estimate_x_v_from_b, setup_iasnn_integrator
from jorbit.utils.doubledouble import DoubleDouble

from jorbit.integrators.ias15_dd import refine_intermediate_g as rigdd

b_x_denoms, b_v_denoms, h, r, c, d = setup_iasnn_integrator(n_internal_points=7)


x0 = DoubleDouble(jnp.array([[1.0, 0, 0]]))
v0 = DoubleDouble(jnp.array([[0, 1.0, 0]]))
a0 = -x0
b = DoubleDouble(jnp.ones((7, 1, 3), dtype=jnp.float64))
g = DoubleDouble(jnp.ones((7, 1, 3), dtype=jnp.float64))
dt = DoubleDouble(0.01)

rq = rigdd(substep_num=4, g=g, r=r, at=a0 * DoubleDouble(0.9), a0=a0)
rq

100%|██████████| 998/998 [00:00<00:00, 1407.00it/s]


DoubleDouble([[-42.47575008 -47.69200031 -47.69200031]], [[ 2.53971562e-15 -2.26124626e-15 -2.26124626e-15]])

In [3]:
pre = precompute(7)
b_x_denoms, b_v_denoms, h, r, c, d = pre

dt = mpf("0.01")
x0 = matrix([["1.0", "0", "0"]])
v0 = matrix([["0", "1.0", "0"]])
a0 = matrix([["-1.0", "0", "0"]])
b = mpm.ones(7, 3)  # missing the dimension for nparticles to work w/ mpmath
g = mpm.ones(7, 3)

q = refine_intermediate_g(substep_num=4, g=g, r=r, at=a0 * 0.9, a0=a0)
q

100%|██████████| 998/998 [00:00<00:00, 1365.97it/s]


matrix(
[['-42.475750081141580808836201308054150808740185515630119954534339007943949426', '-47.6920003067568833853739027019785430508873489205952718351524434014057729376', '-47.6920003067568833853739027019785430508873489205952718351524434014057729376']])

In [10]:
print((q[0, 0] - float(rq[0, 0].hi) - float(rq[0, 0].lo)))
print((q[0, 1] - float(rq[0, 1].hi) - float(rq[0, 1].lo)))
print((q[0, 2] - float(rq[0, 2].hi) - float(rq[0, 2].lo)))

0.000000000000000443025612243332432104024931231521336707902873985877462118073104528191912161
0.000000000000000276233085841311643473284298411717976250938410112779333183628079770057068431
0.000000000000000276233085841311643473284298411717976250938410112779333183628079770057068431


In [None]:
# hm. the doubledouble refine_g is only single precision

In [13]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jorbit.integrators.ias15_dd import _estimate_x_v_from_b, setup_iasnn_integrator
from jorbit.utils.doubledouble import DoubleDouble

b_x_denoms, b_v_denoms, h, r, c, d = setup_iasnn_integrator(n_internal_points=7)


x0 = DoubleDouble(jnp.array([[1.0, 0, 0]]))
v0 = DoubleDouble(jnp.array([[0, 1.0, 0]]))
a0 = -x0
b = DoubleDouble(jnp.ones((7, 1, 3), dtype=jnp.float64))
dt = DoubleDouble(0.01)

x_dd, v_dd = _estimate_x_v_from_b(
    a0=a0,
    v0=v0,
    x0=x0,
    dt=dt,
    b_x_denoms=b_x_denoms,
    b_v_denoms=b_v_denoms,
    h=h[1],
    bp=b,
)
x_dd, v_dd

100%|██████████| 998/998 [00:00<00:00, 1400.67it/s]


(DoubleDouble([[9.99999845e-01 5.62628660e-04 3.05472711e-09]], [[ 3.08383424e-17 -1.34093184e-20 -1.03341420e-25]]),
 DoubleDouble([[-5.46178334e-04  1.00001645e+00  1.64472715e-05]], [[ 3.86470038e-21  7.09607810e-17 -2.51140453e-22]]))

In [14]:
pre = precompute(7)
b_x_denoms, b_v_denoms, h, r, c, d = pre

dt = mpf("0.01")
x0 = matrix([["1.0", "0", "0"]])
v0 = matrix([["0", "1.0", "0"]])
a0 = matrix([["-1.0", "0", "0"]])
b = mpm.ones(7, 3)  # missing the dimension for nparticles to work w/ mpmath
g = mpm.ones(7, 3)

x_mp, v_mp = estimate_x_v_from_b(
    a0=a0,
    v0=v0,
    x0=x0,
    dt=dt,
    b_x_denoms=b_x_denoms,
    b_v_denoms=b_v_denoms,
    h=h[1],
    bp=b,
)
x_mp, v_mp

100%|██████████| 998/998 [00:00<00:00, 1371.81it/s]


(matrix(
 [['0.99999984478094120252301594486387171374795735893814', '0.00056262866009633252913496629819679835588232069132449', '0.0000000030547271110644784443878736865980846655438782619988']]),
 matrix(
 [['-0.00054617833388839319509525385463194026215904313563138', '1.0000164472714808282695612680556911714956386120118', '0.000016447271480828269561268055691171495638612011814846']]))

In [15]:
print(x_mp[0, 0] - float(x_dd[0, 0].hi) - float(x_dd[0, 0].lo))
print(v_mp[0, 0] - float(v_dd[0, 0].hi) - float(v_dd[0, 0].lo))

-3.2519074011945340865175046615255634651902208870089e-25
-3.1673570803014496430939158035056559171404018326311e-22


In [11]:
dt = mpf("0.01")
x0 = matrix([["1.0", "0", "0"]])
v0 = matrix([["0", "1.0", "0"]])
a0 = matrix([["-1.0", "0", "0"]])
b = mpm.zeros(7, 3)  # missing the dimension for nparticles to work w/ mpmath
g = mpm.zeros(7, 3)
b

matrix(
[['0.0', '0.0', '0.0'],
 ['0.0', '0.0', '0.0'],
 ['0.0', '0.0', '0.0'],
 ['0.0', '0.0', '0.0'],
 ['0.0', '0.0', '0.0'],
 ['0.0', '0.0', '0.0'],
 ['0.0', '0.0', '0.0']])

In [12]:
steps_per_orbit = 100
n_orbits = 1

dt = mpf("2") * mp.pi / mpf(str(steps_per_orbit))

for i in tqdm(range(steps_per_orbit * n_orbits)):
    x0, v0, b = step(x0, v0, b, dt, pre)

x0

  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [00:03<00:00, 32.87it/s]


matrix(
[['-11.684767359688071197001402094240177226994507749059', '-1371.7225586856334469822337741625392712067685699736', '255.34455581170960926135647247347388758524543358365']])