In [1]:
# Core
import numpy as np
import pandas as pd

# Tensorflow / ML
import tensorflow as tf
from tensorflow.python.keras import backend as K
keras = tf.keras

# MSE imports
import kepler_sieve
from ztf_ast import load_ztf_nearest_ast, calc_hit_freq
from ztf_element import make_ztf_batch, load_ztf_batch
from asteroid_element import load_ast_elt
from candidate_element import asteroid_elts, perturb_elts, random_elts
from asteroid_model import make_model_ast_pos, make_model_ast_dir
from asteroid_model import AsteroidPosition, AsteroidDirection
from astro_utils import dist2deg, dist2sec, deg2dist

## Load ZTF Data and Batch of Orbital Elements

In [2]:
# Load orbital elements for known asteroids
ast_elt = load_ast_elt()

# Number of asteroids
N_ast = ast_elt.shape[0]

In [3]:
# Load ztf nearest asteroid data
ztf_ast = load_ztf_nearest_ast()

In [4]:
# Asteroid numbers and hit counts
ast_num, hit_count = calc_hit_freq(ztf=ztf_ast, thresh_sec=2.0)

# Sort the hit counts in descending order and find the top batch_size
idx = np.argsort(hit_count)[::-1]

# Extract the asteroid number and hit count for this batch
ast_num_best = ast_num[idx]
hit_count_best = hit_count[idx]

In [5]:
# Set element batch size
elt_batch_size = 64

# Batch of unperturbed elements
elts_ast = asteroid_elts(ast_nums=ast_num_best[0:elt_batch_size])

In [6]:
elts_ast

Unnamed: 0,element_id,a,e,inc,Omega,omega,f,epoch
0,51921,2.669306,0.217361,0.499554,4.699703,2.450796,-1.133491,58600.0
1,59244,2.634727,0.262503,0.465045,5.738298,1.766995,-1.601363,58600.0
2,15786,1.883227,0.047655,0.392360,6.134689,0.804823,-1.246069,58600.0
3,3904,2.556387,0.098279,0.261542,5.450163,2.202423,-1.357345,58600.0
4,142999,2.619944,0.191376,0.514017,0.238022,0.946463,-1.299301,58600.0
...,...,...,...,...,...,...,...,...
59,11952,2.219650,0.086091,0.117967,0.042442,2.904823,-3.016580,58600.0
60,134815,2.612770,0.140831,0.513922,0.272689,0.645552,-0.957836,58600.0
61,27860,2.619406,0.096185,0.200633,5.541400,3.266046,3.948770,58600.0
62,85937,2.342291,0.197267,0.439063,5.279693,3.210025,3.947687,58600.0


In [7]:
# Perturb orbital elements
sigma_a = 0.0 
sigma_e = 0.0 
sigma_f_deg = 0.1
sigma_Omega_deg = 0.0
sigma_omega_deg = 0.0
mask_pert = None
random_seed = 42

elts_pert = perturb_elts(elts_ast, sigma_a=sigma_a, sigma_e=sigma_e, sigma_f_deg=sigma_f_deg, 
                         sigma_Omega_deg=sigma_Omega_deg, sigma_omega_deg=sigma_omega_deg,
                         mask_pert=mask_pert, random_seed=random_seed)

In [8]:
elts_pert

Unnamed: 0,element_id,a,e,inc,Omega,omega,f,epoch
0,51921,2.669306,0.217361,0.499554,4.699703,2.450796,-1.133117,58600.0
1,59244,2.634727,0.262503,0.465045,5.738298,1.766995,-1.603537,58600.0
2,15786,1.883227,0.047655,0.392360,6.134689,0.804823,-1.245767,58600.0
3,3904,2.556387,0.098279,0.261542,5.450163,2.202423,-1.356673,58600.0
4,142999,2.619944,0.191376,0.514017,0.238022,0.946463,-1.300844,58600.0
...,...,...,...,...,...,...,...,...
59,11952,2.219650,0.086091,0.117967,0.042442,2.904823,-3.014978,58600.0
60,134815,2.612770,0.140831,0.513922,0.272689,0.645552,-0.954132,58600.0
61,27860,2.619406,0.096185,0.200633,5.541400,3.266046,3.950572,58600.0
62,85937,2.342291,0.197267,0.439063,5.279693,3.210025,3.945035,58600.0


