In [1]:
# Core
import numpy as np
import pandas as pd

# Tensorflow / ML
import tensorflow as tf
from tensorflow.python.keras import backend as K
keras = tf.keras

# MSE imports
import kepler_sieve
from ztf_ast import load_ztf_nearest_ast, calc_hit_freq
from ztf_element import make_ztf_batch, load_ztf_batch
from asteroid_element import load_ast_elt
from candidate_element import orbital_element_batch, perturb_elts, random_elts
from asteroid_model import make_model_ast_pos, make_model_ast_dir
from asteroid_model import AsteroidPosition, AsteroidDirection
from astro_utils import dist2deg, dist2sec, deg2dist

Found 4 GPUs.  Setting memory growth = True.


## Load ZTF Data and Batch of Orbital Elements

In [2]:
# Load orbital elements for known asteroids
ast_elt = load_ast_elt()

# Number of asteroids
N_ast = ast_elt.shape[0]

In [3]:
# Load ztf nearest asteroid data
ztf_ast = load_ztf_nearest_ast()

In [4]:
# Asteroid numbers and hit counts
ast_num, hit_count = calc_hit_freq(ztf=ztf_ast, thresh_sec=2.0)

# Sort the hit counts in descending order and find the top batch_size
idx = np.argsort(hit_count)[::-1]

# Extract the asteroid number and hit count for this batch
ast_num_best = ast_num[idx]
hit_count_best = hit_count[idx]

In [5]:
# Set element batch size
elt_batch_size = 64

# Batch of unperturbed elements
elts_ast = orbital_element_batch(ast_nums=ast_num_best[0:elt_batch_size])

In [6]:
elts_ast

Unnamed: 0,element_id,a,e,inc,Omega,omega,f,epoch
0,51921,2.669306,0.217361,0.499554,4.699703,2.450796,-1.133491,58600.0
1,59244,2.634727,0.262503,0.465045,5.738297,1.766995,-1.601363,58600.0
2,15786,1.883227,0.047655,0.392360,6.134689,0.804823,-1.246069,58600.0
3,3904,2.556387,0.098279,0.261542,5.450163,2.202423,-1.357345,58600.0
4,142999,2.619944,0.191376,0.514017,0.238022,0.946463,-1.299301,58600.0
...,...,...,...,...,...,...,...,...
59,11952,2.219650,0.086091,0.117967,0.042442,2.904823,-3.016580,58600.0
60,134815,2.612770,0.140831,0.513923,0.272689,0.645552,-0.957836,58600.0
61,27860,2.619406,0.096185,0.200633,5.541399,3.266046,3.948770,58600.0
62,85937,2.342292,0.197267,0.439063,5.279693,3.210025,3.947687,58600.0


In [7]:
# Perturb orbital elements
sigma_a = 0.0 
sigma_e = 0.0 
sigma_f_deg = 0.1
sigma_Omega_deg = 0.0
sigma_omega_deg = 0.0
mask_pert = None
random_seed = 42

elts_pert = perturb_elts(elts_ast, sigma_a=sigma_a, sigma_e=sigma_e, sigma_f_deg=sigma_f_deg, 
                         sigma_Omega_deg=sigma_Omega_deg, sigma_omega_deg=sigma_omega_deg,
                         mask_pert=mask_pert, random_seed=random_seed)

In [8]:
elts_pert

Unnamed: 0,element_id,a,e,inc,Omega,omega,f,epoch
0,51921,2.669306,0.217361,0.499554,4.699703,2.450796,-1.133117,58600.0
1,59244,2.634727,0.262503,0.465045,5.738297,1.766995,-1.603537,58600.0
2,15786,1.883227,0.047655,0.392360,6.134689,0.804823,-1.245767,58600.0
3,3904,2.556387,0.098279,0.261542,5.450163,2.202423,-1.356673,58600.0
4,142999,2.619945,0.191376,0.514017,0.238022,0.946463,-1.300844,58600.0
...,...,...,...,...,...,...,...,...
59,11952,2.219650,0.086091,0.117967,0.042442,2.904823,-3.014978,58600.0
60,134815,2.612770,0.140831,0.513923,0.272689,0.645552,-0.954132,58600.0
61,27860,2.619406,0.096185,0.200633,5.541399,3.266046,3.950572,58600.0
62,85937,2.342292,0.197267,0.439063,5.279693,3.210025,3.945035,58600.0


