In [1]:
# Core
import numpy as np
import pandas as pd

# Tensorflow / ML
import tensorflow as tf
from tensorflow.python.keras import backend as K

# MSE imports
import kepler_sieve
from ztf_ast import load_ztf_nearest_ast, calc_hit_freq
from ztf_element import load_ztf_batch, make_ztf_batch
from asteroid_element import load_ast_elt
from asteroid_integrate import load_ast_pos
from candidate_element import asteroid_elts, perturb_elts, random_elts

# MSE imports for prototyping / testing
from asteroid_model import make_model_ast_pos, make_model_ast_dir
from asteroid_model import AsteroidPosition, AsteroidDirection
from asteroid_search_layers import CandidateElements, TrajectoryScore, R2lam
from asteroid_search_model import AsteroidSearchModel
from astro_utils import deg2dist, dist2deg, dist2sec

Found 4 GPUs.  Setting memory growth = True.


In [2]:
keras = tf.keras

## Load ZTF Data and Batch of Orbital Elements

In [3]:
# Load orbital elements for known asteroids
ast_elt = load_ast_elt()

# Number of asteroids
N_ast = ast_elt.shape[0]

In [4]:
# Load ztf nearest asteroid data
ztf_ast = load_ztf_nearest_ast()

In [5]:
# Asteroid numbers and hit counts
ast_num, hit_count = calc_hit_freq(ztf=ztf_ast, thresh_sec=2.0)

# Sort the hit counts in descending order and find the top batch_size
idx = np.argsort(hit_count)[::-1]

# Extract the asteroid number and hit count for this batch
ast_num_best = ast_num[idx]
hit_count_best = hit_count[idx]

In [6]:
# Set batch size
batch_size = 64
elt_batch_size = batch_size

# Batch of unperturbed elements
elts_ast = asteroid_elts(ast_nums=ast_num_best[0:batch_size])

In [7]:
elts_ast

Unnamed: 0,element_id,a,e,inc,Omega,omega,f,epoch
0,51921,2.669306,0.217361,0.499554,4.699703,2.450796,-1.133491,58600.0
1,59244,2.634727,0.262503,0.465045,5.738297,1.766995,-1.601363,58600.0
2,15786,1.883227,0.047655,0.392360,6.134689,0.804823,-1.246069,58600.0
3,3904,2.556387,0.098279,0.261542,5.450163,2.202423,-1.357345,58600.0
4,142999,2.619944,0.191376,0.514017,0.238022,0.946463,-1.299301,58600.0
...,...,...,...,...,...,...,...,...
59,11952,2.219650,0.086091,0.117967,0.042442,2.904823,-3.016580,58600.0
60,134815,2.612770,0.140831,0.513923,0.272689,0.645552,-0.957836,58600.0
61,27860,2.619406,0.096185,0.200633,5.541399,3.266046,3.948770,58600.0
62,85937,2.342292,0.197267,0.439063,5.279693,3.210025,3.947687,58600.0


In [8]:
# Perturb orbital elements
sigma_a = 0.0 
sigma_e = 0.0 
sigma_f_deg = 0.1
sigma_Omega_deg = 0.0
sigma_omega_deg = 0.0
mask_pert = None
random_seed = 42

elts_pert = perturb_elts(elts_ast, sigma_a=sigma_a, sigma_e=sigma_e, sigma_f_deg=sigma_f_deg, 
                         sigma_Omega_deg=sigma_Omega_deg, sigma_omega_deg=sigma_omega_deg,
                         mask_pert=mask_pert, random_seed=random_seed)

In [9]:
elts_pert

Unnamed: 0,element_id,a,e,inc,Omega,omega,f,epoch
0,51921,2.669306,0.217361,0.499554,4.699703,2.450796,-1.133117,58600.0
1,59244,2.634727,0.262503,0.465045,5.738297,1.766995,-1.603537,58600.0
2,15786,1.883227,0.047655,0.392360,6.134689,0.804823,-1.245767,58600.0
3,3904,2.556387,0.098279,0.261542,5.450163,2.202423,-1.356673,58600.0
4,142999,2.619945,0.191376,0.514017,0.238022,0.946463,-1.300844,58600.0
...,...,...,...,...,...,...,...,...
59,11952,2.219650,0.086091,0.117967,0.042442,2.904823,-3.014978,58600.0
60,134815,2.612770,0.140831,0.513923,0.272689,0.645552,-0.954132,58600.0
61,27860,2.619406,0.096185,0.200633,5.541399,3.266046,3.950572,58600.0
62,85937,2.342292,0.197267,0.439063,5.279693,3.210025,3.945035,58600.0


