In [None]:
# test tle2sat function 

# sgp4 python package imports
from sgp4.api import Satrec

# jax_sgp4 imports
from functions import tle2sat

import jax.numpy as jnp
import numpy as np

# TLE for the International Space Station (ZARYA)
tle1 = '1 25544U 98067A   25316.22557474  .00019835  00000-0  36009-3 0  9990'
tle2 = '2 25544  51.6334 293.5281 0004132  61.4658 298.6746 15.49560899538121' 

# Create Satrec object
sat = Satrec.twoline2rv(tle1, tle2)

sat_test = tle2sat(tle1, tle2)

print(sat.no_kozai, sat_test.n0) # diff units
print(sat.ecco, sat_test.e0)
print(sat.inclo * 180.0 / jnp.pi, sat_test.i0)
print(sat.argpo * 180.0 / jnp.pi, sat_test.w0)
print(sat.nodeo * 180.0 / jnp.pi, sat_test.Omega0) # rounding error?
print(sat.mo * 180.0 / jnp.pi, sat_test.M0)
print(sat.bstar, sat_test.Bstar)
print(sat.epochdays, sat_test.epochdays)
print(sat.epochyr, sat_test.epochyr)  

In [None]:
# test sgp4 function 

# sgp4 python package imports
from sgp4.api import jday

# jax_sgp4 imports
from functions import sgp4

# TLE for the International Space Station (ZARYA)
tle1 = '1 25544U 98067A   25316.22557474  .00019835  00000-0  36009-3 0  9990'
tle2 = '2 25544  51.6334 293.5281 0004132  61.4658 298.6746 15.49560899538121' 

# Create Satrec object
sat = Satrec.twoline2rv(tle1, tle2)

# Choose a time since epoch to propagate to (minutes)
tsince = 500.0

# Propagate using standard SGP4 implementation
e, r, v = sat.sgp4_tsince(tsince)

print(r, v)  # position (km), velocity (km/s)
print("SGP4 error code:", e)

# Propagate using JAX SGP4 implementation
satjax = tle2sat(tle1, tle2)

result = sgp4(satjax, tsince)
rjax = result[:3]
vjax = result[3:]

print(rjax, vjax)  # position (km), velocity (km/s)

# Compare results
r = jnp.array(r)
v = jnp.array(v)
print(rjax - r)
print(vjax - v)

# time sgp4 unjitted
%timeit sgp4(n0_test, e0_test, i0_test, w0_test, Omega0_test, M0_test, t0_test, Bstar_test, tsince).block_until_ready()

# time sgp4 jitted
jaxsgp4 = jax.jit(sgp4)
%timeit jaxsgp4(n0_test, e0_test, i0_test, w0_test, Omega0_test, M0_test, t0_test, Bstar_test, tsince).block_until_ready()

# time sgp4 python package C++ implementation

# (Verify it is using the fast C++ implementation rather than the slow python)
from sgp4.api import accelerated
print(accelerated)

%timeit sat.sgp4_tsince(tsince)

In [None]:
# test sgp4_many_times function

from functions import sgp4_many_times

tsince_array = jnp.linspace(0, 1440, num=10000)  # 10 time points over one day
results_many_times = sgp4_many_times(satjax, tsince_array)

%timeit sgp4_many_times(satjax, tsince_array).block_until_ready()

# convert array of times to jd, fr for sgp4 package
jd, fr = sat.jdsatepoch, sat.jdsatepochF
jd_array = jd + np.array(tsince_array) / 1440.0
fr_array = fr * np.ones_like(tsince_array)

# check results against sgp4 package
_, true_r_many_times, true_v_many_times = sat.sgp4_array(jd_array, fr_array)
print(true_r_many_times - results_many_times[:, :3])  # compare positions
print(true_v_many_times - results_many_times[:, 3:6])  # compare velocities

%timeit sat.sgp4_array(jd_array, fr_array)  # SGP4 C++ implementation for multiple times

# test over array for plot
jaxtimes = []
oldsgp4times = []
nums = []

for i in range(5):
  num = 10**i
  nums.append(num)
  tsince_array = jnp.linspace(0, 1440, num=num) 
  jaxtime = %timeit -o sgp4_many_times(n0_test, e0_test, i0_test, w0_test, Omega0_test, M0_test, t0_test, Bstar_test, tsince_array).block_until_ready()
  jaxtimes.append(jaxtime.average)
  jd_array = jd + np.array(tsince_array) / 1440.0
  fr_array = fr * np.ones_like(tsince_array)
  oldsgp4time = %timeit -o sat.sgp4_array(jd_array, fr_array)
  oldsgp4times.append(oldsgp4time.average)

# Plot results

import matplotlib.pyplot as plt
plt.loglog(nums, jaxtimes, label='JAX SGP4')
plt.loglog(nums, oldsgp4times, label='SGP4 C++')
plt.xlabel('Number of Time Points')
plt.ylabel('Time (s)')
plt.legend()

In [None]:
# test tle2sat_array and sgp4_many_sats functions 

from functions import tle2sat_array, sgp4_many_sats

# test propagation over multiple satellites

# took raw gnss TLE data from https://celestrak.org/NORAD/elements/ which is in file gnss.txt 
# gnss actually doesn't work because these are deep space 
# use space stations instead gives 32 satellites to test with
# or use starlink gives 9341 LEO sats to test with

from importlib.resources import files

data = files('jax_sgp4').joinpath('data/starlink.txt').read_text()
lines = iter(data.splitlines())

tle1_list = []
tle2_list = []

for line1 in lines:
    
    if not line1.startswith('1 '):
        continue 
        # note: Will get an error here if the satellite name starts with a '1 '

    line2 = next(lines)

    line1 = line1[:69]
    line2 = line2[:69]

    tle1_list.append(line1)
    tle2_list.append(line2)

print(f"Number of TLEs: {len(tle1_list)}")

# test for jax sgp4
sat_array_jax = tle2sat_array(tle1_list, tle2_list)

# Choose a time since epoch to propagate to (minutes)
tsince = 120.0
results_many_sats = sgp4_many_sats(sat_array_jax, tsince)

print(results_many_sats.shape)  # (num_sats, 6)
#%timeit sgp4_many_sats(sat_array_jax, tsince).block_until_ready()

# compare against python sgp4 package for multiple satellites SatrecArray
# (don't use this for timing as it's not the fast C++ implementation)
# just here to check against for correctness

true_results = []
for tle1, tle2 in zip(tle1_list, tle2_list):
    sat = Satrec.twoline2rv(tle1, tle2)
    e, r, v = sat.sgp4_tsince(tsince)
    true_results.append((*r, *v))

true_results = np.array(true_results)

# Compare
print(true_results[:, :3] - results_many_sats[:, :3])  # positions
print(true_results[:, 3:] - results_many_sats[:, 3:])  # velocities


# see how time scales with number of satellites  ######### the below needs fixing
# test over array 
manysatjaxtimes = []
satnums = []

for i in range(4): # change this for starlink data
  num = 10**i
  satnums.append(num)
  tsince = 120.0 # choose arbitrary time
  jaxtime = %timeit -o sgp4_many_sats(sat_array_jax[:num], tsince).block_until_ready() # need to fix sat_array_jax slicing
  manysatjaxtimes.append(jaxtime.average)

# Plot results
import matplotlib.pyplot as plt
plt.plot(satnums, manysatjaxtimes, label='JAX SGP4 Many Sats')
plt.xlabel('Number of Satellites')
plt.ylabel('Time (s)')
plt.legend()