# scratch work

In [1]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from tqdm import tqdm

from astropy.time import Time
import astropy.units as u
from astropy.coordinates import SkyCoord
from astroquery.jplhorizons import Horizons

from jorbit.observation import Observations
from jorbit.utils.states import SystemState
from jorbit.ephemeris import Ephemeris
from jorbit.accelerations import create_newtonian_ephemeris_acceleration_func
from jorbit.integrators import ias15_evolve, initialize_ias15_integrator_state
from jorbit.astrometry.sky_projection import on_sky, sky_sep

t0 = Time("2024-12-01 00:00")
t1 = Time("2025-12-01 00:00")

In [2]:
200 * 5 * 52

52000

In [4]:
obj = Horizons(id="274301", location="@0", epochs=[t0.tdb.jd, t1.tdb.jd])
vecs = obj.vectors(refplane="earth")

x0 = jnp.array([vecs["x"][0], vecs["y"][0], vecs["z"][0]])
v0 = jnp.array([vecs["vx"][0], vecs["vy"][0], vecs["vz"][0]])

x1 = jnp.array([vecs["x"][-1], vecs["y"][-1], vecs["z"][-1]])
v1 = jnp.array([vecs["vx"][-1], vecs["vy"][-1], vecs["vz"][-1]])


obj = Horizons(id="274301", location="695", epochs=[t0.jd, t1.jd])
eph = obj.ephemerides(extra_precision=True, quantities="1")
coord0 = SkyCoord(eph["RA"][0], eph["DEC"][0], unit=(u.deg, u.deg), frame="icrs")
coord1 = SkyCoord(eph["RA"][-1], eph["DEC"][-1], unit=(u.deg, u.deg), frame="icrs")

In [5]:
obs = Observations(
    observed_coordinates=[coord0, coord0],
    times=jnp.array([t0.tdb.jd, t1.tdb.jd]),
    observatories="695@399",
    astrometric_uncertainties=1 * u.arcsec,
    verbose=True,
)

eph = Ephemeris(
    earliest_time=Time("2015-01-01 00:00"),
    latest_time=Time("2035-01-01 00:00"),
    ssos="default planets",
)

acc_func = create_newtonian_ephemeris_acceleration_func(eph.processor)

state0 = SystemState(
    positions=jnp.array([x0]),
    velocities=jnp.array([v0]),
    log_gms=jnp.array([0.0]),
    time=t0.tdb.jd,
    acceleration_func_kwargs=None,
)

ra0, dec0 = on_sky(state0, acc_func, obs.observer_positions[0])
calc_coord0 = SkyCoord(ra0, dec0, unit=(u.rad, u.rad), frame="icrs")
calc_coord0.separation(coord0)

Downloading observer positions from Horizons...


<Angle [3.19443392e-10] deg>

In [6]:
new_positions, new_velocities, state1, new_integrator_state = ias15_evolve(
    initial_system_state=state0,
    acceleration_func=acc_func,
    times=jnp.array([t1.tdb.jd]),
    initial_integrator_state=initialize_ias15_integrator_state(acc_func(state0)),
    n_steps=100,
)

ra1, dec1 = on_sky(state1, acc_func, obs.observer_positions[1])
calc_coord1 = SkyCoord(ra1, dec1, unit=(u.rad, u.rad), frame="icrs")
calc_coord1.separation(coord1)

<Angle [5.47274304e-07] deg>

In [7]:
state1.positions, state0.positions

(Array([[-2.17856334, -1.13866228, -0.35793294]], dtype=float64),
 Array([[-1.78355289,  1.97281979,  0.57967129]], dtype=float64))

In [8]:
x1, x0

(Array([-2.17856337, -1.13866227, -0.35793294], dtype=float64),
 Array([-1.78355289,  1.97281979,  0.57967129], dtype=float64))

In [9]:
new_integrator_state.dt

Array(16.64531375, dtype=float64)

In [10]:
from jorbit.astrometry.transformations import (
    elements_to_cartesian,
    cartesian_to_elements,
)

tess_coord = SkyCoord(
    95.97, 22.75, unit=(u.deg, u.deg), frame="icrs", obstime=Time("2023-11-11 00:00")
)
obs = Observations(
    observed_coordinates=[tess_coord],
    times=jnp.array([tess_coord.obstime.tdb.jd]),
    observatories="@-95",
    astrometric_uncertainties=1 * u.arcsec,
    verbose=True,
)

eph = Ephemeris(
    earliest_time=Time("2015-01-01 00:00"),
    latest_time=Time("2035-01-01 00:00"),
    ssos="default planets",
)
acc_func = create_newtonian_ephemeris_acceleration_func(eph.processor)