## Batches of ZTF Data vs. Elements

In [10]:
# Arguments to make_ztf_batch
thresh_deg = 1.0
near_ast = False
regenerate = False

In [11]:
# Load unperturbed element batch
ztf_elt_ast = load_ztf_batch(elts=elts_ast, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [12]:
# Load perturbed element batch
ztf_elt_pert = load_ztf_batch(elts=elts_pert, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [13]:
ztf_elt_ast

Unnamed: 0,ztf_id,element_id,ObjectID,CandidateID,TimeStampID,mjd,ra,dec,ux,uy,...,vz,elt_ux,elt_uy,elt_uz,elt_r,s,s_sec,z,v,is_hit
0,53851,733,b'ZTF18abnothj',594197584815010004,5501,58348.197581,266.229165,-13.513802,-0.063945,-0.983101,...,0.004080,-0.057300,-0.982042,0.179751,2.234078,0.010624,2191.371398,0.999944,0.370539,False
1,73604,733,b'ZTF18ablwzmb',594197584815015003,5501,58348.197581,265.761024,-13.509148,-0.071871,-0.982578,...,0.004080,-0.057300,-0.982042,0.179751,2.234078,0.016809,3467.103003,0.999859,0.927533,False
2,82343,733,b'ZTF18abiydvm',635193253015015018,12089,58389.193252,270.331454,-11.244934,0.005674,-0.977422,...,0.003825,0.000918,-0.977996,0.208622,2.703478,0.005450,1124.142942,0.999985,0.097510,False
3,257221,733,b'ZTF18acakcqg',931471223715015007,39920,58685.471227,29.693832,42.180412,0.643725,0.603886,...,-0.001953,0.639004,0.610779,0.467571,2.175851,0.008712,1797.042210,0.999962,0.249184,False
4,327000,733,b'ZTF18achmdmw',937465970615015011,40837,58691.465972,33.104905,44.059131,0.601970,0.636719,...,-0.002129,0.606278,0.637608,0.475272,2.114865,0.007949,1639.537152,0.999968,0.207418,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
90205,5650588,324582,b'ZTF20aaqvhld',1150176701515015008,96618,58904.176701,44.164238,29.650540,0.623416,0.752309,...,-0.001541,0.627640,0.750696,0.206212,2.981799,0.008187,1688.636853,0.999966,0.220027,False
90206,5650589,324582,b'ZTF20aaqvhld',1150176245715015005,96617,58904.176250,44.164062,29.650536,0.623417,0.752307,...,-0.001541,0.627641,0.750695,0.206213,2.981793,0.008187,1688.600639,0.999966,0.220018,False
90207,5650665,324582,b'ZTF20aaqvhll',1150176245815015010,96617,58904.176250,44.368640,28.490480,0.628284,0.753618,...,-0.001541,0.627641,0.750695,0.206213,2.981793,0.013370,2757.856469,0.999911,0.586871,False
90208,5650697,324582,b'ZTF20aaqvhmb',1150176246015015005,96617,58904.176250,43.296207,29.505908,0.633424,0.743491,...,-0.001541,0.627641,0.750695,0.206213,2.981793,0.012388,2555.279465,0.999923,0.503822,False


In [14]:
ztf_elt_ast.columns

Index(['ztf_id', 'element_id', 'ObjectID', 'CandidateID', 'TimeStampID', 'mjd',
       'ra', 'dec', 'ux', 'uy', 'uz', 'qx', 'qy', 'qz', 'vx', 'vy', 'vz',
       'elt_ux', 'elt_uy', 'elt_uz', 'elt_r', 's', 's_sec', 'z', 'v',
       'is_hit'],
      dtype='object')

In [15]:
# Review results
ztf_elt = ztf_elt_ast
element_id_best = ast_num_best[0]
mask = (ztf_elt.element_id == element_id_best)
hits_best = np.sum(ztf_elt[mask].is_hit)
s_sec_min = np.min(ztf_elt[mask].s_sec)
idx = np.argmin(ztf_elt.s)
ztf_id = ztf_elt.ztf_id[idx]
# ztf_elt[mask].iloc[idx:idx+1]
print(f'Best asteroid has element_id = {element_id_best}')
print(f'Hit count: {hits_best}')
print(f'Closest hit: {s_sec_min:0.3f} arc seconds')
# ztf_elt[mask]

Best asteroid has element_id = 51921
Hit count: 158
Closest hit: 0.382 arc seconds


## Load Direction Model

In [16]:
# Data types
dtype = tf.float32
dtype_np = np.float32

In [17]:
# Alias ztf_elt
ztf_elt = ztf_elt_ast.copy()

In [18]:
# Build numpy array of times
ts_np = ztf_elt.mjd.values.astype(dtype_np)

# Get observation count per element
row_lengths_np = ztf_elt.element_id.groupby(ztf_elt.element_id).count().values.astype(np.int32)

In [19]:
# Build tensor of flattened times
ts = keras.backend.constant(value=ts_np, shape=ts_np.shape, dtype=dtype)

# Tensor of row lengths
row_lengths = keras.backend.constant(value=row_lengths_np, shape=row_lengths_np.shape, dtype=tf.int32)

In [20]:
# Observation site
site_name = 'palomar'

In [21]:
# Build direction model
model_dir = make_model_ast_dir(ts_np=ts_np, row_lengths_np=row_lengths_np, site_name=site_name)

In [22]:
# Calibration arrays (flat)
cols_q_ast = ['qx', 'qy', 'qz']
cols_v_ast = ['vx', 'vy', 'vz']
q_ast = ztf_elt[cols_q_ast].values.astype(dtype_np)
v_ast = ztf_elt[cols_v_ast].values.astype(dtype_np)

In [23]:
# Run calibration
model_dir.ast_pos_layer.calibrate(elts=elts_ast, q_ast=q_ast, v_ast=v_ast)

## Predict and Test Direction Model

In [24]:
# Stack elements as a dict of numpy arrays for prediction
cols_elt = ['a', 'e', 'inc', 'Omega', 'omega', 'f', 'epoch']
elts_ast_dict = {col : elts_ast[col].values for col in cols_elt}

In [25]:
# Predict calibrated direction model
u_pred, r_pred = model_dir.predict(elts_ast_dict)

In [26]:
# Compare to direction on the ztf_elt frame (should be the same!)
cols_u = ['elt_ux', 'elt_uy', 'elt_uz']
u_exp = ztf_elt[cols_u].values
u_diff = u_pred.values - u_exp
mean_diff_s = np.mean(np.linalg.norm(u_diff, axis=1))
mean_diff_sec = dist2sec(mean_diff_s)

# Report results
print(f'Mean direction difference:')
print(f'Cartesian  : {mean_diff_s:8.2e}')
print(f'Arc Seconds: {mean_diff_sec:6.3f}')

Mean direction difference:
Cartesian  : 7.86e-06
Arc Seconds:  1.622


## Tensors and Layers for Eager Mode

In [27]:
# Orbital elements
a = tf.Variable(initial_value=elts_ast.a, dtype=dtype, name='a')
e = tf.Variable(initial_value=elts_ast.e, dtype=dtype, name='e')
inc = tf.Variable(initial_value=elts_ast.inc, dtype=dtype, name='inc')
Omega = tf.Variable(initial_value=elts_ast.Omega, dtype=dtype, name='Omega')
omega = tf.Variable(initial_value=elts_ast.omega, dtype=dtype, name='omega')
f = tf.Variable(initial_value=elts_ast.f, dtype=dtype, name='f')
epoch = tf.constant(value=elts_ast.epoch, dtype=dtype, name='epoch')

In [28]:
# Thresholds
thresh_s = deg2dist(thresh_deg).astype(dtype_np)
thresh_s2 = (thresh_s**2).astype(dtype_np)
thresh_z = (1.0 - thresh_s2 / 2.0).astype(dtype_np)

In [29]:
# Convert resolution from degrees to Cartesian
R_deg_np = np.ones(elt_batch_size, dtype=dtype_np)*1.0
R_np = deg2dist(R_deg_np)
lam_np = R2lam(R=R_np, thresh_s=thresh_s)

In [30]:
# Mixture probability paramters: h and lam
h = tf.Variable(initial_value=np.ones(elt_batch_size)*0.5, dtype=dtype, name='h')
# lam = tf.Variable(initial_value=np.ones(elt_batch_size)*1.0, dtype=dtype, name='lam')
R = tf.Variable(initial_value=R_np, dtype=dtype, name='R')
lam = tf.Variable(initial_value=lam_np, dtype=dtype, name='lam')

In [31]:
# Observed directions
cols_u_obs = ['ux', 'uy', 'uz']
u_obs_np = ztf_elt[cols_u_obs].values.astype(dtype_np)
u_obs = tf.constant(value=u_obs_np, dtype=dtype, name='u_obs')

In [32]:
# Review shape
print(f'u_obs.shape = {u_obs.shape}')

u_obs.shape = (90210, 3)


In [33]:
# Extract direction layers
ast_pos_layer = model_dir.ast_pos_layer
ast_dir_layer = model_dir.ast_dir_layer

In [34]:
# Predict in eager mode
q_flat, v_flat = ast_pos_layer(a, e, inc, Omega, omega, f, epoch)
u_pred, r_pred = ast_dir_layer(a, e, inc, Omega, omega, f, epoch)

In [35]:
# Review shapes
print(f'q_flat.shape = {q_flat.shape}')
print(f'v_flat.shape = {v_flat.shape}')
print(f'u_pred.shape = {u_pred.shape}')
print(f'r_pred.shape = {r_pred.shape}')

q_flat.shape = (90210, 3)
v_flat.shape = (90210, 3)
u_pred.shape = (90210, 3)
r_pred.shape = (90210, 1)


## TrajectoryScore Layer

In [36]:
# Build score layer
score_layer = TrajectoryScore(u_obs_np=u_obs_np, row_lengths_np=row_lengths_np, thresh_deg=thresh_deg)

In [37]:
# Test score layer
log_like = score_layer(u_pred, h, lam)

In [38]:
# Report shape
print(f'log_like.shape = {log_like.shape}')

log_like.shape = (64,)


## Prototype TrajectoryScore

In [39]:
# Thresholds
thresh_s = deg2dist(thresh_deg).astype(dtype_np)
thresh_s2 = (thresh_s**2).astype(dtype_np)
thresh_z = (1.0 - thresh_s2 / 2.0).astype(dtype_np)

In [40]:
# Difference between actual and predicted directions
du = keras.layers.subtract(inputs=[u_pred, u_obs], name='du')
# Squared distance bewteen predicted and observed directions
s2 = K.sum(tf.square(du), axis=(-1))

In [41]:
# Review shapes
print(f'du.shape = {du.shape}')
print(f's2.shape = {s2.shape}')

du.shape = (90210, 3)
s2.shape = (90210,)


In [42]:
# Look at relevant columns
cols_disp = ['ztf_id', 'element_id', 'ux', 'uy', 'uz', 'elt_ux', 'elt_uy', 'elt_uz', 'z', 'v', 'is_hit']
ztf_elt.iloc[0:3][cols_disp]

Unnamed: 0,ztf_id,element_id,ux,uy,uz,elt_ux,elt_uy,elt_uz,z,v,is_hit
0,53851,733,-0.063945,-0.983101,0.17153,-0.0573,-0.982042,0.179751,0.999944,0.370539,False
1,73604,733,-0.071871,-0.982578,0.171389,-0.0573,-0.982042,0.179751,0.999859,0.927533,False
2,82343,733,0.005674,-0.977422,0.211222,0.000918,-0.977996,0.208622,0.999985,0.09751,False


In [43]:
# Filter to only include terms where z2 is within the threshold distance^2
# is_close = s2 < thresh_s2
is_close = tf.math.less(s2, thresh_s2)

In [44]:
# Review shape
close_frac = np.mean(is_close)
print(f'is_close.shape = {is_close.shape}')
print(f'close_frac = {close_frac:8.6f}')

is_close.shape = (90210,)
close_frac = 0.999789


In [45]:
# Relative distance v on data inside threshold
# s2[is_close] / thresh_s2
# v = tf.divide(s2[is_close], thresh_s2, name='v')
v = tf.divide(tf.boolean_mask(tensor=s2, mask=is_close), thresh_s2, name='v')

In [46]:
# Review shape, summary of v
mean_v = np.mean(v)
mean_log_v = np.mean(np.log(v))
print(f'v.shape = {v.shape}')
print(f'mean_v     = {mean_v:9.6f}')
print(f'mean_log_v = {mean_log_v:9.6f}')

v.shape = (90191,)
mean_v     =  0.437990
mean_log_v = -2.683210


In [47]:
# Row_lengths, for close observations only
is_close_r = tf.RaggedTensor.from_row_lengths(is_close, row_lengths=row_lengths)
row_lengths_close = tf.reduce_sum(tf.cast(is_close_r, tf.int32), axis=1)

# Shape of parameters
close_size = tf.reduce_sum(row_lengths_close)
param_shape = (close_size,)

In [48]:
# Review row_lengths
row_lengths_close

<tf.Tensor: shape=(64,), dtype=int32, numpy=
array([1424,  936, 1117, 1087, 1035, 1028,  780, 2744,  925, 2485, 1023,
       1013, 1063,  858, 1943,  949,  636, 1199, 1547, 1477,  991,  790,
       3204, 2719, 1612,  771, 1047, 1253,  845, 3081,  866, 1472, 2677,
        802, 1652,  708, 1041,  855, 2362, 1012, 1048, 1677,  916,  812,
       1100, 1180, 2708,  688, 1929, 2871, 1099,  743,  943, 2018, 2716,
        595,  864,  941, 2628, 3126,  943, 1726,  762, 1129], dtype=int32)>

In [49]:
# Upsample h and lambda
h_vec = tf.reshape(tensor=tf.repeat(h, row_lengths_close), shape=param_shape, name='h_vec')
lam_vec = tf.reshape(tensor=tf.repeat(h, row_lengths_close), shape=param_shape, name='lam_vec')

In [50]:
# Probability according to mixture model
p = h_vec * tf.exp(-lam_vec * v) + (1.0 - h_vec)
log_p = keras.layers.Activation(tf.math.log, name='log_p')(p)

In [51]:
# Rearrange to ragged tensors
log_p_r = tf.RaggedTensor.from_row_lengths(log_p, row_lengths=row_lengths_close)

In [52]:
# Log likelihood by element
log_like = tf.reduce_sum(log_p_r, axis=1)

In [53]:
# Summary statistics
mean_log_p = np.mean(log_like)
std_log_p = np.std(log_like)

# Review shape and summary
print(f'log_p.shape = {log_p.shape}')
print(f'log_like.shape = {log_like.shape}')
print(f'mean_log_p = {mean_log_p:9.6f}')
print(f'std_log_p  = {std_log_p:9.6f}')

log_p.shape = (90191,)
log_like.shape = (64,)
mean_log_p = -141.581116
std_log_p  = 83.229546


## Candidate Elements Layer

In [54]:
# Alias elts
elts = elts_ast

In [55]:
# Build elements layer
elements_layer = CandidateElements(elts=elts, h=0.125, R=R_np, name='candidates')

In [56]:
# Extract elements
a, e, inc, Omega, omega, f, epoch, h, lam, R = elements_layer(inputs=None)

In [57]:
# Review shapes and first few a values
print(f'a.shape = {a.shape}')
print(a[0:5])

a.shape = (64,)
tf.Tensor([2.669305  2.634727  1.8832275 2.5563874 2.6199443], shape=(5,), dtype=float32)


In [58]:
def make_model_elts(elts: pd.DataFrame, h: float = 0.5, R_deg: float = 1.0):

    # Number of elements
    elt_batch_size = elts.shape[0]
    
    # Dummy inputs
    x = keras.Input(shape=(), batch_size=elt_batch_size, name='x')
    
    # Wrap up inputs
    inputs = (x, )
    
    # Convert R to Cartesian
    R = deg2dist(R_deg)
    
    # Set of trainable weights with candidate orbital elements; initialize according to elts
    elements_layer = CandidateElements(elts=elts, h=h, R=R, name='candidates')
    
    # Extract the candidate elements and mixture parameters; pass dummy inputs to satisfy keras Layer API
    a, e, inc, Omega, omega, f, epoch, h, lam, R = elements_layer(inputs=x)

    # Wrap up the outputs
    outputs = (a, e, inc,)

    # Wrap this into a model
    model = keras.Model(inputs=inputs, outputs=outputs, name='model_elements')

    return model

In [59]:
# Dummy inputs
x = np.ones(elt_batch_size)

In [60]:
model_elts = make_model_elts(elts=elts)

In [61]:
# model_elts(x)

## Test make_model_asteroid_search()

In [62]:
# Parameters to build model
h = 1.0 / 64.0
R_deg = 1.0
learning_rate = 1.0E-4
clipnorm = 1.0

In [63]:
# Build asteroid search model
model = AsteroidSearchModel(
        elts=elts_ast, ztf_elt=ztf_elt, site_name=site_name,
        thresh_deg=thresh_deg, h=h, R_deg=R_deg,
        learning_rate=learning_rate, clipnorm=clipnorm)

In [64]:
# Dummy inputs for search model; any array with shape [elt_batch_size,] is good
x_np = np.ones(elt_batch_size)

In [65]:
# Run model
log_like, elts_tf, mixture = model(x_np)

In [66]:
# Review shapes
print(f'log_like.shape = {log_like.shape}')
print(f'elts_tf.shape = {elts_tf.shape}')
print(f'mixture.shape = {mixture.shape}')
print(f'u_pred.shape = {u_pred.shape}')

log_like.shape = (64,)
elts_tf.shape = (448,)
mixture.shape = (192,)
u_pred.shape = (90210, 3)


In [67]:
# Review log likelihoods
print(log_like[0:5])

tf.Tensor([0.903787  0.7684468 1.0551798 0.6134815 0.6689773], shape=(5,), dtype=float32)


In [68]:
# Review mixture parameters
print(mixture[0:5])

tf.Tensor([0.015625 0.015625 0.015625 0.015625 0.015625], shape=(5,), dtype=float32)


In [69]:
# Use Adam optimizer with gradient clipping
# learning_rate = 2.0e-5
learning_rate = 0.001   # default 1.0E-3
beta_1 = 0.900          # default 0.900
beta_2 = 0.999          # default 0.999
epsilon = 1.0E-7        # default 1.0E-7
amsgrad = False         # default False
clipvalue = None        # default not used
# Optimizer arguments
opt_args = {
    'learning_rate': learning_rate,
    'beta_1': beta_1,
    'beta_2': beta_2,
    'epsilon': epsilon,
    'amsgrad': amsgrad,
}
# Add the clip value if it was set
if clipvalue is not None:
    opt_args['clipvalue'] = clipvalue
# Build the optimizer
opt = keras.optimizers.Adam(**opt_args)

In [70]:
# Compile the model
model.compile(optimizer=opt)

In [71]:
k = 5
steps_per_epoch = elt_batch_size * k

# model.fit(x=np.ones(steps_per_epoch), epochs=10, steps_per_epoch=steps_per_epoch)

In [72]:
# elts_tf.shape

## Prototype make_model_asteroid_search()

In [73]:
from tf_utils import Identity
space_dims = 3

In [74]:
elts

Unnamed: 0,element_id,a,e,inc,Omega,omega,f,epoch
0,51921,2.669306,0.217361,0.499554,4.699703,2.450796,-1.133491,58600.0
1,59244,2.634727,0.262503,0.465045,5.738297,1.766995,-1.601363,58600.0
2,15786,1.883227,0.047655,0.392360,6.134689,0.804823,-1.246069,58600.0
3,3904,2.556387,0.098279,0.261542,5.450163,2.202423,-1.357345,58600.0
4,142999,2.619944,0.191376,0.514017,0.238022,0.946463,-1.299301,58600.0
...,...,...,...,...,...,...,...,...
59,11952,2.219650,0.086091,0.117967,0.042442,2.904823,-3.016580,58600.0
60,134815,2.612770,0.140831,0.513923,0.272689,0.645552,-0.957836,58600.0
61,27860,2.619406,0.096185,0.200633,5.541399,3.266046,3.948770,58600.0
62,85937,2.342292,0.197267,0.439063,5.279693,3.210025,3.947687,58600.0


In [75]:
ztf_elt

Unnamed: 0,ztf_id,element_id,ObjectID,CandidateID,TimeStampID,mjd,ra,dec,ux,uy,...,vz,elt_ux,elt_uy,elt_uz,elt_r,s,s_sec,z,v,is_hit
0,53851,733,b'ZTF18abnothj',594197584815010004,5501,58348.197581,266.229165,-13.513802,-0.063945,-0.983101,...,0.004080,-0.057300,-0.982042,0.179751,2.234078,0.010624,2191.371398,0.999944,0.370539,False
1,73604,733,b'ZTF18ablwzmb',594197584815015003,5501,58348.197581,265.761024,-13.509148,-0.071871,-0.982578,...,0.004080,-0.057300,-0.982042,0.179751,2.234078,0.016809,3467.103003,0.999859,0.927533,False
2,82343,733,b'ZTF18abiydvm',635193253015015018,12089,58389.193252,270.331454,-11.244934,0.005674,-0.977422,...,0.003825,0.000918,-0.977996,0.208622,2.703478,0.005450,1124.142942,0.999985,0.097510,False
3,257221,733,b'ZTF18acakcqg',931471223715015007,39920,58685.471227,29.693832,42.180412,0.643725,0.603886,...,-0.001953,0.639004,0.610779,0.467571,2.175851,0.008712,1797.042210,0.999962,0.249184,False
4,327000,733,b'ZTF18achmdmw',937465970615015011,40837,58691.465972,33.104905,44.059131,0.601970,0.636719,...,-0.002129,0.606278,0.637608,0.475272,2.114865,0.007949,1639.537152,0.999968,0.207418,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
90205,5650588,324582,b'ZTF20aaqvhld',1150176701515015008,96618,58904.176701,44.164238,29.650540,0.623416,0.752309,...,-0.001541,0.627640,0.750696,0.206212,2.981799,0.008187,1688.636853,0.999966,0.220027,False
90206,5650589,324582,b'ZTF20aaqvhld',1150176245715015005,96617,58904.176250,44.164062,29.650536,0.623417,0.752307,...,-0.001541,0.627641,0.750695,0.206213,2.981793,0.008187,1688.600639,0.999966,0.220018,False
90207,5650665,324582,b'ZTF20aaqvhll',1150176245815015010,96617,58904.176250,44.368640,28.490480,0.628284,0.753618,...,-0.001541,0.627641,0.750695,0.206213,2.981793,0.013370,2757.856469,0.999911,0.586871,False
90208,5650697,324582,b'ZTF20aaqvhmb',1150176246015015005,96617,58904.176250,43.296207,29.505908,0.633424,0.743491,...,-0.001541,0.627641,0.750695,0.206213,2.981793,0.012388,2555.279465,0.999923,0.503822,False


In [82]:
h = 0.125
R = deg2dist(1.0)
lam = 1.0

In [77]:
# Element batch size comes from elts
elt_batch_size = elts.shape[0]

In [78]:
# Dummy input; this does absolutely nothing, but otherwise keras pukes
# Did I mention that sometimes I HATE keras?
x = keras.Input(shape=(), batch_size=elt_batch_size, name='x')

# Wrap up inputs
inputs = (x, )

In [79]:
# Numpy array and tensor of observation times; flat, shape (data_size,)
ts_np = ztf_elt.mjd.values.astype(dtype_np)

# Get observation count per element
row_lengths_np = ztf_elt.element_id.groupby(ztf_elt.element_id).count()

# Shape of the observed trajectories
data_size = ztf_elt.shape[0]
traj_shape = (data_size, space_dims)

In [84]:
# Observed directions; extract from ztf_elt DataFrame
cols_u_obs = ['ux', 'uy', 'uz']
u_obs_np = ztf_elt[cols_u_obs].values.astype(dtype_np)    

# Set of trainable weights with candidate orbital elements; initialize according to elts
elements_layer = CandidateElements(elts=elts, h=h, R=R, name='candidates')

# Extract the candidate elements and mixture parameters; pass dummy inputs to satisfy keras Layer API
a, e, inc, Omega, omega, f, epoch, h, lam, R = elements_layer(inputs=x)

In [85]:
# Alias the orbital elements; a, e, inc, Omega, omega, and f are trainable; epoch is fixed
a = Identity(name='a')(a)
e = Identity(name='e')(e)
inc = Identity(name='inc')(inc)
Omega = Identity(name='Omega')(Omega)
omega = Identity(name='omega')(omega)
f = Identity(name='f')(f)
epoch = Identity(name='epoch')(epoch)

In [86]:
# Alias the mixture model parameters
h = Identity(name='h')(h)
lam = Identity(name='lam')(lam)

# The orbital elements; stack to shape (elt_batch_size, 7)
elts_tf = tf.stack(values=[a, e, inc, Omega, omega, f, epoch], axis=1, name='elts')

In [87]:
# Test model 1
outputs_1 = (a, )
model_test_1 = keras.Model(inputs=inputs, outputs=outputs_1)

In [88]:
# Run test model 1
model_test_1(x_np)

(<tf.Tensor: shape=(64,), dtype=float32, numpy=
 array([2.669305 , 2.634727 , 1.8832275, 2.5563874, 2.6199443, 2.407715 ,
        2.3491728, 2.3106558, 2.2718794, 2.6123786, 2.6039326, 3.0796757,
        3.1138892, 2.331416 , 3.0063653, 2.344482 , 2.2864869, 2.582972 ,
        2.5813491, 2.6285136, 2.6182234, 3.2066405, 2.907894 , 2.6719072,
        2.3353634, 3.1711535, 2.32733  , 2.6373715, 2.4007137, 2.3050103,
        2.6771643, 2.6571858, 2.42876  , 1.9033059, 2.5840602, 2.4018168,
        2.5349774, 2.6773167, 3.0751357, 3.1739578, 3.2142055, 2.6185794,
        3.1158595, 1.9388875, 2.7592254, 2.275916 , 2.3057158, 3.398871 ,
        2.3911347, 2.2808902, 2.240677 , 1.9278091, 1.968155 , 2.5760424,
        2.581424 , 2.3393567, 2.9165263, 2.6267216, 3.2064333, 2.2196505,
        2.61277  , 2.6194053, 2.3422916, 3.1555693], dtype=float32)>,)

In [89]:
# The predicted direction
direction_layer = AsteroidDirection(ts_np=ts_np, row_lengths_np=row_lengths_np, 
                                    site_name=site_name, name='direction_layer')

# Calibration arrays (flat)
cols_q_ast = ['qx', 'qy', 'qz']
cols_v_ast = ['vx', 'vy', 'vz']
q_ast = ztf_elt[cols_q_ast].values.astype(dtype_np)
v_ast = ztf_elt[cols_v_ast].values.astype(dtype_np)

In [90]:
# Run calibration
direction_layer.q_layer.calibrate(elts=elts, q_ast=q_ast, v_ast=v_ast)

# Tensor of predicted directions
u_pred, r_pred = direction_layer(a, e, inc, Omega, omega, f, epoch)

In [91]:
# Test model 2
outputs_2 = (u_pred, )
model_test_2 = keras.Model(inputs=inputs, outputs=outputs_2)

In [92]:
# Run test model 2
model_test_2(x_np)

(<tf.Tensor: shape=(90210, 3), dtype=float32, numpy=
 array([[-5.7307154e-02, -9.8204172e-01,  1.7974921e-01],
        [-5.7307154e-02, -9.8204172e-01,  1.7974921e-01],
        [ 9.1700378e-04, -9.7799540e-01,  2.0862426e-01],
        ...,
        [ 6.2764210e-01,  7.5069404e-01,  2.0621328e-01],
        [ 6.2764210e-01,  7.5069404e-01,  2.0621328e-01],
        [ 6.2764210e-01,  7.5069404e-01,  2.0621328e-01]], dtype=float32)>,)

In [93]:
# Score layer for these observations
score_layer = TrajectoryScore(row_lengths_np=row_lengths_np, u_obs_np=u_obs_np,
                              thresh_deg=thresh_deg, name='score_layer')

# Compute the log likelihood by element from the predicted direction and mixture model parameters
log_like = score_layer(u_pred, h=h, lam=lam)

In [94]:
# Test model 3
outputs_3 = (log_like, )
model_test_3 = keras.Model(inputs=inputs, outputs=outputs_3)

In [95]:
# Run test model 3
model_test_3(x_np)

(<tf.Tensor: shape=(64,), dtype=float32, numpy=
 array([6.9650297, 5.9499645, 8.21375  , 4.682707 , 5.140408 , 5.6145296,
        5.4914594, 9.588548 , 6.1326394, 5.3171554, 6.057453 , 5.160606 ,
        6.2296643, 5.5941014, 8.768156 , 5.4015975, 6.1190042, 5.287321 ,
        6.362869 , 7.106543 , 5.802117 , 7.226463 , 5.9347825, 5.277159 ,
        6.303598 , 4.8148537, 5.533498 , 5.4938703, 6.2355227, 3.9153745,
        5.3333926, 6.7027745, 4.534996 , 5.143604 , 3.4614468, 4.9295664,
        2.8134315, 6.1517406, 7.405118 , 5.27056  , 6.1381526, 6.7151227,
        4.1302013, 6.339996 , 5.7333446, 6.5054455, 6.0524893, 4.852806 ,
        3.4832158, 5.10293  , 6.2081876, 5.3100243, 5.571174 , 4.7640467,
        1.7110062, 3.3318555, 4.2579346, 2.828771 , 5.718561 , 9.448475 ,
        5.0833254, 6.0138373, 4.9748497, 4.242923 ], dtype=float32)>,)

In [96]:
# Wrap inputs and outputs
inputs = (x,)
# outputs = (log_like, elts_tf, u_pred)
outputs = (log_like, elts_tf, u_pred,)

In [97]:
# Create model with functional API
model = keras.Model(inputs=inputs, outputs=outputs)

In [98]:
# Run model
model(x_np)

(<tf.Tensor: shape=(64,), dtype=float32, numpy=
 array([6.9650297, 5.949962 , 8.21375  , 4.682707 , 5.1404104, 5.614532 ,
        5.491459 , 9.588542 , 6.1326385, 5.317159 , 6.057454 , 5.160606 ,
        6.2296615, 5.5941014, 8.768161 , 5.4015965, 6.119004 , 5.2873235,
        6.362866 , 7.106541 , 5.8021135, 7.226465 , 5.934785 , 5.277146 ,
        6.303595 , 4.814853 , 5.5334964, 5.4938684, 6.2355194, 3.915373 ,
        5.3333936, 6.702776 , 4.534997 , 5.1436067, 3.4614463, 4.9295664,
        2.8134334, 6.1517453, 7.4051194, 5.2705617, 6.138152 , 6.715118 ,
        4.130203 , 6.3399954, 5.7333455, 6.505447 , 6.0524855, 4.852804 ,
        3.483217 , 5.1029286, 6.2081895, 5.310025 , 5.5711737, 4.7640457,
        1.7110058, 3.3318565, 4.257934 , 2.8287692, 5.718564 , 9.448479 ,
        5.083329 , 6.013835 , 4.974848 , 4.2429233], dtype=float32)>,
 <tf.Tensor: shape=(64, 7), dtype=float32, numpy=
 array([[ 2.66930509e+00,  2.17360646e-01,  4.99553442e-01,
          4.69970322e+00,  2.450