## Batches of ZTF Data vs. Elements

In [9]:
# Arguments to make_ztf_batch
thresh_deg = 1.0
near_ast = False
regenerate = False

In [10]:
# Load unperturbed element batch
ztf_elt_ast = load_ztf_batch(elts=elts_ast, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [11]:
# Load perturbed element batch
ztf_elt_pert = load_ztf_batch(elts=elts_pert, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [12]:
ztf_elt_ast

Unnamed: 0,ztf_id,element_id,ObjectID,CandidateID,TimeStampID,mjd,ra,dec,mag_app,ux,...,vz,elt_ux,elt_uy,elt_uz,elt_r,s,s_sec,z,v,is_hit
0,53851,733,b'ZTF18abnothj',594197584815010004,5501,58348.197581,266.229165,-13.513802,16.755600,-0.063945,...,0.004080,-0.057300,-0.982042,0.179751,2.234078,0.010624,2191.408734,0.999944,0.370552,False
1,73604,733,b'ZTF18ablwzmb',594197584815015003,5501,58348.197581,265.761024,-13.509148,16.035999,-0.071871,...,0.004080,-0.057300,-0.982042,0.179751,2.234078,0.016809,3467.151428,0.999859,0.927559,False
2,82343,733,b'ZTF18abiydvm',635193253015015018,12089,58389.193252,270.331454,-11.244934,17.196199,0.005674,...,0.003825,0.000919,-0.977996,0.208622,2.703478,0.005450,1124.103915,0.999985,0.097503,False
3,257221,733,b'ZTF18acakcqg',931471223715015007,39920,58685.471227,29.693832,42.180412,19.289200,0.643725,...,-0.001953,0.639004,0.610779,0.467571,2.175851,0.008713,1797.091521,0.999962,0.249197,False
4,327000,733,b'ZTF18achmdmw',937465970615015011,40837,58691.465972,33.104905,44.059131,17.725201,0.601970,...,-0.002129,0.606278,0.637608,0.475272,2.114866,0.007949,1639.539679,0.999968,0.207419,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
90206,5650588,324582,b'ZTF20aaqvhld',1150176701515015008,96618,58904.176701,44.164238,29.650540,18.084700,0.623416,...,-0.001541,0.627640,0.750696,0.206212,2.981799,0.008187,1688.638104,0.999966,0.220027,False
90207,5650589,324582,b'ZTF20aaqvhld',1150176245715015005,96617,58904.176250,44.164062,29.650536,18.165199,0.623417,...,-0.001541,0.627641,0.750695,0.206213,2.981793,0.008187,1688.601889,0.999966,0.220018,False
90208,5650665,324582,b'ZTF20aaqvhll',1150176245815015010,96617,58904.176250,44.368640,28.490480,19.025200,0.628284,...,-0.001541,0.627641,0.750695,0.206213,2.981793,0.013370,2757.856412,0.999911,0.586871,False
90209,5650697,324582,b'ZTF20aaqvhmb',1150176246015015005,96617,58904.176250,43.296207,29.505908,19.852800,0.633424,...,-0.001541,0.627641,0.750695,0.206213,2.981793,0.012388,2555.278205,0.999923,0.503822,False


In [13]:
ztf_elt_ast.columns

Index(['ztf_id', 'element_id', 'ObjectID', 'CandidateID', 'TimeStampID', 'mjd',
       'ra', 'dec', 'mag_app', 'ux', 'uy', 'uz', 'qx', 'qy', 'qz', 'vx', 'vy',
       'vz', 'elt_ux', 'elt_uy', 'elt_uz', 'elt_r', 's', 's_sec', 'z', 'v',
       'is_hit'],
      dtype='object')

In [14]:
# Review results
ztf_elt = ztf_elt_ast
element_id_best = ast_num_best[0]
mask = (ztf_elt.element_id == element_id_best)
hits_best = np.sum(ztf_elt[mask].is_hit)
s_sec_min = np.min(ztf_elt[mask].s_sec)
idx = np.argmin(ztf_elt.s)
ztf_id = ztf_elt.ztf_id[idx]
# ztf_elt[mask].iloc[idx:idx+1]
print(f'Best asteroid has element_id = {element_id_best}')
print(f'Hit count: {hits_best}')
print(f'Closest hit: {s_sec_min:0.3f} arc seconds')
# ztf_elt[mask]

Best asteroid has element_id = 51921
Hit count: 158
Closest hit: 0.381 arc seconds


## Load Position and Direction Models

In [15]:
# Data types
dtype = tf.float32
dtype_np = np.float32

In [16]:
# Alias ztf_elt
ztf_elt = ztf_elt_ast.copy()

In [17]:
# Build numpy array of times
ts_np = ztf_elt.mjd.values.astype(dtype_np)

# Build tensor of flattened times
ts = keras.backend.constant(ts_np)

In [18]:
# Get observation count per element
row_lengths_np = ztf_elt.element_id.groupby(ztf_elt.element_id).count().values.astype(np.int32)

# Tensor of row lengths
row_lengths = keras.backend.constant(row_lengths_np, dtype=tf.int32)

In [19]:
row_lengths

<tf.Tensor: shape=(64,), dtype=int32, numpy=
array([1424,  936, 1117, 1087, 1035, 1028,  780, 2747,  926, 2485, 1023,
       1013, 1063,  858, 1943,  950,  636, 1199, 1547, 1477,  991,  790,
       3205, 2719, 1613,  771, 1047, 1253,  845, 3084,  866, 1474, 2678,
        802, 1653,  708, 1043,  855, 2362, 1012, 1048, 1677,  916,  813,
       1101, 1180, 2708,  688, 1929, 2871, 1099,  743,  943, 2018, 2717,
        595,  864,  941, 2628, 3127,  943, 1726,  762, 1129], dtype=int32)>

In [20]:
# Build ragged tensor of input times
ts_r = tf.RaggedTensor.from_row_lengths(values=ztf_elt.mjd.values.astype(dtype_np), row_lengths=row_lengths)

In [21]:
# Unique times
ts_unq = np.unique(ztf_elt_ast.mjd)
TimeStampID_unq = np.unique(ztf_elt_ast.TimeStampID)

# The epoch
epoch0 = elts_ast.epoch[0]

In [22]:
# Report time tensor shapes
print(f'ts.shape={ts.shape}')
print(f'ts.shape={ts_r.shape}')
# print(f'ts_flat.shape={ts_flat.shape}')
print(f'ts_unq.shape={ts_unq.shape}')

ts.shape=(90211,)
ts.shape=(64, None)
ts_unq.shape=(6383,)


In [23]:
# Observation site
site_name = 'palomar'

In [24]:
# Build position model
model_pos = make_model_ast_pos(ts_np=ts_np, row_lengths_np=row_lengths_np)

In [25]:
# Build direction model
model_dir = make_model_ast_dir(ts_np=ts_np, row_lengths_np=row_lengths_np, site_name=site_name)

In [26]:
# Stack elements as a dict of numpy arrays for prediction
cols_elt = ['a', 'e', 'inc', 'Omega', 'omega', 'f', 'epoch']
elts_ast_dict = {col : elts_ast[col].values for col in cols_elt}
# elts_ast_dict

## Run and Calibrate Position Model

In [27]:
model_pos.summary()

Model: "model_asteroid_pos"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
a (InputLayer)                  [(64,)]              0                                            
__________________________________________________________________________________________________
Omega (InputLayer)              [(64,)]              0                                            
__________________________________________________________________________________________________
e (InputLayer)                  [(64,)]              0                                            
__________________________________________________________________________________________________
epoch (InputLayer)              [(64,)]              0                                            
_________________________________________________________________________________

In [28]:
# Predict position model
q_pred, v_pred = model_pos.predict(elts_ast_dict)

In [29]:
# Review shape of predictions
print(f'q_pred.shape = {q_pred.shape}')
print(f'v_pred.shape = {v_pred.shape}')

q_pred.shape = (64, None, 3)
v_pred.shape = (64, None, 3)


In [30]:
# Calibration arrays (flat)
cols_q_ast = ['qx', 'qy', 'qz']
cols_v_ast = ['vx', 'vy', 'vz']
q_ast = ztf_elt[cols_q_ast].values.astype(dtype_np)
v_ast = ztf_elt[cols_v_ast].values.astype(dtype_np)

In [31]:
# Run calibration
model_pos.ast_pos_layer.calibrate(elts=elts_ast, q_ast=q_ast, v_ast=v_ast)

In [32]:
# Check that corrections during calibration aren't too large
mean_dq = np.mean(tf.linalg.norm(model_pos.ast_pos_layer.dq, axis=1))
mean_dv = np.mean(tf.linalg.norm(model_pos.ast_pos_layer.dv, axis=1))

print(f'Mean calibration adjustments:')
print(f'mean_dq = {mean_dq:6.2e}')
print(f'mean_dv = {mean_dv:6.2e}')

Mean calibration adjustments:
mean_dq = 9.85e-05
mean_dq = 9.30e-07


## Predict Direction Model

In [None]:
model_dir.summary()

In [None]:
# Calibrate direction model's position layer
model_dir.ast_pos_layer.calibrate(elts=elts_ast, q_ast=q_ast, v_ast=v_ast)

In [None]:
# Predict calibrated direction model
u_pred, r_pred = model_dir.predict(elts_ast_dict)

In [None]:
# Review shape of predictions
print(f'u_pred.shape = {u_pred.shape}')
print(f'r_pred.shape = {r_pred.shape}')

In [None]:
# Compare to direction on the ztf_elt frame (should be the same!)
cols_u = ['elt_ux', 'elt_uy', 'elt_uz']
u_exp = ztf_elt[cols_u].values
u_diff = u_pred.values - u_exp
mean_diff_s = np.mean(np.linalg.norm(u_diff, axis=1))
mean_diff_sec = dist2sec(mean_diff_s)

# Report results
print(f'Mean direction difference:')
print(f'Cartesian  : {mean_diff_s:8.2e}')
print(f'Arc Seconds: {mean_diff_sec:6.3f}')

## Assemble Tensors for Prototype Model

In [None]:
# Orbital elements
a = tf.Variable(initial_value=elts_ast.a, dtype=dtype, name='a')
e = tf.Variable(initial_value=elts_ast.e, dtype=dtype, name='e')
inc = tf.Variable(initial_value=elts_ast.inc, dtype=dtype, name='inc')
Omega = tf.Variable(initial_value=elts_ast.Omega, dtype=dtype, name='Omega')
omega = tf.Variable(initial_value=elts_ast.omega, dtype=dtype, name='omega')
f = tf.Variable(initial_value=elts_ast.f, dtype=dtype, name='f')
epoch = tf.constant(value=elts_ast.epoch, dtype=dtype, name='epoch')

In [None]:
# Review element shapes (all the same)
print(f'a.shape = {a.shape}')

In [None]:
# Review ztf_elt on best candidate
mask_best = (ztf_elt.element_id == element_id_best)
ztf_elt_best = ztf_elt[mask_best]
ztf_elt_best

In [None]:
# Build ragged tensor of u_obs
cols_u_obs = ['ux', 'uy', 'uz']
u_obs_r = tf.RaggedTensor.from_row_lengths(values=ztf_elt[cols_u_obs].values, row_lengths=row_lengths)

# Review shape
print(f'u_obs_r.shape = {u_obs_r.shape}')

In [None]:
# Latest time
t_max = np.max(ztf_elt.mjd.values)

# Time for padded values
t_pad = t_max + 1.0

In [None]:
# Convert ragged to 2D
ts_2d = ts_r.to_tensor(default_value=t_pad)
ts_2d_np = ts_2d.numpy()

print(f'ts_2d.shape = {ts_2d.shape}')

## Calculate Position and Direction Using Layers in asteroid_model.py

In [None]:
# Build position layer
ast_pos_layer = AsteroidPosition(ts_np=ts_np, row_lengths_np=row_lengths_np, name='ast_pos_layer')

In [None]:
# Build direction layer
direction_layer = AsteroidDirection(ts_np=ts_np, row_lengths_np=row_lengths_np, site_name=site_name, name='u_pred')

In [None]:
# Predict position using the layer
q_pred, v_pred = ast_pos_layer(a, e, inc, Omega, omega, f, epoch)

In [None]:
# Review output
print(f'q_pred.shape = {q_pred.shape}')
print(f'v_pred.shape = {q_pred.shape}')

In [None]:
# Run calibration
ast_pos_layer.calibrate(elts=elts_ast, q_ast=q_ast, v_ast=v_ast)

In [None]:
# Check that corrections during calibration aren't too large
mean_dq = np.mean(tf.linalg.norm(ast_pos_layer.dq, axis=1))
mean_dv = np.mean(tf.linalg.norm(ast_pos_layer.dv, axis=1))

print(f'Mean calibration adjustments:')
print(f'mean_dq = {mean_dq:6.2e}')
print(f'mean_dq = {mean_dv:6.2e}')

In [None]:
# Convert q, v to ragged tensors matching the element batch
q_r = tf.RaggedTensor.from_row_lengths(values=q_pred, row_lengths=row_lengths, name='q_r')
v_r = tf.RaggedTensor.from_row_lengths(values=v_pred, row_lengths=row_lengths, name='v_r')

In [None]:
# Report shapes
print(f'q_r.shape={q_r.shape}')

In [None]:
# Predict direction using the layer
u_pred, r_pred = direction_layer(a, e, inc, Omega, omega, f, epoch)

In [None]:
# Report shapes
print(f'u_pred.shape={u_pred.shape}')
print(f'r_pred.shape={r_pred.shape}')

## Create Input Tensors for Functional API Model Debugging

In [None]:
# Input tensors
in_a = keras.Input(shape=(), batch_size=elt_batch_size, name='a')
in_e = keras.Input(shape=(), batch_size=elt_batch_size, name='e')
in_inc = keras.Input(shape=(), batch_size=elt_batch_size, name='inc')
in_Omega = keras.Input(shape=(), batch_size=elt_batch_size, name='Omega')
in_omega = keras.Input(shape=(), batch_size=elt_batch_size, name='omega')
in_f = keras.Input(shape=(), batch_size=elt_batch_size, name='f')
in_epoch = keras.Input(shape=(), batch_size=elt_batch_size, name='epoch')

# Wrap inputs
inputs = (in_a, in_e, in_inc, in_Omega, in_omega, in_f, in_epoch)

In [None]:
in_a

In [None]:
a

In [None]:
data_size = tf.reduce_sum(row_lengths)
elt_shape = (data_size, 1,)

In [None]:
# Create output tensor
ast_pos_layer = AsteroidPosition(ts_np=ts_np, row_lengths_np=row_lengths_np)
q_flat, v_flat = ast_pos_layer(in_a, in_e, in_inc, in_Omega, in_omega, in_f, in_epoch)

In [None]:
# Predict direction
u_pred, r_pred = direction_layer(in_a, in_e, in_inc, in_Omega, in_omega, in_f, in_epoch)

In [None]:
# Review output
print(f'u_pred.shape = {u_pred.shape}')
print(f'r_pred.shape = {r_pred.shape}')

In [None]:
# Build test position model
model_pos_test = keras.Model(inputs=inputs, outputs=[q_flat, v_flat])

In [None]:
# Build test direction model
model_dir_test = keras.Model(inputs=inputs, outputs=[u_pred, r_pred])

In [None]:
model_dir_test([a, e, inc, Omega, omega, f, epoch])

## Imports for Step by Step Calculations

In [None]:
import astropy
from astropy.units import au, day, year

# Local imports
from orbital_element import MeanToTrueAnomaly, TrueToMeanAnomaly
from asteroid_data import get_earth_pos, get_sun_pos_vel
from asteroid_model import ElementToPosition
from ra_dec import calc_topos

In [None]:
# Constants

# The gravitational constant in ('day', 'AU', 'Msun') coordinates
# Hard code G
G_ = 0.00029591220828559104
# The gravitational field strength mu = G * (m0 + m1)
# For massless asteroids orbiting the sun with units Msun, m0=1.0, m1=0.0, and mu = G
mu = tf.constant(G_)

# Speed of light; express this in AU / day
light_speed_au_day = astropy.constants.c.to(au / day).value

# Number of spatial dimensions
space_dims = 3

# Data types
dtype = tf.float32
dtype_np = np.float32

In [None]:
# Build direction layer
direction_layer = AsteroidDirection(ts_np=ts_np, row_lengths_np=row_lengths_np, site_name=site_name, name='u_pred')

In [None]:
# Predict direction layer
u_pred, r_pred = direction_layer(a, e, inc, Omega, omega, f, epoch)

In [None]:
u_pred.shape

## Step by Step Calculation of Asteroid Position

In [None]:
epoch.shape

In [None]:
target_shape = (1,)

In [None]:
# Time relative to epoch
epoch_t  = keras.layers.Reshape(target_shape, name='epoch_t')(tf.repeat(epoch, row_lengths))
t = keras.layers.subtract([ts, epoch_t], name='t')

In [None]:
# Report
print(f'ts.shape = {ts.shape}')
print(f'epoch.shape = {epoch.shape}')
print(f't.shape = {t.shape}')

In [None]:
# Compute eccentric anomaly E from f and e
M = TrueToMeanAnomaly(name='TrueToMeanAnomaly')([f, e])

# Compute mean motion N from mu and a
a3 = tf.math.pow(a, 3, name='a3')
mu_over_a3 = tf.divide(mu, a3, name='mu_over_a3')
N = tf.sqrt(mu_over_a3, name='N')

In [None]:
# Repeat the constant orbital elements to be vectors of shape (batch_size, traj_size, 1)
a_t = keras.layers.Reshape(target_shape, name='a_t')(tf.repeat(a, row_lengths))
e_t = keras.layers.Reshape(target_shape, name='e_t')(tf.repeat(e, row_lengths))
inc_t = keras.layers.Reshape(target_shape, name='inc_t')(tf.repeat(inc, row_lengths))
Omega_t = keras.layers.Reshape(target_shape, name='Omega_t')(tf.repeat(Omega, row_lengths))
omega_t = keras.layers.Reshape(target_shape, name='omega_t')(tf.repeat(omega, row_lengths))

In [None]:
# Report
print(f'a_t.shape = {a_t.shape}')

In [None]:
# Repeat initial mean anomaly M0 and mean motion N0 to match shape of outputs
M0_t  = keras.layers.Reshape(target_shape, name='M0_t')(tf.repeat(M, row_lengths))
N0_t  = keras.layers.Reshape(target_shape, name='N0_t')(tf.repeat(N, row_lengths))
# Compute the mean anomaly M(t) as a function of time
N_mult_t = keras.layers.multiply(inputs=[N0_t, t])
M_t = keras.layers.add(inputs=[M0_t, N_mult_t])

In [None]:
# Compute the true anomaly from the mean anomly and eccentricity
f_t = MeanToTrueAnomaly(name='mean_to_true_anomaly')([M_t, e_t])

# Wrap orbital elements into one tuple of inputs for layer converting to cartesian coordinates
elt_t = (a_t, e_t, inc_t, Omega_t, omega_t, f_t,)

In [None]:
# Report
print(f'M_t.shape = {M_t.shape}')
print(f'f_t.shape = {f_t.shape}')

In [None]:
# Convert orbital elements to heliocentric cartesian coordinates
q_helio, v_helio = ElementToPosition(name='q_helio')(elt_t)

In [None]:
# Report
q_helio.shape