# scratch work

In [1]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from astropy.time import Time
import astropy.units as u
from astropy.coordinates import SkyCoord
from astroquery.jplhorizons import Horizons

from jorbit import Particle

In [2]:
t0 = Time("2024-12-01 00:00")
t1 = Time("2025-12-01 00:00")

obj = Horizons(id="274301", location="@0", epochs=[t0.tdb.jd, t1.tdb.jd])
vecs = obj.vectors(refplane="earth")

x0 = jnp.array([vecs["x"][0], vecs["y"][0], vecs["z"][0]])
v0 = jnp.array([vecs["vx"][0], vecs["vy"][0], vecs["vz"][0]])

In [3]:
p = Particle(x=x0, v=v0, log_gm=-jnp.inf, time=t1)
p

Particle: unnamed

In [4]:
p.gravity(1.0)

AttributeError: 'float' object has no attribute 'time'

In [6]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import chex


@chex.dataclass
class Parameters:
    x: chex.ArrayDevice
    y: chex.ArrayDevice

    @jax.jit
    def sum(self):
        return self.x + self.y


params = Parameters(x=jnp.array([1.0, 2.0, 3.0]), y=jnp.array([4.0, 5.0, 6.0]))

In [7]:
params.sum()

Array([5., 7., 9.], dtype=float64)

In [8]:
@jax.jit
def test(p):
    return p.sum()


test(params), jax.jacrev(test)(params)

(Array([5., 7., 9.], dtype=float64),
 Parameters(x=Array([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]], dtype=float64), y=Array([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]], dtype=float64)))

In [9]:
params

Parameters(x=Array([1., 2., 3.], dtype=float64), y=Array([4., 5., 6.], dtype=float64))

In [10]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import optimistix as optx


def fn(y, args):
    a, b = y
    c = jnp.tanh(jnp.sum(b)) - a
    d = a**2 - jnp.sinh(b + 1)
    return c, d


solver = optx.Newton(rtol=1e-8, atol=1e-8)
y0 = (jnp.array(0.0), jnp.zeros((2, 2)))
sol = optx.root_find(fn, solver, y0)

In [12]:
sol

Solution(
  value=(f64[], f64[2,2]),
  result=EnumerationItem(
    _value=i32[],
    _enumeration=<class 'optimistix._solution.RESULTS'>
  ),
  aux=None,
  stats={'max_steps': 256, 'num_steps': weak_i64[]},
  state=_NewtonChordState(
    f=(f64[], f64[2,2]),
    linear_state=None,
    diff=(f64[], f64[2,2]),
    diffsize=f64[],
    diffsize_prev=weak_f64[],
    result=EnumerationItem(
      _value=i32[],
      _enumeration=<class 'optimistix._solution.RESULTS'>
    ),
    step=weak_i64[]
  )
)

In [13]:
sol.value

(Array(-0.85650715, dtype=float64),
 Array([[-0.32002086, -0.32002086],
        [-0.32002086, -0.32002086]], dtype=float64))

In [None]:
def fun(y, args):
    return y**2


solver = optx.BFGS(rtol=1e-8, atol=1e-8)
y0 = jnp.array(100.0)
sol = optx.minimise(fun, solver, y0)
sol.value

Array(0., dtype=float64)