Downloading observer positions from Horizons...


In [11]:
@jax.jit
def best_orb(variable_params, fixed_params, target_ra, target_dec, observer_pos):
    x, v = elements_to_cartesian(**fixed_params, **variable_params)
    state = SystemState(
        positions=x,
        velocities=v,
        log_gms=jnp.array([0.0]),
        time=0.0,
        acceleration_func_kwargs=None,
    )
    calc_ra, calc_dec = on_sky(state, acc_func, observer_pos)

    return sky_sep(calc_ra, calc_dec, target_ra, target_dec)[0]


best_orb(
    {
        "Omega": jnp.array([180.0]),
        "omega": jnp.array([180.0]),
        "nu": jnp.array([180.0]),
    },
    {"a": jnp.array([30.0]), "ecc": jnp.array([0.1]), "inc": jnp.array([30.0])},
    tess_coord.ra.rad,
    tess_coord.dec.rad,
    obs.observer_positions[0],
)

jax.jacfwd(best_orb, argnums=0)(
    {
        "Omega": jnp.array([180.0]),
        "omega": jnp.array([180.0]),
        "nu": jnp.array([180.0]),
    },
    {"a": jnp.array([30.0]), "ecc": jnp.array([0.1]), "inc": jnp.array([30.0])},
    tess_coord.ra.rad,
    tess_coord.dec.rad,
    obs.observer_positions[0],
)

{'Omega': Array([3251.19227403], dtype=float64),
 'nu': Array([2130.19201191], dtype=float64),
 'omega': Array([2130.19215495], dtype=float64)}

In [12]:
def scipy_objective(x):
    return best_orb(
        {
            "Omega": jnp.array([x[0]]),
            "omega": jnp.array([x[1]]),
            "nu": jnp.array([x[2]]),
        },
        {"a": jnp.array([30.0]), "ecc": jnp.array([0.1]), "inc": jnp.array([30.0])},
        tess_coord.ra.rad,
        tess_coord.dec.rad,
        obs.observer_positions[0],
    )


# grad = jax.jacfwd(best_orb, argnums=0)
def scipy_gradient(x):
    g = jax.jacfwd(best_orb, argnums=0)(
        {
            "Omega": jnp.array([x[0]]),
            "omega": jnp.array([x[1]]),
            "nu": jnp.array([x[2]]),
        },
        {"a": jnp.array([30.0]), "ecc": jnp.array([0.1]), "inc": jnp.array([30.0])},
        tess_coord.ra.rad,
        tess_coord.dec.rad,
        obs.observer_positions[0],
    )
    return jnp.array([g["Omega"][0], g["omega"][0], g["nu"][0]])


from scipy.optimize import minimize

res = minimize(
    scipy_objective,
    [180.0, 180.0, 180.0],
    jac=scipy_gradient,
    method="BFGS",
    options={"disp": True},
)

         Current function value: 0.000000
         Iterations: 50
         Function evaluations: 114
         Gradient evaluations: 108


  res = _minimize_bfgs(fun, x0, args, jac, callback, **options)


In [13]:
res

  message: Desired error not necessarily achieved due to precision loss.
  success: False
   status: 2
      fun: 4.0910197415231247e-10
        x: [-3.871e+01  2.468e+02  2.424e+02]
      nit: 50
      jac: [-3.139e+03 -2.722e+03 -2.729e+03]
 hess_inv: [[ 2.130e-06  1.032e-03 -1.032e-03]
            [ 1.032e-03  5.000e-01 -4.999e-01]
            [-1.032e-03 -4.999e-01  4.997e-01]]
     nfev: 114
     njev: 108

In [14]:
el = {
    "a": jnp.array([30.0]),
    "ecc": jnp.array([0.1]),
    "inc": jnp.array([30.0]),
    "Omega": jnp.array([res.x[0]]),
    "omega": jnp.array([res.x[1]]),
    "nu": jnp.array([res.x[2]]),
}
x, v = elements_to_cartesian(**el)
state = SystemState(
    positions=x,
    velocities=v,
    log_gms=jnp.array([0.0]),
    time=0.0,
    acceleration_func_kwargs=None,
)
calc_ra, calc_dec = on_sky(state, acc_func, obs.observer_positions[0])
calc_coord = SkyCoord(calc_ra, calc_dec, unit=(u.rad, u.rad), frame="icrs")
calc_coord.separation(tess_coord)

<Angle [1.12700756e-13] deg>

