# scratch work

In [16]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from jplephem.spk import SPK
import astropy.units as u
from astropy.time import Time
from astropy.utils.data import download_file

PLANET_EPHEMERIS_URL = "https://ssd.jpl.nasa.gov//ftp/eph/planets/bsp/de440.bsp"
ASTEROID_EPHEMERIS_URL = (
    "https://ssd.jpl.nasa.gov/ftp/eph/small_bodies/asteroids_de441/sb441-n16.bsp"
)

In [17]:
kernel = SPK.open(download_file(PLANET_EPHEMERIS_URL, cache=True))
dir(kernel)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__enter__',
 '__eq__',
 '__exit__',
 '__firstlineno__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__static_attributes__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'close',
 'comments',
 'daf',
 'open',
 'pairs',
 'segments']

In [18]:
dir(kernel.segments[0])

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__firstlineno__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__static_attributes__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_data',
 'center',
 'compute',
 'compute_and_differentiate',
 'daf',
 'data_type',
 'describe',
 'end_i',
 'end_jd',
 'end_second',
 'frame',
 'generate',
 'load_array',
 'source',
 'start_i',
 'start_jd',
 'start_second',
 'target']

In [19]:
[s.target for s in kernel.segments]

[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 301, 399, 199, 299]

In [20]:
len(kernel.segments[0]._data)

3

In [21]:
for k in kernel.segments:
    init, intlen, coeff = k._data
    print(init, intlen, coeff.shape)

-14200747200.0 691200.0 (14, 3, 50224)
-14200747200.0 1382400.0 (10, 3, 25112)
-14200747200.0 1382400.0 (13, 3, 25112)
-14200747200.0 2764800.0 (11, 3, 12556)
-14200747200.0 2764800.0 (8, 3, 12556)
-14200747200.0 2764800.0 (7, 3, 12556)
-14200747200.0 2764800.0 (6, 3, 12556)
-14200747200.0 2764800.0 (6, 3, 12556)
-14200747200.0 2764800.0 (6, 3, 12556)
-14200747200.0 1382400.0 (11, 3, 25112)
-14200747200.0 345600.0 (13, 3, 100448)
-14200747200.0 345600.0 (13, 3, 100448)
-14200747200.0 34714828800.0 (2, 3, 1)
-14200747200.0 34714828800.0 (2, 3, 1)


In [22]:
-14200747200.0 * u.s.to(u.year)

-449.9945242984257

In [23]:
for k in kernel.segments:
    print(k.target)
    print(k.describe())

1
1549-12-31..2650-01-25  Type 2  Solar System Barycenter (0) -> Mercury Barycenter (1)
  frame=1 source=DE-0440LE-0440
2
1549-12-31..2650-01-25  Type 2  Solar System Barycenter (0) -> Venus Barycenter (2)
  frame=1 source=DE-0440LE-0440
3
1549-12-31..2650-01-25  Type 2  Solar System Barycenter (0) -> Earth Barycenter (3)
  frame=1 source=DE-0440LE-0440
4
1549-12-31..2650-01-25  Type 2  Solar System Barycenter (0) -> Mars Barycenter (4)
  frame=1 source=DE-0440LE-0440
5
1549-12-31..2650-01-25  Type 2  Solar System Barycenter (0) -> Jupiter Barycenter (5)
  frame=1 source=DE-0440LE-0440
6
1549-12-31..2650-01-25  Type 2  Solar System Barycenter (0) -> Saturn Barycenter (6)
  frame=1 source=DE-0440LE-0440
7
1549-12-31..2650-01-25  Type 2  Solar System Barycenter (0) -> Uranus Barycenter (7)
  frame=1 source=DE-0440LE-0440
8
1549-12-31..2650-01-25  Type 2  Solar System Barycenter (0) -> Neptune Barycenter (8)
  frame=1 source=DE-0440LE-0440
9
1549-12-31..2650-01-25  Type 2  Solar System Ba

In [24]:
kernel.segments[2].compute(Time("2021-01-01T00:00:00", scale="tdb").jd) * u.km.to(u.au)

array([-0.18578476,  0.89254175,  0.38703056])

In [25]:
jnp.linalg.norm(
    kernel.segments[2].compute(Time("2021-01-01T00:00:00", scale="tdb").jd)
    * u.km.to(u.au)
)

Array(0.99042385, dtype=float64)

In [26]:
def eval_cheby(coefficients, x):
    b_ii = jnp.zeros((3, x.shape[0]))
    b_i = jnp.zeros((3, x.shape[0]))

    def scan_func(X, a):
        b_i, b_ii = X
        tmp = b_i
        b_i = a + 2 * x * b_i - b_ii
        b_ii = tmp
        return (b_i, b_ii), b_i

    (b_i, b_ii), s = jax.lax.scan(scan_func, (b_i, b_ii), coefficients[:-1])
    return coefficients[-1] + x * b_i - b_ii, s


def prep_ephemeris(
    init,
    intlen,
    coefficients,
    tdb,
):
    tdb2 = 0.0  # leaving in case we ever decide to increase the time precision and use 2 floats
    _, _, n = coefficients.shape

    # 2451545.0 is the J2000 epoch in TDB
    index1, offset1 = jnp.divmod((tdb - 2451545.0) * 86400.0 - init, intlen)
    index2, offset2 = jnp.divmod(tdb2 * 86400.0, intlen)
    index3, offset = jnp.divmod(offset1 + offset2, intlen)
    index = (index1 + index2 + index3).astype(int)

    omegas = index == n
    index = jnp.where(omegas, index - 1, index)
    offset = jnp.where(omegas, offset + intlen, offset)

    coefficients = coefficients[:, :, index]

    s = 2.0 * offset / intlen - 1.0
    return s, coefficients


