In [1]:
# test tle2sat function 

# sgp4 python package imports
from sgp4.api import Satrec

# jax_sgp4 imports
from jax_sgp4 import tle2sat

import jax.numpy as jnp
import numpy as np

# TLE for the International Space Station (ZARYA)
tle1 = '1 25544U 98067A   25316.22557474  .00019835  00000-0  36009-3 0  9990'
tle2 = '2 25544  51.6334 293.5281 0004132  61.4658 298.6746 15.49560899538121' 

# Create Satrec object
sat = Satrec.twoline2rv(tle1, tle2)

sat_test = tle2sat(tle1, tle2)

print(sat.no_kozai, sat_test.n0) # diff units
print(sat.ecco, sat_test.e0)
print(sat.inclo * 180.0 / jnp.pi, sat_test.i0)
print(sat.argpo * 180.0 / jnp.pi, sat_test.w0)
print(sat.nodeo * 180.0 / jnp.pi, sat_test.Omega0) # rounding error?
print(sat.mo * 180.0 / jnp.pi, sat_test.M0)
print(sat.bstar, sat_test.Bstar)
print(sat.epochdays, sat_test.epochdays)
print(sat.epochyr, sat_test.epochyr)  

0.06761234911928327 15.49560899
0.0004132 0.0004132
51.6334 51.6334
61.4658 61.4658
293.52809999999994 293.5281
298.6746 298.6746
0.00036009000000000003 0.00036009000000000003
316.22557474 316.22557474
25 2025


In [None]:
# test sgp4 function (to compare jitted vs unjitted but jaxsgp4 has basically replaced this now, can ignore)

from sgp4.api import jday, Satrec
from jax_sgp4 import sgp4, tle2sat
import jax
import jax.numpy as jnp

# TLE for the International Space Station (ZARYA)
tle1 = '1 25544U 98067A   25316.22557474  .00019835  00000-0  36009-3 0  9990'
tle2 = '2 25544  51.6334 293.5281 0004132  61.4658 298.6746 15.49560899538121' 

# Create Satrec object
sat = Satrec.twoline2rv(tle1, tle2)

# Choose a time since epoch to propagate to (minutes)
tsince = 500.0

# Propagate using standard SGP4 implementation
e, r, v = sat.sgp4_tsince(tsince)

print(r, v)  # position (km), velocity (km/s)
print("SGP4 error code:", e)

# Propagate using JAX SGP4 implementation
satjax = tle2sat(tle1, tle2)

result = sgp4(satjax, tsince)
rjax = result[:3]
vjax = result[3:]

print(rjax, vjax)  # position (km), velocity (km/s)

# Compare results
r = jnp.array(r)
v = jnp.array(v)
print(rjax - r)
print(vjax - v)

# # time sgp4 unjitted
# %timeit sgp4(satjax, tsince).block_until_ready()

# # time sgp4 jitted
# jaxsgp4 = jax.jit(sgp4)
# %timeit jaxsgp4(satjax, tsince).block_until_ready()

# # time sgp4 python package C++ implementation
# # (Verify it is using the fast C++ implementation rather than the slow python)
# from sgp4.api import accelerated
# print(accelerated)
# %timeit sat.sgp4_tsince(tsince)

(705.08880216296, 5760.37365185177, 3527.6041455633244) (-5.192094694881382, 3.401162766577282, -4.496405865253736)
SGP4 error code: 0
[ 705.08880216 5760.37365185 3527.60414556] [-5.19209469  3.40116277 -4.49640587]
[ 5.50244295e-11 -3.63797881e-11  4.82032192e-11]
[9.76996262e-15 7.81597009e-14 4.79616347e-14]


In [3]:
# test jaxsgp4
from sgp4.api import Satrec
from jax_sgp4 import jaxsgp4, tle2sat
import jax.numpy as jnp

# for a single time 

# TLE for the International Space Station (ZARYA)
tle1 = '1 25544U 98067A   25316.22557474  .00019835  00000-0  36009-3 0  9990'
tle2 = '2 25544  51.6334 293.5281 0004132  61.4658 298.6746 15.49560899538121'

tsince = 500.0

# Propagate using standard SGP4 implementation
sat = Satrec.twoline2rv(tle1, tle2)
e, r, v = sat.sgp4_tsince(tsince)

# Propagate using JAX SGP4 implementation
satjax = tle2sat(tle1, tle2)
result = jaxsgp4(satjax, tsince)
rjax = result[:3]
vjax = result[3:]

print(r)
print(v)
print(rjax)
print(vjax)
print(rjax - jnp.array(r))
print(vjax - jnp.array(v))

(705.08880216296, 5760.37365185177, 3527.6041455633244)
(-5.192094694881382, 3.401162766577282, -4.496405865253736)
[ 705.08880216 5760.37365185 3527.60414556]
[-5.19209469  3.40116277 -4.49640587]
[ 5.04769559e-11 -3.63797881e-11  4.82032192e-11]
[6.21724894e-15 7.46069873e-14 4.79616347e-14]


In [10]:
# test spg4_jdfr function 

from jax_sgp4 import sgp4_jdfr, tle2sat, sgp4
from sgp4.api import jday, Satrec
import jax.numpy as jnp

# TLE for the International Space Station (ZARYA)
tle1 = '1 25544U 98067A   25316.22557474  .00019835  00000-0  36009-3 0  9990'
tle2 = '2 25544  51.6334 293.5281 0004132  61.4658 298.6746 15.49560899538121' 

# Create Satrec object
sat = Satrec.twoline2rv(tle1, tle2)

satjax = tle2sat(tle1, tle2)