In [15]:
def produce_valid_orbits(semis, eccs, incs, target_ra, target_dec, observer_pos):

    def scipy_objective(x, fixed, target_ra, target_dec, observer_pos):
        return best_orb(
            {
                "Omega": jnp.array([x[0]]),
                "omega": jnp.array([x[1]]),
                "nu": jnp.array([x[2]]),
            },
            fixed,
            target_ra,
            target_dec,
            observer_pos,
        )

    def scipy_gradient(x, fixed, target_ra, target_dec, observer_pos):
        g = jax.jacfwd(best_orb, argnums=0)(
            {
                "Omega": jnp.array([x[0]]),
                "omega": jnp.array([x[1]]),
                "nu": jnp.array([x[2]]),
            },
            fixed,
            target_ra,
            target_dec,
            observer_pos,
        )
        return jnp.array([g["Omega"][0], g["omega"][0], g["nu"][0]])

    fixed_params = []
    for a in semis:
        for e in eccs:
            for i in incs:
                fixed_params.append(
                    {"a": jnp.array([a]), "ecc": jnp.array([e]), "inc": jnp.array([i])}
                )

    valid_orbits = []
    for fixed in tqdm(fixed_params):
        res = minimize(
            scipy_objective,
            [180.0, 180.0, 180.0],
            jac=scipy_gradient,
            method="BFGS",
            args=(fixed, target_ra, target_dec, observer_pos),
            # options={"disp": True},
        )
        if res.fun < 10.0:
            el = {
                "a": fixed["a"],
                "ecc": fixed["ecc"],
                "inc": fixed["inc"],
                "Omega": jnp.array([res.x[0]]),
                "omega": jnp.array([res.x[1]]),
                "nu": jnp.array([res.x[2]]),
            }
            valid_orbits.append((el, res.fun))

    return valid_orbits


orbs = produce_valid_orbits(
    jnp.array([30.0, 40.0, 50.0]),
    jnp.array([0.1, 0.2, 0.3]),
    jnp.array([20.0, 30.0, 40.0]),
    tess_coord.ra.rad,
    tess_coord.dec.rad,
    obs.observer_positions[0],
)

100%|██████████| 27/27 [00:06<00:00,  4.10it/s]


In [16]:
jnp.arange(5, 120, 5).size * jnp.arange(0.0, 0.9, 0.1).size * jnp.arange(
    0, 180, 10
).size / 4 / 60

15.525

In [17]:
len(orbs)

18

In [22]:
xs, vs = [], []
for o in orbs:
    x, v = elements_to_cartesian(**o[0])
    xs.append(x[0])
    vs.append(v[0])
xs = jnp.array(xs)
vs = jnp.array(vs)

state = SystemState(
    positions=xs,
    velocities=vs,
    log_gms=jnp.ones_like(xs[:, 0]) * -jnp.inf,
    time=tess_coord.obstime.tdb.jd,
    acceleration_func_kwargs=None,
)

sector_times = (tess_coord.obstime + jnp.arange(0.1, 27, 0.1) * u.day).tdb.jd


new_positions, new_velocities, state1, new_integrator_state = ias15_evolve(
    initial_system_state=state,
    acceleration_func=acc_func,
    times=sector_times,
    initial_integrator_state=initialize_ias15_integrator_state(acc_func(state)),
    n_steps=100,
)

In [23]:
new_positions.shape

(269, 18, 3)

In [26]:
xs.shape

(18, 3)

In [32]:
s = SystemState(
    positions=new_positions[0],
    velocities=new_velocities[0],
    log_gms=jnp.ones_like(new_positions[0, :, 0]) * -jnp.inf,
    time=sector_times[0],
    acceleration_func_kwargs=None,
)

r, d = on_sky(s, acc_func, obs.observer_positions[0])
c = SkyCoord(r, d, unit=(u.rad, u.rad), frame="icrs")
c

<SkyCoord (ICRS): (ra, dec) in deg
    [(95.96493464, 22.75175144), (95.9659758 , 22.74756711),
     (95.96528858, 22.75162466), (95.9666757 , 22.74798943),
     (95.96628255, 22.75127783), (95.96722327, 22.7483208 ),
     (95.96725538, 22.74909687), (95.96759315, 22.7485369 ),
     (95.96772637, 22.74925129), (95.96802916, 22.74880153),
     (95.96809635, 22.74937295), (95.96836947, 22.74900851),
     (95.96819584, 22.74940333), (95.96841961, 22.74903611),
     (95.96851789, 22.74950954), (95.96871814, 22.74921796),
     (95.96876869, 22.74959245), (95.96895029, 22.7493596 )]>

