In [94]:
# Core
import numpy as np
import pandas as pd

# Tensorflow / ML
import tensorflow as tf
from tensorflow.python.keras import backend as K

# Local
from ztf_data import load_ztf_nearest_ast, calc_hit_freq, load_ztf_batch,make_ztf_batch
from asteroid_integrate import load_ast_elt, load_ast_pos
from candidate_element import orbital_element_batch, perturb_elts, random_elts

from asteroid_model import make_model_ast_pos, make_model_ast_dir
from asteroid_model import AsteroidPosition, AsteroidDirection
from asteroid_search_model import make_model_asteroid_search, OrbitalElements, TrajectoryScore
from astro_utils import deg2dist, dist2deg, dist2sec

In [95]:
keras = tf.keras

## Load ZTF Data and Batch of Orbital Elements

In [96]:
# Load orbital elements for known asteroids
ast_elt = load_ast_elt()

# Number of asteroids
N_ast = ast_elt.shape[0]

In [97]:
# Load ztf nearest asteroid data
ztf_ast = load_ztf_nearest_ast()

In [98]:
# Asteroid numbers and hit counts
ast_num, hit_count = calc_hit_freq(ztf=ztf_ast, thresh_sec=2.0)

# Sort the hit counts in descending order and find the top batch_size
idx = np.argsort(hit_count)[::-1]

# Extract the asteroid number and hit count for this batch
ast_num_best = ast_num[idx]
hit_count_best = hit_count[idx]

In [99]:
# Set batch size
batch_size = 64
elt_batch_size = batch_size

# Batch of unperturbed elements
elts_ast = orbital_element_batch(ast_nums=ast_num_best[0:batch_size])

In [100]:
elts_ast

Unnamed: 0,element_id,a,e,inc,Omega,omega,f,epoch
0,51921,2.669306,0.217361,0.499554,4.699703,2.450796,-1.133491,58600.0
1,59244,2.634727,0.262503,0.465045,5.738297,1.766995,-1.601363,58600.0
2,15786,1.883227,0.047655,0.392360,6.134689,0.804823,-1.246069,58600.0
3,3904,2.556387,0.098279,0.261542,5.450163,2.202423,-1.357345,58600.0
4,142999,2.619944,0.191376,0.514017,0.238022,0.946463,-1.299301,58600.0
...,...,...,...,...,...,...,...,...
59,11952,2.219650,0.086091,0.117967,0.042442,2.904823,-3.016580,58600.0
60,134815,2.612770,0.140831,0.513923,0.272689,0.645552,-0.957836,58600.0
61,27860,2.619406,0.096185,0.200633,5.541399,3.266046,3.948770,58600.0
62,85937,2.342292,0.197267,0.439063,5.279693,3.210025,3.947687,58600.0


In [101]:
# Perturb orbital elements
sigma_a = 0.0 
sigma_e = 0.0 
sigma_f_deg = 0.1
sigma_Omega_deg = 0.0
sigma_omega_deg = 0.0
mask_pert = None
random_seed = 42

elts_pert = perturb_elts(elts_ast, sigma_a=sigma_a, sigma_e=sigma_e, sigma_f_deg=sigma_f_deg, 
                         sigma_Omega_deg=sigma_Omega_deg, sigma_omega_deg=sigma_omega_deg,
                         mask_pert=mask_pert, random_seed=random_seed)

In [102]:
elts_pert

Unnamed: 0,element_id,a,e,inc,Omega,omega,f,epoch
0,51921,2.669306,0.217361,0.499554,4.699703,2.450796,-1.133117,58600.0
1,59244,2.634727,0.262503,0.465045,5.738297,1.766995,-1.603537,58600.0
2,15786,1.883227,0.047655,0.392360,6.134689,0.804823,-1.245767,58600.0
3,3904,2.556387,0.098279,0.261542,5.450163,2.202423,-1.356673,58600.0
4,142999,2.619945,0.191376,0.514017,0.238022,0.946463,-1.300844,58600.0
...,...,...,...,...,...,...,...,...
59,11952,2.219650,0.086091,0.117967,0.042442,2.904823,-3.014978,58600.0
60,134815,2.612770,0.140831,0.513923,0.272689,0.645552,-0.954132,58600.0
61,27860,2.619406,0.096185,0.200633,5.541399,3.266046,3.950572,58600.0
62,85937,2.342292,0.197267,0.439063,5.279693,3.210025,3.945035,58600.0