def single_perturber_states(
    init,
    intlen,
    coefficients,
    tdb,
):
    s, coefficients = prep_ephemeris(init, intlen, coefficients, tdb)

    # Position
    x, As = eval_cheby(coefficients, s)  # in km here

    return x

In [27]:
init, intlen, coefficients = kernel.segments[2]._data
tdb = jnp.array([Time("2021-01-01T00:00:00", scale="tdb").jd])
s, coefficients = prep_ephemeris(init, intlen, coefficients, tdb)
x, As = eval_cheby(coefficients, s)  # in km here
(x * u.km.to(u.au))

Array([[-0.18578476],
       [ 0.89254175],
       [ 0.38703056]], dtype=float64)

In [28]:
kernel.segments[2].compute(Time("2021-01-01T00:00:00", scale="tdb").jd) * u.km.to(u.au)

array([-0.18578476,  0.89254175,  0.38703056])

In [29]:
@jax.tree_util.register_pytree_node_class
class Test:
    def __init__(self, s):
        self.s = s
        print("some stuff")
        self.r = jnp.zeros(100)
        assert self.s < 10
        del self.r

    def tree_flatten(self):
        children = (self.s,)
        aux_data = None
        return (children, aux_data)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children)

    @jax.jit
    def process(self, t):
        return t**2 + t


t = Test(1.0)
_ = t.process(1.0)


@jax.jit
def q(t):
    return t.process(t.s)


q(t)

some stuff
some stuff
some stuff


TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function process at /var/folders/mj/qxz5chg95r53_2nlv9f86qhm0000gn/T/ipykernel_65050/1229738365.py:21 for jit. This concrete value was not available in Python because it depends on the value of the argument self[<flat index 0>].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

In [30]:
from astropy.time import Time
from jorbit.ephemeris import Ephemeris

e = Ephemeris()

In [31]:
e.state(Time("2021-01-01T00:00:00"))

{10: {'x': Array([-0.00665194,  0.00546611,  0.00248513], dtype=float64),
  'v': Array([-6.83985927e-06, -5.37276641e-06, -2.09406947e-06], dtype=float64),
  'a': Array([ 6.14347531e-09, -8.92051287e-09, -3.98993029e-09], dtype=float64)},
 1: {'x': Array([ 0.23037548, -0.30268229, -0.18669482], dtype=float64),
  'v': Array([0.01783733, 0.01561044, 0.00649018], dtype=float64),
  'a': Array([-0.00086786,  0.00112827,  0.00069267], dtype=float64)},
 2: {'x': Array([-0.45305527, -0.52456435, -0.20775883], dtype=float64),
  'v': Array([ 0.01577462, -0.01110446, -0.0059947 ], dtype=float64),
  'a': Array([0.00034785, 0.000413  , 0.00016382], dtype=float64)},
 3: {'x': Array([-0.18579853,  0.8925394 ,  0.38702954], dtype=float64),
  'v': Array([-0.01720373, -0.00294078, -0.00127453], dtype=float64),
  'a': Array([ 5.57645088e-05, -2.76110661e-04, -1.19693461e-04], dtype=float64)},
 4: {'x': Array([0.61421366, 1.26221831, 0.56217515], dtype=float64),
  'v': Array([-0.01223161,  0.00619036,  0.

In [None]:
%timeit e.state(Time("2021-01-01T00:00:00"))[10]["x"].block_until_ready()

1.07 ms ± 19.2 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [83]:
from astroquery.jplhorizons import Horizons

obj = Horizons(
    id="3",
    epochs=Time("2021-01-01T00:00:00").tdb.jd,
    id_type="majorbody",
    location="@0",
)
horizons_results = obj.vectors(refplane="earth")
horizons_calc = jnp.array(
    [horizons_vec["x"][0], horizons_vec["y"][0], horizons_vec["z"][0]]
)
horizons_results



targetname,datetime_jd,datetime_str,x,y,z,vx,vy,vz,lighttime,range,range_rate
---,d,---,AU,AU,AU,AU / d,AU / d,AU / d,d,AU,AU / d
str25,float64,str30,float64,float64,float64,float64,float64,float64,float64,float64,float64
Earth-Moon Barycenter (3),2459215.50080074,A.D. 2021-Jan-01 00:01:09.1839,-0.1857985333448159,0.8925393953089126,0.3870295356607404,-0.0172037276289486,-0.002940779545245,-0.0012745336545634,0.0057202114689326,0.9904239136073262,7.914146143481989e-05


In [84]:
e = Ephemeris()
jorbit_calc = e.state(Time("2021-01-01T00:00:00"))
jorbit_calc = jorbit_calc[3]["x"]
jorbit_calc

Array([-0.18579853,  0.8925394 ,  0.38702954], dtype=float64)

In [85]:
jnp.linalg.norm(jorbit_calc - horizons_calc) * u.au.to(u.m)

Array(0.19863445, dtype=float64)

In [86]:
# 0.2 meters! what about jplephem?
jplephem_calc = kernel.segments[2].compute(
    Time("2021-01-01T00:00:00").tdb.jd
) * u.km.to(u.au)
jnp.linalg.norm(jplephem_calc - horizons_calc) * u.au.to(u.m)

Array(0.20586389, dtype=float64)

In [88]:
jnp.linalg.norm(jplephem_calc - jorbit_calc) * u.au.to(u.mm)

Array(7.22946136, dtype=float64)

In [89]:
# 7mm difference between jplephem and jorbit, 0.2m between each and Horizons