In [None]:
aws s3 cp --no-sign-request s3://stpubdata/tess/public/mast/tess-s0072-1-1-cube.fits sector72_cubes/ccd1_cam1.fits

rclone ls tess:stpubdata/tess/public/mast/

rclone cp tess://stpubdata/tess/public/mast/tess-s0072-1-1-cube.fits sector72_cubes/ccd1_cam1.fits

In [None]:
for x, v in zip(new_positions, new_velocities):
    ra, dec = on_sky(
        SystemState(
            positions=x,
            velocities=v,
            log_gms=jnp.ones_like(x) * -jnp.inf,
            time=0.0,
            acceleration_func_kwargs=None,
        ),
        acc_func,
        obs.observer_positions[0],
    )
    print(SkyCoord(ra, dec, unit=(u.rad, u.rad), frame="icrs").separation(tess_coord))

TypeError: Cannot concatenate arrays with different numbers of dimensions: got (18, 3), (10,).

In [None]:
from astropy.io import fits
from astropy.wcs import WCS

with fits.open("tess2023327004148-s0072-1-1-0267-s_ffic.fits") as hdul:
    w = WCS(hdul[1].header)
    t = Time(hdul[1].header["DATE-OBS"])
t

Set MJD-END to 60271.035557 from DATE-END'. [astropy.wcs.wcs]


<Time object: scale='utc' format='isot' value=2023-11-23T00:47:52.158>

In [None]:
w.pixel_to_world(1000, 1000)

<SkyCoord (ICRS): (ra, dec) in deg
    (102.24287734, 15.63388622)>

In [None]:
356 / 64

5.5625

In [None]:
16 * 16 / 60

4.266666666666667

In [33]:
helio_r = 30.0
e = 0.1
nu = 40.0
Omega = 42.0

semi = helio_r / (1 - e**2) * (1 + e * jnp.cos(nu))
semi

Array(28.28200587, dtype=float64, weak_type=True)

In [181]:
def produce_valid_orbits(
    helio_rs, eccs, nus, Omegas, target_ra, target_dec, observer_pos
):

    def scipy_objective(x, fixed, target_ra, target_dec, observer_pos):
        return best_orb(
            {"omega": jnp.array([x[0]]), "inc": jnp.array([x[1]])},
            fixed,
            target_ra,
            target_dec,
            observer_pos,
        )

    def scipy_gradient(x, fixed, target_ra, target_dec, observer_pos):
        g = jax.jacfwd(best_orb, argnums=0)(
            {"omega": jnp.array([x[0]]), "inc": jnp.array([x[1]])},
            fixed,
            target_ra,
            target_dec,
            observer_pos,
        )
        return jnp.array([g["omega"][0], g["inc"][0]])

    fixed_params = []
    for r in helio_rs:
        for e in eccs:
            for nu in nus:
                for Omega in Omegas:
                    semi = r / (1 - e**2) * (1 + e * jnp.cos(nu))
                    fixed_params.append(
                        {
                            "a": semi,
                            "ecc": jnp.array([e]),
                            "nu": jnp.array([nu]),
                            "Omega": jnp.array([Omega]),
                        }
                    )

    valid_orbits = []
    for fixed in tqdm(fixed_params):
        res = minimize(
            scipy_objective,
            [20.0, 120.0],
            jac=scipy_gradient,
            method="BFGS",
            args=(fixed, target_ra, target_dec, observer_pos),
            # options={"disp": True},
        )
        if res.fun < 10.0:
            el = {
                "a": fixed["a"],
                "ecc": fixed["ecc"],
                "inc": jnp.array([res.x[1]]),
                "Omega": fixed["Omega"],
                "omega": jnp.array([res.x[0]]),
                "nu": fixed["nu"],
            }
            valid_orbits.append((el, res.fun))

    return valid_orbits


orbs = produce_valid_orbits(
    helio_rs=jnp.array([30.0, 40.0, 50.0]),
    eccs=jnp.array([0.1, 0.2, 0.3]),
    nus=jnp.array([20.0, 30.0, 40.0]),
    Omegas=jnp.array([20.0, 30.0, 40.0]),
    target_ra=tess_coord.ra.rad,
    target_dec=tess_coord.dec.rad,
    observer_pos=obs.observer_positions[0],
)

  0%|          | 0/81 [00:00<?, ?it/s]


ValueError: shapes (2,1) and (2,1) not aligned: 1 (dim 1) != 2 (dim 0)

In [178]:
jnp.arange(5, 120, 5)[:, None]