## Batches of ZTF Data vs. Elements

In [9]:
# Arguments to make_ztf_batch
thresh_deg = 1.0
near_ast = False
regenerate = False

In [10]:
# Load unperturbed element batch
ztf_elt_ast = load_ztf_batch(elts=elts_ast, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [11]:
# Load perturbed element batch
ztf_elt_pert = load_ztf_batch(elts=elts_pert, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [12]:
ztf_elt_ast

Unnamed: 0,ztf_id,element_id,ObjectID,CandidateID,TimeStampID,mjd,ra,dec,ux,uy,...,vz,elt_ux,elt_uy,elt_uz,elt_r,s,s_sec,z,v,is_hit
0,53851,733,b'ZTF18abnothj',594197584815010004,5501,58348.197581,266.229165,-13.513802,-0.063945,-0.983101,...,0.004080,-0.057300,-0.982042,0.179751,2.234078,0.010624,2191.371398,0.999944,0.370539,False
1,73604,733,b'ZTF18ablwzmb',594197584815015003,5501,58348.197581,265.761024,-13.509148,-0.071871,-0.982578,...,0.004080,-0.057300,-0.982042,0.179751,2.234078,0.016809,3467.103003,0.999859,0.927533,False
2,82343,733,b'ZTF18abiydvm',635193253015015018,12089,58389.193252,270.331454,-11.244934,0.005674,-0.977422,...,0.003825,0.000918,-0.977996,0.208622,2.703478,0.005450,1124.142942,0.999985,0.097510,False
3,257221,733,b'ZTF18acakcqg',931471223715015007,39920,58685.471227,29.693832,42.180412,0.643725,0.603886,...,-0.001953,0.639004,0.610779,0.467571,2.175851,0.008712,1797.042210,0.999962,0.249184,False
4,327000,733,b'ZTF18achmdmw',937465970615015011,40837,58691.465972,33.104905,44.059131,0.601970,0.636719,...,-0.002129,0.606278,0.637608,0.475272,2.114865,0.007949,1639.537152,0.999968,0.207418,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
90205,5650588,324582,b'ZTF20aaqvhld',1150176701515015008,96618,58904.176701,44.164238,29.650540,0.623416,0.752309,...,-0.001541,0.627640,0.750696,0.206212,2.981799,0.008187,1688.636853,0.999966,0.220027,False
90206,5650589,324582,b'ZTF20aaqvhld',1150176245715015005,96617,58904.176250,44.164062,29.650536,0.623417,0.752307,...,-0.001541,0.627641,0.750695,0.206213,2.981793,0.008187,1688.600639,0.999966,0.220018,False
90207,5650665,324582,b'ZTF20aaqvhll',1150176245815015010,96617,58904.176250,44.368640,28.490480,0.628284,0.753618,...,-0.001541,0.627641,0.750695,0.206213,2.981793,0.013370,2757.856469,0.999911,0.586871,False
90208,5650697,324582,b'ZTF20aaqvhmb',1150176246015015005,96617,58904.176250,43.296207,29.505908,0.633424,0.743491,...,-0.001541,0.627641,0.750695,0.206213,2.981793,0.012388,2555.279465,0.999923,0.503822,False


In [13]:
ztf_elt_ast.columns

Index(['ztf_id', 'element_id', 'ObjectID', 'CandidateID', 'TimeStampID', 'mjd',
       'ra', 'dec', 'ux', 'uy', 'uz', 'qx', 'qy', 'qz', 'vx', 'vy', 'vz',
       'elt_ux', 'elt_uy', 'elt_uz', 'elt_r', 's', 's_sec', 'z', 'v',
       'is_hit'],
      dtype='object')

In [14]:
# Review results
ztf_elt = ztf_elt_ast
element_id_best = ast_num_best[0]
mask = (ztf_elt.element_id == element_id_best)
hits_best = np.sum(ztf_elt[mask].is_hit)
s_sec_min = np.min(ztf_elt[mask].s_sec)
idx = np.argmin(ztf_elt.s)
ztf_id = ztf_elt.ztf_id[idx]
# ztf_elt[mask].iloc[idx:idx+1]
print(f'Best asteroid has element_id = {element_id_best}')
print(f'Hit count: {hits_best}')
print(f'Closest hit: {s_sec_min:0.3f} arc seconds')
# ztf_elt[mask]

Best asteroid has element_id = 51921
Hit count: 158
Closest hit: 0.382 arc seconds


## Load Position and Direction Models

In [15]:
# Data types
dtype = tf.float32
dtype_np = np.float32

In [16]:
# Alias ztf_elt
ztf_elt = ztf_elt_ast.copy()

In [17]:
# Build numpy array of times
ts_np = ztf_elt.mjd.values.astype(dtype_np)

# Build tensor of flattened times
ts = keras.backend.constant(ts_np)

In [18]:
# Get observation count per element
row_lengths_np = ztf_elt.element_id.groupby(ztf_elt.element_id).count().values.astype(np.int32)

# Tensor of row lengths
row_lengths = keras.backend.constant(row_lengths_np, dtype=tf.int32)

In [19]:
row_lengths

<tf.Tensor: shape=(64,), dtype=int32, numpy=
array([1424,  936, 1117, 1087, 1035, 1028,  780, 2747,  926, 2485, 1023,
       1013, 1063,  858, 1943,  949,  636, 1199, 1547, 1477,  991,  790,
       3205, 2719, 1612,  771, 1047, 1253,  845, 3085,  866, 1474, 2678,
        802, 1653,  708, 1043,  855, 2362, 1012, 1048, 1677,  916,  813,
       1101, 1180, 2708,  688, 1929, 2871, 1099,  743,  943, 2018, 2717,
        595,  864,  941, 2628, 3127,  943, 1726,  762, 1129], dtype=int32)>

In [20]:
# Build ragged tensor of input times
ts_r = tf.RaggedTensor.from_row_lengths(values=ztf_elt.mjd.values.astype(dtype_np), row_lengths=row_lengths)

In [21]:
# Unique times
ts_unq = np.unique(ztf_elt_ast.mjd)
TimeStampID_unq = np.unique(ztf_elt_ast.TimeStampID)

# The epoch
epoch0 = elts_ast.epoch[0]

In [22]:
# Report time tensor shapes
print(f'ts.shape={ts.shape}')
print(f'ts.shape={ts_r.shape}')
# print(f'ts_flat.shape={ts_flat.shape}')
print(f'ts_unq.shape={ts_unq.shape}')

ts.shape=(90210,)
ts.shape=(64, None)
ts_unq.shape=(6383,)


In [23]:
# Observation site
site_name = 'palomar'

In [24]:
# Build position model
model_pos = make_model_ast_pos(ts_np=ts_np, row_lengths_np=row_lengths_np)

In [25]:
# Build direction model
model_dir = make_model_ast_dir(ts_np=ts_np, row_lengths_np=row_lengths_np, site_name=site_name)

In [26]:
# Stack elements as a dict of numpy arrays for prediction
cols_elt = ['a', 'e', 'inc', 'Omega', 'omega', 'f', 'epoch']
elts_ast_dict = {col : elts_ast[col].values for col in cols_elt}
# elts_ast_dict

## Run and Calibrate Position Model

In [27]:
model_pos.summary()

Model: "model_asteroid_pos"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
a (InputLayer)                  [(64,)]              0                                            
__________________________________________________________________________________________________
Omega (InputLayer)              [(64,)]              0                                            
__________________________________________________________________________________________________
e (InputLayer)                  [(64,)]              0                                            
__________________________________________________________________________________________________
epoch (InputLayer)              [(64,)]              0                                            
_________________________________________________________________________________

In [28]:
# Predict position model
q_pred, v_pred = model_pos.predict(elts_ast_dict)

In [29]:
# Review shape of predictions
print(f'q_pred.shape = {q_pred.shape}')
print(f'v_pred.shape = {v_pred.shape}')

q_pred.shape = (64, None, 3)
v_pred.shape = (64, None, 3)


In [30]:
# Calibration arrays (flat)
cols_q_ast = ['qx', 'qy', 'qz']
cols_v_ast = ['vx', 'vy', 'vz']
q_ast = ztf_elt[cols_q_ast].values.astype(dtype_np)
v_ast = ztf_elt[cols_v_ast].values.astype(dtype_np)

In [31]:
# Run calibration
model_pos.ast_pos_layer.calibrate(elts=elts_ast, q_ast=q_ast, v_ast=v_ast)

In [32]:
# Check that corrections during calibration aren't too large
mean_dq = np.mean(tf.linalg.norm(model_pos.ast_pos_layer.dq, axis=1))
mean_dv = np.mean(tf.linalg.norm(model_pos.ast_pos_layer.dv, axis=1))

print(f'Mean calibration adjustments:')
print(f'mean_dq = {mean_dq:6.2e}')
print(f'mean_dq = {mean_dv:6.2e}')

Mean calibration adjustments:
mean_dq = 9.84e-05
mean_dq = 9.31e-07


## Predict Direction Model

In [33]:
model_dir.summary()

Model: "model_asteroid_dir"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
a (InputLayer)                  [(64,)]              0                                            
__________________________________________________________________________________________________
Omega (InputLayer)              [(64,)]              0                                            
__________________________________________________________________________________________________
e (InputLayer)                  [(64,)]              0                                            
__________________________________________________________________________________________________
epoch (InputLayer)              [(64,)]              0                                            
_________________________________________________________________________________

In [34]:
# Calibrate direction model's position layer
model_dir.ast_pos_layer.calibrate(elts=elts_ast, q_ast=q_ast, v_ast=v_ast)

In [35]:
# Predict calibrated direction model
u_pred, r_pred = model_dir.predict(elts_ast_dict)

In [36]:
# Review shape of predictions
print(f'u_pred.shape = {u_pred.shape}')
print(f'r_pred.shape = {r_pred.shape}')

u_pred.shape = (64, None, 3)
r_pred.shape = (64, None, 1)


In [37]:
# Compare to direction on the ztf_elt frame (should be the same!)
cols_u = ['elt_ux', 'elt_uy', 'elt_uz']
u_exp = ztf_elt[cols_u].values
u_diff = u_pred.values - u_exp
mean_diff_s = np.mean(np.linalg.norm(u_diff, axis=1))
mean_diff_sec = dist2sec(mean_diff_s)

# Report results
print(f'Mean direction difference:')
print(f'Cartesian  : {mean_diff_s:8.2e}')
print(f'Arc Seconds: {mean_diff_sec:6.3f}')

Mean direction difference:
Cartesian  : 7.86e-06
Arc Seconds:  1.622


## Assemble Tensors for Prototype Model

In [38]:
# Orbital elements
a = tf.Variable(initial_value=elts_ast.a, dtype=dtype, name='a')
e = tf.Variable(initial_value=elts_ast.e, dtype=dtype, name='e')
inc = tf.Variable(initial_value=elts_ast.inc, dtype=dtype, name='inc')
Omega = tf.Variable(initial_value=elts_ast.Omega, dtype=dtype, name='Omega')
omega = tf.Variable(initial_value=elts_ast.omega, dtype=dtype, name='omega')
f = tf.Variable(initial_value=elts_ast.f, dtype=dtype, name='f')
epoch = tf.constant(value=elts_ast.epoch, dtype=dtype, name='epoch')

In [39]:
# Review element shapes (all the same)
print(f'a.shape = {a.shape}')

a.shape = (64,)


In [40]:
# Review ztf_elt on best candidate
mask_best = (ztf_elt.element_id == element_id_best)
ztf_elt_best = ztf_elt[mask_best]
ztf_elt_best

Unnamed: 0,ztf_id,element_id,ObjectID,CandidateID,TimeStampID,mjd,ra,dec,ux,uy,...,vz,elt_ux,elt_uy,elt_uz,elt_r,s,s_sec,z,v,is_hit
49064,254639,51921,b'ZTF17aaaexyh',929453032715010002,39524,58683.453032,30.512085,25.122180,0.780026,0.590635,...,0.003049,0.773781,0.598746,0.206801,2.425591,0.010237,2111.633673,0.999948,3.440641e-01,False
49065,254654,51921,b'ZTF17aaaexyh',929479172715010000,39555,58683.479178,30.512152,25.122183,0.780025,0.590636,...,0.003049,0.773681,0.598853,0.206866,2.425228,0.010383,2141.648464,0.999946,3.539146e-01,False
49066,285559,51921,b'ZTF19abiquuf',936427393415015002,40501,58690.427396,33.557434,26.797350,0.743838,0.632026,...,0.002976,0.746833,0.626025,0.224350,2.329539,0.009682,1997.086804,0.999953,3.077489e-01,False
49067,285642,51921,b'ZTF19abiquzf',936427393915015000,40501,58690.427396,32.542958,27.063322,0.750685,0.620481,...,0.002976,0.746833,0.626025,0.224350,2.329539,0.007210,1487.255810,0.999974,1.706772e-01,False
49068,292364,51921,b'ZTF19abiquzd',936427393915015002,40501,58690.427396,32.772683,26.984702,0.749282,0.623062,...,0.002976,0.746833,0.626025,0.224350,2.329539,0.003845,793.025303,0.999993,4.852657e-02,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
49767,5447565,51921,b'ZTF20aapdeiy',1145119391715015002,95200,58899.119398,40.107769,39.932949,0.586472,0.708557,...,-0.001161,0.586477,0.708552,0.392427,1.969148,0.000007,1.543181,1.000000,1.837551e-07,True
49768,5461054,51921,b'ZTF18abnzokb',1145163031715010010,95280,58899.163032,40.387826,40.141830,0.582264,0.710898,...,-0.001163,0.586279,0.708772,0.392325,1.969530,0.005016,1034.707623,0.999987,8.261148e-02,False
49769,5461101,51921,b'ZTF20aapendd',1145163031715015008,95280,58899.163032,40.131138,39.932511,0.586274,0.708777,...,-0.001163,0.586279,0.708772,0.392325,1.969530,0.000007,1.500867,1.000000,1.738164e-07,True
49770,5646723,51921,b'ZTF20aaqulnc',1150144742815015009,96582,58904.144745,42.885851,39.904980,0.562069,0.734151,...,-0.001294,0.562074,0.734146,0.380923,2.012406,0.000007,1.436572,1.000000,1.592431e-07,True


In [41]:
# Build ragged tensor of u_obs
cols_u_obs = ['ux', 'uy', 'uz']
u_obs_r = tf.RaggedTensor.from_row_lengths(values=ztf_elt[cols_u_obs].values, row_lengths=row_lengths)

# Review shape
print(f'u_obs_r.shape = {u_obs_r.shape}')

u_obs_r.shape = (64, None, 3)


In [42]:
# Latest time
t_max = np.max(ztf_elt.mjd.values)

# Time for padded values
t_pad = t_max + 1.0

In [43]:
# Convert ragged to 2D
ts_2d = ts_r.to_tensor(default_value=t_pad)
ts_2d_np = ts_2d.numpy()

print(f'ts_2d.shape = {ts_2d.shape}')

ts_2d.shape = (64, 3205)


## Calculate Position and Direction Using Layers in asteroid_model.py

In [44]:
# Build position layer
ast_pos_layer = AsteroidPosition(ts_np=ts_np, row_lengths_np=row_lengths_np, name='ast_pos_layer')

In [45]:
# Build direction layer
direction_layer = AsteroidDirection(ts_np=ts_np, row_lengths_np=row_lengths_np, site_name=site_name, name='u_pred')

In [46]:
# Predict position using the layer
q_pred, v_pred = ast_pos_layer(a, e, inc, Omega, omega, f, epoch)

In [47]:
# Review output
print(f'q_pred.shape = {q_pred.shape}')
print(f'v_pred.shape = {q_pred.shape}')

q_pred.shape = (90210, 3)
v_pred.shape = (90210, 3)


In [48]:
# Run calibration
ast_pos_layer.calibrate(elts=elts_ast, q_ast=q_ast, v_ast=v_ast)

In [49]:
# Check that corrections during calibration aren't too large
mean_dq = np.mean(tf.linalg.norm(ast_pos_layer.dq, axis=1))
mean_dv = np.mean(tf.linalg.norm(ast_pos_layer.dv, axis=1))

print(f'Mean calibration adjustments:')
print(f'mean_dq = {mean_dq:6.2e}')
print(f'mean_dq = {mean_dv:6.2e}')

Mean calibration adjustments:
mean_dq = 9.84e-05
mean_dq = 9.31e-07


In [50]:
# Convert q, v to ragged tensors matching the element batch
q_r = tf.RaggedTensor.from_row_lengths(values=q_pred, row_lengths=row_lengths, name='q_r')
v_r = tf.RaggedTensor.from_row_lengths(values=v_pred, row_lengths=row_lengths, name='v_r')

In [51]:
# Report shapes
print(f'q_r.shape={q_r.shape}')

q_r.shape=(64, None, 3)


In [52]:
# Predict direction using the layer
u_pred, r_pred = direction_layer(a, e, inc, Omega, omega, f, epoch)

In [53]:
# Report shapes
print(f'u_pred.shape={u_pred.shape}')
print(f'r_pred.shape={r_pred.shape}')

u_pred.shape=(90210, 3)
r_pred.shape=(90210, 1)


## Create Input Tensors for Functional API Model Debugging

In [54]:
# Input tensors
in_a = keras.Input(shape=(), batch_size=elt_batch_size, name='a')
in_e = keras.Input(shape=(), batch_size=elt_batch_size, name='e')
in_inc = keras.Input(shape=(), batch_size=elt_batch_size, name='inc')
in_Omega = keras.Input(shape=(), batch_size=elt_batch_size, name='Omega')
in_omega = keras.Input(shape=(), batch_size=elt_batch_size, name='omega')
in_f = keras.Input(shape=(), batch_size=elt_batch_size, name='f')
in_epoch = keras.Input(shape=(), batch_size=elt_batch_size, name='epoch')

# Wrap inputs
inputs = (in_a, in_e, in_inc, in_Omega, in_omega, in_f, in_epoch)

In [55]:
in_a

<tf.Tensor 'a_2:0' shape=(64,) dtype=float32>

In [56]:
a

<tf.Variable 'a:0' shape=(64,) dtype=float32, numpy=
array([2.6693056, 2.634727 , 1.8832273, 2.5563872, 2.6199443, 2.407715 ,
       2.3491726, 2.3106558, 2.2718794, 2.6123784, 2.6039326, 3.0796754,
       3.1138892, 2.331416 , 3.0063655, 2.344482 , 2.286487 , 2.5829718,
       2.5813491, 2.6285136, 2.6182237, 3.2066405, 2.9078937, 2.6719072,
       2.3353634, 3.1711535, 2.3273304, 2.6373718, 2.4007137, 2.3050103,
       2.6771643, 2.6571858, 2.42876  , 1.903306 , 2.5840602, 2.4018168,
       2.5349774, 2.6773167, 3.0751357, 3.1739578, 3.2142053, 2.6185794,
       3.1158593, 1.9388875, 2.7592251, 2.275916 , 2.3057158, 3.3988707,
       2.391135 , 2.2808902, 2.240677 , 1.9278089, 1.968155 , 2.5760422,
       2.581424 , 2.3393567, 2.9165266, 2.6267219, 3.2064335, 2.2196503,
       2.6127703, 2.6194055, 2.3422916, 3.155569 ], dtype=float32)>

In [57]:
data_size = tf.reduce_sum(row_lengths)
elt_shape = (data_size, 1,)

In [58]:
# Create output tensor
ast_pos_layer = AsteroidPosition(ts_np=ts_np, row_lengths_np=row_lengths_np)
q_flat, v_flat = ast_pos_layer(in_a, in_e, in_inc, in_Omega, in_omega, in_f, in_epoch)

In [59]:
# Predict direction
u_pred, r_pred = direction_layer(in_a, in_e, in_inc, in_Omega, in_omega, in_f, in_epoch)

In [60]:
# Review output
print(f'u_pred.shape = {u_pred.shape}')
print(f'r_pred.shape = {r_pred.shape}')

u_pred.shape = (90210, 3)
r_pred.shape = (90210, 1)


In [61]:
# Build test position model
model_pos_test = keras.Model(inputs=inputs, outputs=[q_flat, v_flat])

In [62]:
# Build test direction model
model_dir_test = keras.Model(inputs=inputs, outputs=[u_pred, r_pred])

In [63]:
model_dir_test([a, e, inc, Omega, omega, f, epoch])

[<tf.Tensor: shape=(90210, 3), dtype=float32, numpy=
 array([[-5.7236522e-02, -9.8204035e-01,  1.7977960e-01],
        [-5.7236522e-02, -9.8204035e-01,  1.7977960e-01],
        [ 9.4392640e-04, -9.7799307e-01,  2.0863535e-01],
        ...,
        [ 6.2769508e-01,  7.5064498e-01,  2.0623037e-01],
        [ 6.2769508e-01,  7.5064498e-01,  2.0623037e-01],
        [ 6.2769508e-01,  7.5064498e-01,  2.0623037e-01]], dtype=float32)>,
 <tf.Tensor: shape=(90210, 1), dtype=float32, numpy=
 array([[2.234171 ],
        [2.234171 ],
        [2.7035184],
        ...,
        [2.9816544],
        [2.9816544],
        [2.9816544]], dtype=float32)>]

## Imports for Step by Step Calculations

In [64]:
import astropy
from astropy.units import au, day, year

# Local imports
from orbital_element import MeanToTrueAnomaly, TrueToMeanAnomaly
from asteroid_data import get_earth_pos, get_sun_pos_vel
from asteroid_model import ElementToPosition
from ra_dec import calc_topos

In [65]:
# Constants

# The gravitational constant in ('day', 'AU', 'Msun') coordinates
# Hard code G
G_ = 0.00029591220828559104
# The gravitational field strength mu = G * (m0 + m1)
# For massless asteroids orbiting the sun with units Msun, m0=1.0, m1=0.0, and mu = G
mu = tf.constant(G_)

# Speed of light; express this in AU / day
light_speed_au_day = astropy.constants.c.to(au / day).value

# Number of spatial dimensions
space_dims = 3

# Data types
dtype = tf.float32
dtype_np = np.float32

In [66]:
# Build direction layer
direction_layer = AsteroidDirection(ts_np=ts_np, row_lengths_np=row_lengths_np, site_name=site_name, name='u_pred')

In [67]:
# Predict direction layer
u_pred, r_pred = direction_layer(a, e, inc, Omega, omega, f, epoch)

In [68]:
u_pred.shape

TensorShape([90210, 3])

## Step by Step Calculation of Asteroid Position

In [69]:
epoch.shape

TensorShape([64])

In [70]:
target_shape = (1,)

In [71]:
# Time relative to epoch
epoch_t  = keras.layers.Reshape(target_shape, name='epoch_t')(tf.repeat(epoch, row_lengths))
t = keras.layers.subtract([ts, epoch_t], name='t')

In [72]:
# Report
print(f'ts.shape = {ts.shape}')
print(f'epoch.shape = {epoch.shape}')
print(f't.shape = {t.shape}')

ts.shape = (90210,)
epoch.shape = (64,)
t.shape = (90210, 1)


In [73]:
# Compute eccentric anomaly E from f and e
M = TrueToMeanAnomaly(name='TrueToMeanAnomaly')([f, e])

# Compute mean motion N from mu and a
a3 = tf.math.pow(a, 3, name='a3')
mu_over_a3 = tf.divide(mu, a3, name='mu_over_a3')
N = tf.sqrt(mu_over_a3, name='N')

In [74]:
# Repeat the constant orbital elements to be vectors of shape (batch_size, traj_size, 1)
a_t = keras.layers.Reshape(target_shape, name='a_t')(tf.repeat(a, row_lengths))
e_t = keras.layers.Reshape(target_shape, name='e_t')(tf.repeat(e, row_lengths))
inc_t = keras.layers.Reshape(target_shape, name='inc_t')(tf.repeat(inc, row_lengths))
Omega_t = keras.layers.Reshape(target_shape, name='Omega_t')(tf.repeat(Omega, row_lengths))
omega_t = keras.layers.Reshape(target_shape, name='omega_t')(tf.repeat(omega, row_lengths))

In [75]:
# Report
print(f'a_t.shape = {a_t.shape}')

a_t.shape = (90210, 1)


In [76]:
# Repeat initial mean anomaly M0 and mean motion N0 to match shape of outputs
M0_t  = keras.layers.Reshape(target_shape, name='M0_t')(tf.repeat(M, row_lengths))
N0_t  = keras.layers.Reshape(target_shape, name='N0_t')(tf.repeat(N, row_lengths))
# Compute the mean anomaly M(t) as a function of time
N_mult_t = keras.layers.multiply(inputs=[N0_t, t])
M_t = keras.layers.add(inputs=[M0_t, N_mult_t])

In [77]:
# Compute the true anomaly from the mean anomly and eccentricity
f_t = MeanToTrueAnomaly(name='mean_to_true_anomaly')([M_t, e_t])

# Wrap orbital elements into one tuple of inputs for layer converting to cartesian coordinates
elt_t = (a_t, e_t, inc_t, Omega_t, omega_t, f_t,)

In [78]:
# Report
print(f'M_t.shape = {M_t.shape}')
print(f'f_t.shape = {f_t.shape}')

M_t.shape = (90210, 1)
f_t.shape = (90210, 1)


In [79]:
# Convert orbital elements to heliocentric cartesian coordinates
q_helio, v_helio = ElementToPosition(name='q_helio')(elt_t)

In [80]:
# Report
q_helio.shape

TensorShape([90210, 3])