## Batches of ZTF Data vs. Elements

In [103]:
# Arguments to make_ztf_batch
thresh_deg = 1.0
near_ast = False
regenerate = False

In [104]:
# Load unperturbed element batch
ztf_elt_ast = load_ztf_batch(elts=elts_ast, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [105]:
# Load perturbed element batch
ztf_elt_pert = load_ztf_batch(elts=elts_pert, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [106]:
ztf_elt_ast

Unnamed: 0,ztf_id,element_id,ObjectID,CandidateID,TimeStampID,mjd,ra,dec,ux,uy,...,vz,elt_ux,elt_uy,elt_uz,elt_r,s,s_sec,z,v,is_hit
0,53851,733,b'ZTF18abnothj',594197584815010004,5501,58348.197581,266.229165,-13.513802,-0.063945,-0.983101,...,0.004080,-0.057300,-0.982042,0.179751,2.234078,0.010624,2191.371398,0.999944,0.370539,False
1,73604,733,b'ZTF18ablwzmb',594197584815015003,5501,58348.197581,265.761024,-13.509148,-0.071871,-0.982578,...,0.004080,-0.057300,-0.982042,0.179751,2.234078,0.016809,3467.103003,0.999859,0.927533,False
2,82343,733,b'ZTF18abiydvm',635193253015015018,12089,58389.193252,270.331454,-11.244934,0.005674,-0.977422,...,0.003825,0.000918,-0.977996,0.208622,2.703478,0.005450,1124.142942,0.999985,0.097510,False
3,257221,733,b'ZTF18acakcqg',931471223715015007,39920,58685.471227,29.693832,42.180412,0.643725,0.603886,...,-0.001953,0.639004,0.610779,0.467571,2.175851,0.008712,1797.042210,0.999962,0.249184,False
4,327000,733,b'ZTF18achmdmw',937465970615015011,40837,58691.465972,33.104905,44.059131,0.601970,0.636719,...,-0.002129,0.606278,0.637608,0.475272,2.114865,0.007949,1639.537152,0.999968,0.207418,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
90205,5650588,324582,b'ZTF20aaqvhld',1150176701515015008,96618,58904.176701,44.164238,29.650540,0.623416,0.752309,...,-0.001541,0.627640,0.750696,0.206212,2.981799,0.008187,1688.636853,0.999966,0.220027,False
90206,5650589,324582,b'ZTF20aaqvhld',1150176245715015005,96617,58904.176250,44.164062,29.650536,0.623417,0.752307,...,-0.001541,0.627641,0.750695,0.206213,2.981793,0.008187,1688.600639,0.999966,0.220018,False
90207,5650665,324582,b'ZTF20aaqvhll',1150176245815015010,96617,58904.176250,44.368640,28.490480,0.628284,0.753618,...,-0.001541,0.627641,0.750695,0.206213,2.981793,0.013370,2757.856469,0.999911,0.586871,False
90208,5650697,324582,b'ZTF20aaqvhmb',1150176246015015005,96617,58904.176250,43.296207,29.505908,0.633424,0.743491,...,-0.001541,0.627641,0.750695,0.206213,2.981793,0.012388,2555.279465,0.999923,0.503822,False


In [107]:
ztf_elt_ast.columns

Index(['ztf_id', 'element_id', 'ObjectID', 'CandidateID', 'TimeStampID', 'mjd',
       'ra', 'dec', 'ux', 'uy', 'uz', 'qx', 'qy', 'qz', 'vx', 'vy', 'vz',
       'elt_ux', 'elt_uy', 'elt_uz', 'elt_r', 's', 's_sec', 'z', 'v',
       'is_hit'],
      dtype='object')

In [108]:
# Review results
ztf_elt = ztf_elt_ast
element_id_best = ast_num_best[0]
mask = (ztf_elt.element_id == element_id_best)
hits_best = np.sum(ztf_elt[mask].is_hit)
s_sec_min = np.min(ztf_elt[mask].s_sec)
idx = np.argmin(ztf_elt.s)
ztf_id = ztf_elt.ztf_id[idx]
# ztf_elt[mask].iloc[idx:idx+1]
print(f'Best asteroid has element_id = {element_id_best}')
print(f'Hit count: {hits_best}')
print(f'Closest hit: {s_sec_min:0.3f} arc seconds')
# ztf_elt[mask]

Best asteroid has element_id = 51921
Hit count: 158
Closest hit: 0.382 arc seconds


## Load Direction Model

In [109]:
# Data types
dtype = tf.float32
dtype_np = np.float32

In [110]:
# Alias ztf_elt
ztf_elt = ztf_elt_ast.copy()

In [111]:
# Build numpy array of times
ts_np = ztf_elt.mjd.values.astype(dtype_np)

In [112]:
# Build tensor of flattened times
ts = keras.backend.constant(ts_np)

# Get observation count per element
row_lengths = ztf_elt.element_id.groupby(ztf_elt.element_id).count().values.astype(np.int32)

In [113]:
# Observation site
site_name = 'palomar'

In [114]:
# Build direction model
model_dir = make_model_ast_dir(ts=ts, row_lengths=row_lengths, site_name=site_name, elt_batch_size=elt_batch_size)

In [115]:
# Calibration arrays (flat)
cols_q_ast = ['qx', 'qy', 'qz']
cols_v_ast = ['vx', 'vy', 'vz']
q_ast = ztf_elt[cols_q_ast].values.astype(dtype_np)
v_ast = ztf_elt[cols_v_ast].values.astype(dtype_np)

In [116]:
# Run calibration
model_dir.ast_pos_layer.calibrate(elts=elts_ast, q_ast=q_ast, v_ast=v_ast)

## Predict and Test Direction Model

In [117]:
# Stack elements as a dict of numpy arrays for prediction
cols_elt = ['a', 'e', 'inc', 'Omega', 'omega', 'f', 'epoch']
elts_ast_dict = {col : elts_ast[col].values for col in cols_elt}

In [118]:
# Predict calibrated direction model
u_pred, r_pred = model_dir.predict(elts_ast_dict)

In [119]:
# Compare to direction on the ztf_elt frame (should be the same!)
cols_u = ['elt_ux', 'elt_uy', 'elt_uz']
u_exp = ztf_elt[cols_u].values
u_diff = u_pred.values - u_exp
mean_diff_s = np.mean(np.linalg.norm(u_diff, axis=1))
mean_diff_sec = dist2sec(mean_diff_s)

# Report results
print(f'Mean direction difference:')
print(f'Cartesian  : {mean_diff_s:8.2e}')
print(f'Arc Seconds: {mean_diff_sec:6.3f}')

Mean direction difference:
Cartesian  : 7.86e-06
Arc Seconds:  1.620


## Tensors and Layers for Eager Mode

In [120]:
# Orbital elements
a = tf.Variable(initial_value=elts_ast.a, dtype=dtype, name='a')
e = tf.Variable(initial_value=elts_ast.e, dtype=dtype, name='e')
inc = tf.Variable(initial_value=elts_ast.inc, dtype=dtype, name='inc')
Omega = tf.Variable(initial_value=elts_ast.Omega, dtype=dtype, name='Omega')
omega = tf.Variable(initial_value=elts_ast.omega, dtype=dtype, name='omega')
f = tf.Variable(initial_value=elts_ast.f, dtype=dtype, name='f')
epoch = tf.constant(value=elts_ast.epoch, dtype=dtype, name='epoch')

In [121]:
# Mixture probability paramters: h and lam
h = tf.Variable(initial_value=np.ones(elt_batch_size)*0.5, dtype=dtype, name='h')
lam = tf.Variable(initial_value=np.ones(elt_batch_size)*1.0, dtype=dtype, name='lam')

In [122]:
# Observed directions
cols_u_obs = ['ux', 'uy', 'uz']
u_obs_np = ztf_elt[cols_u_obs].values.astype(dtype_np)
u_obs = tf.constant(value=u_obs_np, dtype=dtype, name='u_obs')

In [123]:
# Review shape
print(f'u_obs.shape = {u_obs.shape}')

u_obs.shape = (90210, 3)


In [124]:
# Extract direction layers
ast_pos_layer = model_dir.ast_pos_layer
ast_dir_layer = model_dir.ast_dir_layer

In [125]:
# Predict in eager mode
q_flat, v_flat = ast_pos_layer(a, e, inc, Omega, omega, f, epoch)
u_pred, r_pred = ast_dir_layer(a, e, inc, Omega, omega, f, epoch)

In [126]:
# Review shapes
print(f'q_flat.shape = {q_flat.shape}')
print(f'v_flat.shape = {v_flat.shape}')
print(f'u_pred.shape = {u_pred.shape}')
print(f'r_pred.shape = {r_pred.shape}')

q_flat.shape = (90210, 3)
v_flat.shape = (90210, 3)
u_pred.shape = (90210, 3)
r_pred.shape = (90210, 1)


## TrajectoryScore Layer

In [127]:
# Build score layer
score_layer = TrajectoryScore(row_lengths=row_lengths, thresh_deg=thresh_deg)

In [128]:
# Test score layer
log_like = score_layer(u_pred, u_obs, h, lam)

In [129]:
# Report shape
print(f'log_like.shape = {log_like.shape}')

log_like.shape = (64,)


## Prototype TrajectoryScore

In [130]:
# Thresholds
thresh_s = deg2dist(thresh_deg).astype(dtype_np)
thresh_s2 = (thresh_s**2).astype(dtype_np)
thresh_z = (1.0 - thresh_s2 / 2.0).astype(dtype_np)

In [131]:
# Difference between actual and predicted directions
du = keras.layers.subtract(inputs=[u_pred, u_obs], name='du')
# Squared distance bewteen predicted and observed directions
s2 = K.sum(tf.square(du), axis=(-1))

In [132]:
# Review shapes
print(f'du.shape = {du.shape}')
print(f's2.shape = {s2.shape}')

du.shape = (90210, 3)
s2.shape = (90210,)


In [133]:
# Look at relevant columns
cols_disp = ['ztf_id', 'element_id', 'ux', 'uy', 'uz', 'elt_ux', 'elt_uy', 'elt_uz', 'z', 'v', 'is_hit']
ztf_elt.iloc[0:3][cols_disp]

Unnamed: 0,ztf_id,element_id,ux,uy,uz,elt_ux,elt_uy,elt_uz,z,v,is_hit
0,53851,733,-0.063945,-0.983101,0.17153,-0.0573,-0.982042,0.179751,0.999944,0.370539,False
1,73604,733,-0.071871,-0.982578,0.171389,-0.0573,-0.982042,0.179751,0.999859,0.927533,False
2,82343,733,0.005674,-0.977422,0.211222,0.000918,-0.977996,0.208622,0.999985,0.09751,False


In [134]:
# Filter to only include terms where z2 is within the threshold distance^2
# is_close = s2 < thresh_s2
is_close = tf.math.less(s2, thresh_s2)

In [135]:
# Review shape
close_frac = np.mean(is_close)
print(f'is_close.shape = {is_close.shape}')
print(f'close_frac = {close_frac:8.6f}')

is_close.shape = (90210,)
close_frac = 0.999789


In [136]:
# Relative distance v on data inside threshold
# s2[is_close] / thresh_s2
# v = tf.divide(s2[is_close], thresh_s2, name='v')
v = tf.divide(tf.boolean_mask(tensor=s2, mask=is_close), thresh_s2, name='v')

In [137]:
# Review shape, summary of v
mean_v = np.mean(v)
mean_log_v = np.mean(np.log(v))
print(f'v.shape = {v.shape}')
print(f'mean_v     = {mean_v:9.6f}')
print(f'mean_log_v = {mean_log_v:9.6f}')

v.shape = (90191,)
mean_v     =  0.437990
mean_log_v = -2.683602


In [138]:
# Row_lengths, for close observations only
is_close_r = tf.RaggedTensor.from_row_lengths(is_close, row_lengths=row_lengths)
row_lengths_close = tf.reduce_sum(tf.cast(is_close_r, tf.int32), axis=1)

# Shape of parameters
close_size = tf.reduce_sum(row_lengths_close)
param_shape = (close_size,)

In [139]:
# Review row_lengths
row_lengths_close

<tf.Tensor: shape=(64,), dtype=int32, numpy=
array([1424,  936, 1117, 1087, 1035, 1028,  780, 2744,  925, 2485, 1023,
       1013, 1063,  858, 1943,  949,  636, 1199, 1547, 1477,  991,  790,
       3204, 2719, 1612,  771, 1047, 1253,  845, 3081,  866, 1472, 2677,
        802, 1652,  708, 1041,  855, 2362, 1012, 1048, 1677,  916,  812,
       1100, 1180, 2708,  688, 1929, 2871, 1099,  743,  943, 2018, 2716,
        595,  864,  941, 2628, 3126,  943, 1726,  762, 1129], dtype=int32)>

In [140]:
# Upsample h and lambda
h_vec = tf.reshape(tensor=tf.repeat(h, row_lengths_close), shape=param_shape, name='h_vec')
lam_vec = tf.reshape(tensor=tf.repeat(h, row_lengths_close), shape=param_shape, name='lam_vec')

In [141]:
# Probability according to mixture model
p = h_vec * tf.exp(-lam_vec * v) + (1.0 - h_vec)
log_p = keras.layers.Activation(tf.math.log, name='log_p')(p)

In [142]:
# Rearrange to ragged tensors
log_p_r = tf.RaggedTensor.from_row_lengths(log_p, row_lengths=row_lengths_close)

In [143]:
# Log likelihood by element
log_like = tf.reduce_sum(log_p_r, axis=1)

In [144]:
# Summary statistics
mean_log_p = np.mean(log_like)
std_log_p = np.std(log_like)

# Review shape and summary
print(f'log_p.shape = {log_p.shape}')
print(f'log_like.shape = {log_like.shape}')
print(f'mean_log_p = {mean_log_p:9.6f}')
print(f'std_log_p  = {std_log_p:9.6f}')

log_p.shape = (90191,)
log_like.shape = (64,)
mean_log_p = -141.581116
std_log_p  = 83.229507


## Candidate Elements Layer

In [145]:
# Alias elts
elts = elts_ast

In [146]:
# Build elements layer
elements_layer = OrbitalElements(elts=elts, elt_batch_size=elt_batch_size, h=0.125, lam=1.0, name='candidates')

In [147]:
# Extract elements
a, e, inc, Omega, omega, f, epoch, h, lam = elements_layer(inputs=None)

In [148]:
# Review shapes
print(f'a.shape = {a.shape}')

a.shape = (64,)


## Prototype make_model_asteroid_search()

In [149]:
model = make_model_asteroid_search(elts=elts, 
                                   ztf_elt=ztf_elt, 
                                   site_name=site_name,
                                   h=0.5,
                                   lam=1.0,
                                   thresh_deg=thresh_deg)

ValueError: Duplicate node name in graph: 'score_layer/packed'

In [80]:
from tf_utils import Identity
space_dims = 3

In [81]:
elts

Unnamed: 0,element_id,a,e,inc,Omega,omega,f,epoch
0,51921,2.669306,0.217361,0.499554,4.699703,2.450796,-1.133491,58600.0
1,59244,2.634727,0.262503,0.465045,5.738297,1.766995,-1.601363,58600.0
2,15786,1.883227,0.047655,0.392360,6.134689,0.804823,-1.246069,58600.0
3,3904,2.556387,0.098279,0.261542,5.450163,2.202423,-1.357345,58600.0
4,142999,2.619944,0.191376,0.514017,0.238022,0.946463,-1.299301,58600.0
...,...,...,...,...,...,...,...,...
59,11952,2.219650,0.086091,0.117967,0.042442,2.904823,-3.016580,58600.0
60,134815,2.612770,0.140831,0.513923,0.272689,0.645552,-0.957836,58600.0
61,27860,2.619406,0.096185,0.200633,5.541399,3.266046,3.948770,58600.0
62,85937,2.342292,0.197267,0.439063,5.279693,3.210025,3.947687,58600.0


In [82]:
ztf_elt

Unnamed: 0,ztf_id,element_id,ObjectID,CandidateID,TimeStampID,mjd,ra,dec,ux,uy,...,vz,elt_ux,elt_uy,elt_uz,elt_r,s,s_sec,z,v,is_hit
0,53851,733,b'ZTF18abnothj',594197584815010004,5501,58348.197581,266.229165,-13.513802,-0.063945,-0.983101,...,0.004080,-0.057300,-0.982042,0.179751,2.234078,0.010624,2191.371398,0.999944,0.370539,False
1,73604,733,b'ZTF18ablwzmb',594197584815015003,5501,58348.197581,265.761024,-13.509148,-0.071871,-0.982578,...,0.004080,-0.057300,-0.982042,0.179751,2.234078,0.016809,3467.103003,0.999859,0.927533,False
2,82343,733,b'ZTF18abiydvm',635193253015015018,12089,58389.193252,270.331454,-11.244934,0.005674,-0.977422,...,0.003825,0.000918,-0.977996,0.208622,2.703478,0.005450,1124.142942,0.999985,0.097510,False
3,257221,733,b'ZTF18acakcqg',931471223715015007,39920,58685.471227,29.693832,42.180412,0.643725,0.603886,...,-0.001953,0.639004,0.610779,0.467571,2.175851,0.008712,1797.042210,0.999962,0.249184,False
4,327000,733,b'ZTF18achmdmw',937465970615015011,40837,58691.465972,33.104905,44.059131,0.601970,0.636719,...,-0.002129,0.606278,0.637608,0.475272,2.114865,0.007949,1639.537152,0.999968,0.207418,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
90205,5650588,324582,b'ZTF20aaqvhld',1150176701515015008,96618,58904.176701,44.164238,29.650540,0.623416,0.752309,...,-0.001541,0.627640,0.750696,0.206212,2.981799,0.008187,1688.636853,0.999966,0.220027,False
90206,5650589,324582,b'ZTF20aaqvhld',1150176245715015005,96617,58904.176250,44.164062,29.650536,0.623417,0.752307,...,-0.001541,0.627641,0.750695,0.206213,2.981793,0.008187,1688.600639,0.999966,0.220018,False
90207,5650665,324582,b'ZTF20aaqvhll',1150176245815015010,96617,58904.176250,44.368640,28.490480,0.628284,0.753618,...,-0.001541,0.627641,0.750695,0.206213,2.981793,0.013370,2757.856469,0.999911,0.586871,False
90208,5650697,324582,b'ZTF20aaqvhmb',1150176246015015005,96617,58904.176250,43.296207,29.505908,0.633424,0.743491,...,-0.001541,0.627641,0.750695,0.206213,2.981793,0.012388,2555.279465,0.999923,0.503822,False


In [83]:
# Element batch size comes from elts
elt_batch_size = elts.shape[0]

# Numpy array and tensor of observation times; flat, shape (data_size,)
ts_np = ztf_elt.mjd.values.astype(dtype_np)
ts = keras.backend.constant(value=ts_np, shape=ts_np.shape, dtype=dtype, name='ts')

# Get observation count per element
row_lengths_np = ztf_elt.element_id.groupby(ztf_elt.element_id).count()
row_lengths = keras.backend.constant(value=row_lengths_np, shape=row_lengths_np.shape, 
                                     dtype=tf.int32, name='row_lengths')

# Shape of the observed trajectories
data_size = ztf_elt.shape[0]
traj_shape = (data_size, space_dims)

In [84]:
# Input is observed directions as a ragged tensor of shape (elt_batch_size, (num_obs), 3,)
u_obs = keras.Input(shape=(3,), batch_size=data_size, name='u_obs')

In [85]:
# Set of trainable weights with candidate orbital elements; initialize according to elts
elements_layer = OrbitalElements(elts=elts, elt_batch_size=elt_batch_size, h=h, lam=lam, name='candidates')

# Extract the candidate elements and mixture parameters; pass dummy inputs to satisfy keras Layer API
a, e, inc, Omega, omega, f, epoch, h, lam = elements_layer(inputs=None)

# Alias the orbital elements; a, e, inc, Omega, omega, and f are trainable; epoch is fixed
a = Identity(name='a')(a)
e = Identity(name='e')(e)
inc = Identity(name='inc')(inc)
Omega = Identity(name='Omega')(Omega)
omega = Identity(name='omega')(omega)
f = Identity(name='f')(f)
epoch = Identity(name='epoch')(epoch)

In [86]:
# Alias the mixture model parameters
h = Identity(name='h')(h)
lam = Identity(name='lam')(lam)

In [87]:
# The orbital elements; stack to shape (elt_batch_size, 7)
elts_tf = tf.stack(values=[a, e, inc, Omega, omega, f, epoch], axis=1, name='elts')

# The predicted direction
direction_layer = AsteroidDirection(ts=ts, row_lengths=row_lengths, site_name=site_name, name='u_pred')

In [88]:
# Calibration arrays (flat)
cols_q_ast = ['qx', 'qy', 'qz']
cols_v_ast = ['vx', 'vy', 'vz']
q_ast = ztf_elt[cols_q_ast].values.astype(dtype_np)
v_ast = ztf_elt[cols_v_ast].values.astype(dtype_np)

In [89]:
# Run calibration
direction_layer.q_layer.calibrate(elts=elts_ast, q_ast=q_ast, v_ast=v_ast)

In [90]:
# Tensor of predicted directions
u_pred, r_pred = direction_layer(a, e, inc, Omega, omega, f, epoch)

In [91]:
# Review shape
print(f'u_pred.shape = {u_pred.shape}')

u_pred.shape = (90210, 3)


In [92]:
# Score layer for these observations
score_layer = TrajectoryScore(row_lengths=row_lengths, thresh_deg=thresh_deg)

In [93]:
# Compute the score
log_like = score_layer(u_pred, u_obs=u_obs, h=h, lam=lam)
score = np.sum(log_like)

ValueError: Duplicate node name in graph: 'trajectory_score_2/packed'

In [74]:
# Review shape and output
print(f'score = {score:8.6f}')
print(f'log_like.shape = {log_like.shape}')

score = -3745.686523
log_like.shape = (64,)


In [75]:
# Wrap inputs and outputs
inputs = (u_obs)
outputs = (log_like, elts_tf, u_pred)

In [76]:
# Create model with functional API
model = keras.Model(inputs=None, outputs=None)

AttributeError: 'NoneType' object has no attribute 'op'