Array([[  5],
       [ 10],
       [ 15],
       [ 20],
       [ 25],
       [ 30],
       [ 35],
       [ 40],
       [ 45],
       [ 50],
       [ 55],
       [ 60],
       [ 65],
       [ 70],
       [ 75],
       [ 80],
       [ 85],
       [ 90],
       [ 95],
       [100],
       [105],
       [110],
       [115]], dtype=int64)

In [39]:
jnp.arange(30, 100, 5).size * jnp.arange(0.0, 0.9, 0.1).size * jnp.arange(
    0, 360, 30
).size * jnp.arange(0, 360, 30).size

18144

In [36]:
orbs

[({'a': Array(31.53964261, dtype=float64),
   'ecc': Array([0.1], dtype=float64),
   'inc': Array([23.59572514], dtype=float64),
   'Omega': Array([20.], dtype=float64),
   'omega': Array([55.65032646], dtype=float64),
   'nu': Array([20.], dtype=float64)},
  1.4231832929851027e-10),
 ({'a': Array(31.53964261, dtype=float64),
   'ecc': Array([0.1], dtype=float64),
   'inc': Array([25.00921594], dtype=float64),
   'Omega': Array([30.], dtype=float64),
   'omega': Array([46.53186474], dtype=float64),
   'nu': Array([20.], dtype=float64)},
  7.734942444512327e-11),
 ({'a': Array(31.53964261, dtype=float64),
   'ecc': Array([0.1], dtype=float64),
   'inc': Array([27.3574672], dtype=float64),
   'Omega': Array([40.], dtype=float64),
   'omega': Array([37.55168761], dtype=float64),
   'nu': Array([20.], dtype=float64)},
  4.3437838991034765e-11),
 ({'a': Array(30.77045894, dtype=float64),
   'ecc': Array([0.1], dtype=float64),
   'inc': Array([23.59973833], dtype=float64),
   'Omega': Array(

In [55]:
@jax.jit
def best_orb(variable_params, fixed_params, target_ra, target_dec, observer_pos):
    x, v = elements_to_cartesian(**fixed_params, **variable_params)
    state = SystemState(
        positions=x,
        velocities=v,
        log_gms=jnp.ones(x.shape[0]) * -jnp.inf,
        time=0.0,
        acceleration_func_kwargs=None,
    )
    calc_ra, calc_dec = on_sky(state, acc_func, observer_pos)
    jax.debug.print("{x}", x=calc_ra)

    return sky_sep(calc_ra, calc_dec, target_ra, target_dec)

In [15]:
def tmp_acc(s):
    return jnp.ones_like(s.positions) * jnp.finfo(jnp.float64).eps


acc_func = jax.tree_util.Partial(lambda s: tmp_acc(s))

fixed_params = {
    "a": jnp.array([34.49259676]),
    "ecc": jnp.array([0.3]),
    "inc": jnp.array([23.62803909]),
    "Omega": jnp.array([20.0]),
    "omega": jnp.array([45.44197637]),
    "nu": jnp.array([30.0]),
}
x, v = elements_to_cartesian(**fixed_params)
state = SystemState(
    positions=x,
    velocities=v,
    log_gms=jnp.ones(x.shape[0]) * -jnp.inf,
    time=0.0,
    acceleration_func_kwargs=None,
)
calc_ra1, calc_dec1 = on_sky(state, acc_func, obs.observer_positions[0])


fixed_params = {
    "a": jnp.array([34.49259676, 34.49259676]),
    "ecc": jnp.array([0.3, 0.3]),
    "inc": jnp.array([23.62803909, 23.62803909]),
    "Omega": jnp.array([20.0, 20.0]),
    "omega": jnp.array([45.44197637, 45.44197637]),
    "nu": jnp.array([30.0, 30.0]),
}
x, v = elements_to_cartesian(**fixed_params)
x = x.at[0].set(x[0] + 1.2e-16)
state = SystemState(
    positions=x,
    velocities=v,
    log_gms=jnp.ones(x.shape[0]) * -jnp.inf,
    time=0.0,
    acceleration_func_kwargs=None,
)
calc_ra2, calc_dec2 = on_sky(state, acc_func, obs.observer_positions[0])

fixed_params = {
    "a": jnp.array([[34.49259676], [34.49259676]]),
    "ecc": jnp.array([[0.3], [0.3]]),
    "inc": jnp.array([[23.62803909], [23.62803909]]),
    "Omega": jnp.array([[20.0], [20.0]]),
    "omega": jnp.array([[45.44197637], [45.44197637]]),
    "nu": jnp.array([[30.0], [30.0]]),
}


def tmp(f):
    x, v = elements_to_cartesian(**f)
    state = SystemState(
        positions=x,
        velocities=v,
        log_gms=jnp.ones(x.shape[0]) * -jnp.inf,
        time=0.0,
        acceleration_func_kwargs=None,
    )
    return on_sky(state, acc_func, obs.observer_positions[0])


calc_ra3, calc_dec3 = jax.vmap(tmp)(fixed_params)

calc_ra1 - calc_ra2, calc_ra1 - calc_ra3

(Array([9.95484292e-06, 9.95484292e-06], dtype=float64),
 Array([[0.],
        [0.]], dtype=float64))

In [175]:
helio_rs = jnp.array([30.0, 40.0, 50.0])
eccs = jnp.array([0.1, 0.2, 0.3])
nus = jnp.array([20.0, 30.0, 40.0])
Omegas = jnp.array([20.0, 30.0, 40.0])


target_ra = tess_coord.ra.rad
target_dec = tess_coord.dec.rad
observer_pos = obs.observer_positions[0]

fixed_params = []
for r in helio_rs:
    for e in eccs:
        for nu in nus:
            for Omega in Omegas:
                semi = r / (1 - e**2) * (1 + e * jnp.cos(nu))
                fixed_params.append(
                    [
                        jnp.array([semi]),
                        jnp.array([e]),
                        jnp.array([nu]),
                        jnp.array([Omega]),
                    ]
                )
fixed_params = jnp.array(fixed_params)


@jax.jit
def best_orb(variable_params, fixed_params, target_ra, target_dec, observer_pos):
    x, v = elements_to_cartesian(**fixed_params, **variable_params)
    state = SystemState(
        positions=x,
        velocities=v,
        log_gms=jnp.ones(x.shape[0]) * -jnp.inf,
        time=0.0,
        acceleration_func_kwargs=None,
    )
    calc_ra, calc_dec = on_sky(state, acc_func, observer_pos)

    return sky_sep(calc_ra, calc_dec, target_ra, target_dec)


@jax.jit
def translate_fixed(f):
    return {
        "a": f[0],
        "ecc": f[1],
        "nu": f[2],
        "Omega": f[3],
    }


@jax.jit
def translate_variable(v):
    return {
        "omega": v[0],
        "inc": v[1],
    }


@jax.jit
def objective(x, fixed, target_ra, target_dec, observer_pos):
    return best_orb(
        translate_variable(x),
        translate_fixed(fixed),
        target_ra,
        target_dec,
        observer_pos,
    )[0]


@jax.jit
def simul_objective(x):
    return jnp.sum(
        jax.vmap(objective, in_axes=(0, 0, None, None, None))(
            x,
            fixed_params,
            target_ra,
            target_dec,
            observer_pos,
        )
    )


@jax.jit
def grad_simul_objective(x):
    return jnp.ravel(jax.jacfwd(simul_objective)(x))


fixed_params = fixed_params
init_inc = 20.0
init_omega = 20.0

init_x = jnp.array(
    [jnp.array([[init_inc], [init_omega]]) for i in range(fixed_params.shape[0])]
)
# simul_objective(init_x), grad_simul_objective(init_x)


def scipy_objective(x):
    x = x.reshape((fixed_params.shape[0], 2))
    x = x[..., None]
    return simul_objective(x)


def scipy_gradient(x):
    x = x.reshape((fixed_params.shape[0], 2))
    x = x[..., None]
    return grad_simul_objective(x)


res = minimize(
    scipy_objective,
    init_x.ravel(),
    jac=scipy_gradient,
    method="BFGS",
    options={"disp": True},
)

KeyboardInterrupt: 

In [174]:
init_x.shape

(5, 2, 1)

In [170]:
%timeit simul_objective(init_x).block_until_ready()

6.42 ms ± 130 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [171]:
%timeit grad_simul_objective(init_x).block_until_ready()

75 ms ± 194 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [143]:
jax.jacfwd(objective, argnums=0)(
    jnp.array([[20.0], [23.0]]), fixed_params[0], target_ra, target_dec, observer_pos
)

Array([[-3724.90392458],
       [  -38.30653406]], dtype=float64)

In [147]:
jnp.array([[[20.0], [23.0]], [[20.0], [23.0]]])[0]

Array([[20.],
       [23.]], dtype=float64)

In [154]:
jnp.array([[[20.0], [23.0]], [[20.0], [23.0]]]).shape

(2, 2, 1)

In [155]:
jnp.repeat(jnp.array([[20.0], [23.0]]), 2)

Array([20., 20., 23., 23.], dtype=float64)

In [157]:
def tot(x):
    return jnp.sum(
        jax.vmap(objective, in_axes=(0, 0, None, None, None))(
            x,
            fixed_params[:2],
            target_ra,
            target_dec,
            observer_pos,
        )
    )


tot(jnp.array([[[20.0], [23.0]], [[20.0], [23.0]]]))

jnp.ravel(jax.jacfwd(tot)(jnp.array([[[20.0], [23.0]], [[20.0], [23.0]]])))

Array([-3724.90392458,   -38.30653406, -3709.61477897,  -168.10468421],      dtype=float64)

In [None]:
x, v = elements_to_cartesian(
    **translate_fixed(fixed_params[0]),
    **translate_variable(jnp.array([[20.0], [23.0]])),
)
state = SystemState(
    positions=x,
    velocities=v,
    log_gms=jnp.ones(x.shape[0]) * -jnp.inf,
    time=0.0,
    acceleration_func_kwargs=None,
)
calc_ra, calc_dec = on_sky(state, acc_func, observer_pos)
sky_sep(calc_ra, calc_dec, target_ra, target_dec)

Array([132260.09134417], dtype=float64)

In [131]:
translate_variable(jnp.array([20.0, 23.0]))
# translate_fixed(fixed_params[0])

{'omega': Array(20., dtype=float64), 'inc': Array(23., dtype=float64)}

In [1]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from tqdm import tqdm
from scipy.optimize import minimize

from astropy.time import Time
import astropy.units as u
from astropy.coordinates import SkyCoord
from astroquery.jplhorizons import Horizons

from jorbit.observation import Observations
from jorbit.utils.states import SystemState
from jorbit.ephemeris import Ephemeris
from jorbit.accelerations import create_newtonian_ephemeris_acceleration_func
from jorbit.integrators import ias15_evolve, initialize_ias15_integrator_state
from jorbit.astrometry.sky_projection import on_sky, sky_sep
from jorbit.astrometry.transformations import (
    elements_to_cartesian,
    cartesian_to_elements,
)

In [50]:
tess_coord = SkyCoord(
    95.97, 22.75, unit=(u.deg, u.deg), frame="icrs", obstime=Time("2023-11-11 00:00")
)

sector_times = (tess_coord.obstime + jnp.arange(0.0, 27, 0.1) * u.day).tdb.jd

obs = Observations(
    observed_coordinates=[tess_coord for i in range(sector_times.size)],
    times=jnp.array(sector_times),
    observatories="@-95",
    astrometric_uncertainties=1 * u.arcsec,
    verbose=True,
)

eph = Ephemeris(
    earliest_time=Time("2015-01-01 00:00"),
    latest_time=Time("2035-01-01 00:00"),
    ssos="default planets",
)
acc_func = create_newtonian_ephemeris_acceleration_func(eph.processor)

Downloading observer positions from Horizons...


In [52]:
@jax.jit
def best_orb(variable_params, fixed_params, target_ra, target_dec, observer_pos):
    x, v = elements_to_cartesian(**fixed_params, **variable_params)
    state = SystemState(
        positions=x,
        velocities=v,
        log_gms=jnp.array([0.0]),
        time=0.0,
        acceleration_func_kwargs=None,
    )
    calc_ra, calc_dec = on_sky(state, acc_func, observer_pos)

    return sky_sep(calc_ra, calc_dec, target_ra, target_dec)[0]


def produce_valid_orbits(
    helio_rs, eccs, nus, Omegas, target_ra, target_dec, observer_pos
):

    def scipy_objective(x, fixed, target_ra, target_dec, observer_pos):
        return best_orb(
            {"omega": jnp.array([x[0]]), "inc": jnp.array([x[1]])},
            fixed,
            target_ra,
            target_dec,
            observer_pos,
        )

    def scipy_gradient(x, fixed, target_ra, target_dec, observer_pos):
        g = jax.jacfwd(best_orb, argnums=0)(
            {"omega": jnp.array([x[0]]), "inc": jnp.array([x[1]])},
            fixed,
            target_ra,
            target_dec,
            observer_pos,
        )
        return jnp.array([g["omega"][0], g["inc"][0]])

    fixed_params = []
    for r in helio_rs:
        for e in eccs:
            for nu in nus:
                for Omega in Omegas:
                    semi = r / (1 - e**2) * (1 + e * jnp.cos(nu))
                    fixed_params.append(
                        {
                            "a": semi,
                            "ecc": jnp.array([e]),
                            "nu": jnp.array([nu]),
                            "Omega": jnp.array([Omega]),
                        }
                    )

    valid_orbits = []
    for fixed in tqdm(fixed_params):
        res = minimize(
            scipy_objective,
            [20.0, 120.0],
            jac=scipy_gradient,
            method="BFGS",
            args=(fixed, target_ra, target_dec, observer_pos),
            # options={"disp": True},
        )
        if res.fun < 10.0:
            el = {
                "a": fixed["a"],
                "ecc": fixed["ecc"],
                "inc": jnp.array([res.x[1]]),
                "Omega": fixed["Omega"],
                "omega": jnp.array([res.x[0]]),
                "nu": fixed["nu"],
            }
            valid_orbits.append((el, res.fun))

    return valid_orbits


l = 2
orbs = produce_valid_orbits(
    helio_rs=jnp.concatenate((jnp.arange(30.0, 50, 2), jnp.arange(50, 100.0, 10)))[:l],
    eccs=jnp.arange(0.0, 0.9, 0.1)[:l],
    nus=jnp.arange(0, 360.0, 40)[:l],
    Omegas=jnp.arange(0, 360.0, 40)[:l],
    target_ra=tess_coord.ra.rad,
    target_dec=tess_coord.dec.rad,
    observer_pos=obs.observer_positions[0],
)

xs, vs = [], []
for o in orbs:
    x, v = elements_to_cartesian(**o[0])
    xs.append(x[0])
    vs.append(v[0])
xs = jnp.array(xs)[:, None, :]
vs = jnp.array(vs)[:, None, :]

100%|██████████| 16/16 [00:10<00:00,  1.46it/s]


In [54]:
def evolved(x, v, observer_positions):
    state = SystemState(
        positions=x,
        velocities=v,
        log_gms=jnp.array([0.0]),
        time=tess_coord.obstime.tdb.jd,
        acceleration_func_kwargs=None,
    )

    new_positions, new_velocities, state1, new_integrator_state = ias15_evolve(
        initial_system_state=state,
        acceleration_func=acc_func,
        times=sector_times,
        initial_integrator_state=initialize_ias15_integrator_state(acc_func(state)),
        n_steps=3,
    )

    # this is silly, but I haven't fixed the tracer acceleration bug yet
    def tmp(x, v, obs_pos):
        state = SystemState(
            positions=x,
            velocities=v,
            log_gms=jnp.ones(x.shape[0]) * -jnp.inf,
            time=0.0,
            acceleration_func_kwargs=None,
        )
        return on_sky(state, acc_func, obs_pos)

    ras, decs = jax.vmap(tmp)(new_positions, new_velocities, observer_positions)
    return ras, decs

    return new_positions, new_velocities


# surprisingly faster than vmap
ras, decs = [], []
for i in tqdm(range(xs.shape[0])):
    ra, dec = evolved(xs[i], vs[i], obs.observer_positions)
    ras.append(ra)
    decs.append(dec)
ras = jnp.array(ras)[..., 0]
decs = jnp.array(decs)[..., 0]

100%|██████████| 16/16 [00:08<00:00,  1.91it/s]


In [55]:
ras.shape

(16, 270)

In [56]:
SkyCoord(ras[0], decs[0], unit=(u.rad, u.rad), frame="icrs")

<SkyCoord (ICRS): (ra, dec) in deg
    [(95.97      , 22.75      ), (95.96820624, 22.75009748),
     (95.96643005, 22.75021254), (95.96467057, 22.75034452),
     (95.96292585, 22.75049188), (95.96119306, 22.75065243),
     (95.95880357, 22.75084833), (95.95641904, 22.75105203),
     (95.9540361 , 22.7512609 ), (95.95165199, 22.75147259),
     (95.94926791, 22.75168501), (95.94694184, 22.75189436),
     (95.94563179, 22.75206379), (95.94738067, 22.75211647),
     (95.94496092, 22.75232032), (95.94253113, 22.75252032),
     (95.94009094, 22.75271615), (95.93764061, 22.75290764),
     (95.93518754, 22.75309442), (95.93285169, 22.75327201),
     (95.93255292, 22.75336927), (95.93305505, 22.75343211),
     (95.93121674, 22.75357708), (95.92936805, 22.75371787),
     (95.92750913, 22.75385457), (95.92497427, 22.75401195),
     (95.92242941, 22.75416544), (95.91987471, 22.75431514),
     (95.91731054, 22.75446116), (95.91474082, 22.75460346),
     (95.91222569, 22.75473993), (95.91072569, 22.