# Choose a Julian date to propagate to
print(jday(2026, 1, 14, 12, 0, 0))  # Example Julian date 14th jan 2026 at 12:00:00 UTC
jd, fr = jday(2026, 1, 14, 12, 0, 0)

result = sgp4_jdfr(satjax, jd, fr)
r_jax = result[:3]
v_jax = result[3:]

print(r_jax)  # position (km)
print(v_jax)  # velocity (km/s)

# compare against sgp4 package
e, r_ref, v_ref = sat.sgp4(jd, fr)
print(r_ref)
print(v_ref)
print(r_jax - jnp.array(r_ref))
print(v_jax - jnp.array(v_ref))



(2461054.5, 0.5)
[ 5824.93922536 -3249.62417854 -1263.90351831]
[3.32215798 3.68968644 5.84040313]
(5824.939225362986, -3249.624178538239, -1263.9035183074157)
(3.3221579839118673, 3.689686441887596, 5.840403125891872)
[9.06766218e-10 1.01226760e-09 1.58524927e-09]
[-2.02371453e-12  1.12931886e-12  4.38760139e-13]


In [3]:
# test sgp4_many_times function

from jax_sgp4 import sgp4_many_times

tsince_array = jnp.linspace(0, 1440, num=10000)  # 10 time points over one day
results_many_times = sgp4_many_times(satjax, tsince_array)

%timeit sgp4_many_times(satjax, tsince_array).block_until_ready()

# convert array of times to jd, fr for sgp4 package
jd, fr = sat.jdsatepoch, sat.jdsatepochF
jd_array = jd + np.array(tsince_array) / 1440.0
fr_array = fr * np.ones_like(tsince_array)

# check results against sgp4 package
_, true_r_many_times, true_v_many_times = sat.sgp4_array(jd_array, fr_array)
print(true_r_many_times - results_many_times[:, :3])  # compare positions
print(true_v_many_times - results_many_times[:, 3:6])  # compare velocities

%timeit sat.sgp4_array(jd_array, fr_array)  # SGP4 C++ implementation for multiple times

# test over array for plot
jaxtimes = []
oldsgp4times = []
nums = []

for i in range(5):
  num = 10**i
  nums.append(num)
  tsince_array = jnp.linspace(0, 1440, num=num) 
  jaxtime = %timeit -o sgp4_many_times(n0_test, e0_test, i0_test, w0_test, Omega0_test, M0_test, t0_test, Bstar_test, tsince_array).block_until_ready()
  jaxtimes.append(jaxtime.average)
  jd_array = jd + np.array(tsince_array) / 1440.0
  fr_array = fr * np.ones_like(tsince_array)
  oldsgp4time = %timeit -o sat.sgp4_array(jd_array, fr_array)
  oldsgp4times.append(oldsgp4time.average)

# Plot results

import matplotlib.pyplot as plt
plt.loglog(nums, jaxtimes, label='JAX SGP4')
plt.loglog(nums, oldsgp4times, label='SGP4 C++')
plt.xlabel('Number of Time Points')
plt.ylabel('Time (s)')
plt.legend()

KeyboardInterrupt: 

In [None]:
# test tle2sat_array and sgp4_many_sats functions 

from jax_sgp4 import tle2sat_array, sgp4_many_sats
from sgp4.api import Satrec
from importlib.resources import files
import numpy as np

# test propagation over multiple satellites

# took raw gnss TLE data from https://celestrak.org/NORAD/elements/ which is in file gnss.txt 
# gnss actually doesn't work because these are deep space 
# use space stations instead gives 32 satellites to test with
# or use starlink gives 9341 LEO sats to test with

data = files('jax_sgp4').joinpath('data/starlink.txt').read_text()
lines = iter(data.splitlines())

tle1_list = []
tle2_list = []

for line1 in lines:
    
    if not line1.startswith('1 '):
        continue 
        # note: Will get an error here if the satellite name starts with a '1 '

    line2 = next(lines)

    line1 = line1[:69]
    line2 = line2[:69]

    tle1_list.append(line1)
    tle2_list.append(line2)

print(f"Number of TLEs: {len(tle1_list)}")

# test for jax sgp4
sat_array_jax = tle2sat_array(tle1_list, tle2_list)

# Choose a time since epoch to propagate to (minutes)
tsince = 120.0
results_many_sats = sgp4_many_sats(sat_array_jax, tsince)

print(results_many_sats.shape)  # (num_sats, 6)
#%timeit sgp4_many_sats(sat_array_jax, tsince).block_until_ready()

# compare against python sgp4 package for multiple satellites SatrecArray
# (don't use this for timing as it's not the fast C++ implementation)
# just here to check against for correctness

true_results = []
for tle1, tle2 in zip(tle1_list, tle2_list):
    sat = Satrec.twoline2rv(tle1, tle2)
    e, r, v = sat.sgp4_tsince(tsince)
    true_results.append((*r, *v))

true_results = np.array(true_results)

# Compare
print(true_results[:, :3] - results_many_sats[:, :3])  # positions
print(true_results[:, 3:] - results_many_sats[:, 3:])  # velocities


# see how time scales with number of satellites  ######### the below needs fixing
# test over array 
manysatjaxtimes = []
satnums = []

for i in range(4): # change this for starlink data
  num = 10**i
  satnums.append(num)
  tsince = 120.0 # choose arbitrary time
  jaxtime = %timeit -o sgp4_many_sats(sat_array_jax[:num], tsince).block_until_ready() # need to fix sat_array_jax slicing
  manysatjaxtimes.append(jaxtime.average)

# Plot results
import matplotlib.pyplot as plt
plt.plot(satnums, manysatjaxtimes, label='JAX SGP4 Many Sats')
plt.xlabel('Number of Satellites')
plt.ylabel('Time (s)')
plt.legend()