In [None]:
# Core
import numpy as np
import pandas as pd

# Tensorflow / ML
import tensorflow as tf
from tensorflow.python.keras import backend as K

# Utility
import time
from datetime import timedelta

# Local imports
from asteroid_search_model import make_model_asteroid_search
from ztf_data import load_ztf_easy_batch
from asteroid_data import make_ztf_dataset, orbital_element_batch
from asteroid_integrate import calc_ast_pos
from asteroid_search_report import report_model, report_training_progress
from utils import print_header

# Typing
from typing import Dict

In [None]:
from asteroid_search import perturb_elts, test_easy_batch
from asteroid_search_report import report_model_attribute
from search_score_functions import score_mean_var
from asteroid_dataframe import calc_ast_data, spline_ast_vec_df, calc_ast_dir
from ztf_data import make_ztf_batch, make_ztf_near_elt
from astro_utils import deg2dist, dist2deg

In [3]:
# Aliases
keras = tf.keras

## Generate DataSet Matching an Easy Batch

In [4]:
# Resolution and threshold in degrees
R_deg: float = 1.0
thresh_deg: float = 1.0
# Whether resolution R is trainable
R_is_trainable = False

# Batch setup
time_batch_size = None
elt_batch_size = 64
epochs = 10
cycles_per_epoch = 10

# Set calibration flag
use_calibration: bool = True

In [5]:
# Load all ZTF data with nearest asteroid calculations
ztf, elts = load_ztf_easy_batch(batch_size=elt_batch_size, thresh_deg=thresh_deg)

# Build a TensorFlow DataSet from ZTF DataFrame
ds, ts, row_len = make_ztf_dataset(ztf=ztf, batch_size=time_batch_size)

# Trajectory size and steps per batch
traj_size = ts.shape[0]
# If time_batch_size was None, use all the time points in each batch
if time_batch_size is None:
    time_batch_size = traj_size
steps = int(np.ceil(traj_size / time_batch_size))

# The element_id is the same as the asteroid number on the easy batch
element_id = elts.element_id.values
# The epoch (scalar)
epoch = elts['epoch'][0]

# The correct orbital elements; also a DataFrame
elts_true = elts.copy()

# Get example batch
batch_in, batch_out = list(ds.take(1))[0]
# Contents of this batch
t = batch_in['t']
idx = batch_in['idx']
row_len = batch_in['row_len']
u_obs = batch_in['u_obs']

# Get max_obs and number of observations
max_obs: int = u_obs.shape[1]
# The number of observations is the TOTAL FOR THE ZTF DATA SET!
# It's not just the size of this easy batch, b/c the easy batch has been harvested to be close!
# num_obs: float = np.sum(row_len, dtype=np.float32)
num_obs: float = 5.69E6

In [6]:
ztf

Unnamed: 0,ztf_id,element_id,ObjectID,CandidateID,TimeStampID,mjd,ra,dec,ux,uy,uz,nearest_ast_num,nearest_ast_dist,ast_ra,ast_dec,ast_ux,ast_uy,ast_uz
0,37606,733,b'ZTF18abtdsvv',621336711115010010,9577,58375.336713,350.405941,0.879484,0.985897,-0.146790,0.080371,1136336,0.001973,350.518390,0.891030,0.986219,-0.144934,0.079786
1,37607,733,b'ZTF18abtdsvv',621309991115010019,9572,58375.310000,350.405990,0.879483,0.985897,-0.146789,0.080371,1136336,0.002098,350.525361,0.893904,0.986239,-0.144804,0.079784
2,37610,733,b'ZTF18abtdsvv',614358404915010002,8240,58368.358403,350.405972,0.879528,0.985897,-0.146789,0.080371,1131588,0.001976,350.425686,0.991042,0.985923,-0.145699,0.082020
3,37612,733,b'ZTF18abtdsvv',615324371115015021,8425,58369.324375,350.405987,0.879540,0.985897,-0.146789,0.080371,320064,0.001601,350.497212,0.889307,0.986159,-0.145280,0.079903
4,141833,733,b'ZTF18abwpfho',623339405515010000,9980,58377.339410,348.764546,0.189321,0.980829,-0.177448,0.080534,298985,0.002356,348.831076,0.306747,0.981046,-0.175586,0.081961
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
90205,5584064,324582,b'ZTF18aaaknvz',1149162620015010022,96198,58903.162627,67.404504,45.304078,0.270241,0.878532,0.393893,417889,0.006270,67.794274,45.536954,0.264723,0.878880,0.396851
90206,5584155,324582,b'ZTF20aaqjjhn',1149162620115015001,96198,58903.162627,66.164947,45.635778,0.282557,0.871180,0.401505,1009539,0.003937,66.390553,45.474743,0.280840,0.873100,0.398528
90207,5584262,324582,b'ZTF20aaqjltu',1149162621915015008,96198,58903.162627,67.072721,46.398450,0.268658,0.870797,0.411746,153267,0.000034,67.069974,46.398068,0.268690,0.870788,0.411746
90208,5584879,324582,b'ZTF20aaqjltx',1149162621915015001,96198,58903.162627,66.993445,46.563928,0.268718,0.869478,0.414486,155845,0.000004,66.993111,46.563943,0.268721,0.869477,0.414487


In [7]:
elts

Unnamed: 0,element_id,a,e,inc,Omega,omega,f,epoch
0,733,3.398871,0.058840,0.354075,5.951576,3.315651,3.703184,58600.0
1,1476,2.280890,0.189703,0.110415,5.768549,6.107239,0.309986,58600.0
2,1803,2.349173,0.247823,0.376183,5.886310,4.431771,2.468021,58600.0
3,2015,2.335363,0.103840,0.207837,6.012107,4.788967,1.897608,58600.0
4,2294,2.581424,0.116455,0.109954,5.075755,0.753231,0.373019,58600.0
...,...,...,...,...,...,...,...,...
59,203722,3.171154,0.291647,0.426091,5.309571,1.789398,-1.162711,58600.0
60,253246,2.677164,0.227275,0.446592,5.260714,1.455787,-0.777289,58600.0
61,306781,2.534977,0.209075,0.228417,5.551083,0.846400,-0.451545,58600.0
62,313521,2.637372,0.254435,0.524255,4.487565,2.113770,-1.015696,58600.0


## Orbital Element Batch

In [8]:
# Mask where data perturbed vs not
mask_good = np.arange(elt_batch_size) < (elt_batch_size//2)
mask_bad = ~mask_good
# Perturb second half of orbital elements
# elts_np2 = perturb_elts(elts_np, sigma_a=0.00, sigma_e=0.00, sigma_f_deg=0.0, mask=mask_bad)
elts_pert = perturb_elts(elts, mask=mask_bad)

# Orbits for calibration
if use_calibration:
    print(f'Numerically integrating calibration trajectories q_cal...')
    q_cal = calc_ast_pos(elts=elts_pert, epoch=epoch, ts=ts)
else:
    q_cal = None

Numerically integrating calibration trajectories q_cal...


## Build Model

In [9]:
# Alpha and beta parameters for the objective function
alpha = 1.0
beta = 0.0

# Build functional model for asteroid score
model = make_model_asteroid_search(\
    ts=ts, elts=elts_pert, max_obs=max_obs, num_obs=num_obs,
    elt_batch_size=elt_batch_size, time_batch_size=time_batch_size,
    R_deg=R_deg, thresh_deg=thresh_deg, R_is_trainable=R_is_trainable, alpha=alpha, beta=beta, 
    q_cal=q_cal, use_calibration=use_calibration)

# Use Adam optimizer with gradient clipping
# learning_rate = 2.0e-5
learning_rate = 2.0E-5  # default 1.0E-3
beta_1 = 0.900          # default 0.900
beta_2 = 0.999          # default 0.999
epsilon = 1.0E-7        # default 1.0E-7
amsgrad = False         # default False
clipvalue = 5.0         # default not used
opt = keras.optimizers.Adam(learning_rate=learning_rate, 
                            beta_1=beta_1, 
                            beta_2=beta_2,
                            epsilon=epsilon, 
                            amsgrad=amsgrad,
                            # clipvalue=clipvalue, 
                            )
model.compile(optimizer=opt)
# Whether to display results
display = False

## Initial Losses

In [10]:
R = deg2dist(R_deg)
A = 1.0 / R**2
thresh = deg2dist(thresh_deg)
mu_per_obs, sigma2_per_obs = score_mean_var(A, thresh=thresh)
N_obs = 5.69e6
mu = N_obs * mu_per_obs
sigma2 = N_obs * sigma2_per_obs
sigma = np.sqrt(sigma2)

# report
print(f'Resolution: {R_deg} degrees / {R:8.6} cartesian')
print(f'A         : {A:8.3f}')
print('\nPer Observation:')
print(f'mu:     {mu_per_obs:5.3e}')
print(f'sigma2: {sigma2_per_obs:5.3e}')
print('\nFor Data:')
print(f'mu:     {mu:10.6f}')
print(f'sigma2: {sigma2:10.6f}')
print(f'simga:  {sigma:10.6f}')

Resolution: 1.0 degrees / 0.0174531 cartesian
A         : 3282.890

Per Observation:
mu:     5.993e-05
sigma2: 4.813e-05

For Data:
mu:     340.986259
sigma2: 273.882006
simga:   16.549381


In [11]:
# Report losses before training
print(f'Processed easy batch with {elt_batch_size} asteroids. Perturbed second half.')
if display:
    print_header('Model Before Training:')
pred0 = model.predict_on_batch(ds)
elts0, R0, u_pred0, z0, _ = pred0
scores0, traj_err0, elt_err0 = \
    report_model(model=model, ds=ds, R_deg=R_deg, mask_good=mask_good, 
                 batch_size=elt_batch_size, steps=steps, elts_true=elts_true, display=display)

Processed easy batch with 64 asteroids. Perturbed second half.


## Check that untrained model predictions match expected results

In [12]:
# Model on unperturbed elements
model_true = make_model_asteroid_search(\
    ts=ts, elts=elts_true, max_obs=max_obs, num_obs=num_obs,
    elt_batch_size=elt_batch_size, time_batch_size=time_batch_size,
    R_deg=R_deg, thresh_deg=thresh_deg, R_is_trainable=R_is_trainable, alpha=alpha, beta=beta, 
    q_cal=q_cal, use_calibration=use_calibration)

In [13]:
# Predictions of model on unperturbed elements
pred_true = model.predict_on_batch(ds)

# Unpack predictions on unperturbed elts; only want u
_, _, u_pred_tf, _, _ = pred_true

In [14]:
# The results predicted by the TF model
u_pred = u_pred_tf.numpy()

u_pred.shape

(64, 2371, 3)

In [18]:
# Table of ZTF frames that are close to selected elements
ztf_batch_true = make_ztf_batch(elts=elts_true, thresh_deg=thresh_deg, near_ast=True)

HBox(children=(FloatProgress(value=0.0, max=64.0), HTML(value='')))




In [19]:
# Select table for the first element
elt0 = element_id[0]
mask = ztf_batch_true.element_id == elt0
ztf_i = ztf_batch_true[mask]
ztf_i

Unnamed: 0,ztf_id,element_id,ObjectID,CandidateID,TimeStampID,mjd,ra,dec,ux,uy,uz,nearest_ast_num,nearest_ast_dist,ast_ra,ast_dec,ast_ux,ast_uy,ast_uz
0,37606,733,b'ZTF18abtdsvv',621336711115010010,9577,58375.336713,350.405941,0.879484,0.985897,-0.146790,0.080371,1136336,0.001973,350.518390,0.891030,0.986219,-0.144934,0.079786
1,37607,733,b'ZTF18abtdsvv',621309991115010019,9572,58375.310000,350.405990,0.879483,0.985897,-0.146789,0.080371,1136336,0.002098,350.525361,0.893904,0.986239,-0.144804,0.079784
2,37610,733,b'ZTF18abtdsvv',614358404915010002,8240,58368.358403,350.405972,0.879528,0.985897,-0.146789,0.080371,1131588,0.001976,350.425686,0.991042,0.985923,-0.145699,0.082020
3,37612,733,b'ZTF18abtdsvv',615324371115015021,8425,58369.324375,350.405987,0.879540,0.985897,-0.146789,0.080371,320064,0.001601,350.497212,0.889307,0.986159,-0.145280,0.079903
4,141833,733,b'ZTF18abwpfho',623339405515010000,9980,58377.339410,348.764546,0.189321,0.980829,-0.177448,0.080534,298985,0.002356,348.831076,0.306747,0.981046,-0.175586,0.081961
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
683,5461148,733,b'ZTF18adccvqy',1145163955415010017,95282,58899.163958,40.031822,35.961235,0.619758,0.711255,0.331687,1001209,0.007646,40.435941,35.670498,0.618331,0.715384,0.325412
684,5461229,733,b'ZTF20aapeoax',1145163955215015011,95282,58899.163958,41.260103,36.441613,0.604733,0.723029,0.333956,191919,0.003513,41.290945,36.641338,0.602888,0.723185,0.336941
685,5646481,733,b'ZTF19abjairu',1142179350215010035,94593,58896.179352,40.098806,37.053481,0.610474,0.711312,0.348363,1148569,0.001315,40.192635,37.045226,0.609698,0.712234,0.347838
686,5646482,733,b'ZTF19abjairu',1145163030215010041,95280,58899.163032,40.099006,37.053438,0.610473,0.711314,0.348361,40232,0.004097,40.226914,37.264863,0.607622,0.712412,0.351090


In [20]:
mask = (ztf_i.nearest_ast_num == elt0)
ztf_i[mask]

Unnamed: 0,ztf_id,element_id,ObjectID,CandidateID,TimeStampID,mjd,ra,dec,ux,uy,uz,nearest_ast_num,nearest_ast_dist,ast_ra,ast_dec,ast_ux,ast_uy,ast_uz
9,395815,733,b'ZTF19abjphnr',939445724715015021,41296,58693.445729,44.285479,33.579335,0.596406,0.753718,0.276059,733,0.000004,44.285291,33.579191,0.596409,0.753717,0.276058
17,420554,733,b'ZTF19abjqzrj',939485674715015012,41349,58693.485671,44.293555,33.585205,0.596283,0.753793,0.276120,733,0.000004,44.293377,33.585051,0.596286,0.753791,0.276118
18,430852,733,b'ZTF19abkkmjx',934476402915015003,40253,58688.476401,43.232256,32.855154,0.612042,0.743715,0.268871,733,0.000004,43.232065,32.855006,0.612045,0.743713,0.268869
19,433118,733,b'ZTF19abkkmjx',934476861615015007,40254,58688.476863,43.232370,32.855227,0.612041,0.743716,0.268872,733,0.000004,43.232164,32.855074,0.612044,0.743714,0.268870
23,549984,733,b'ZTF19ablpetd',942451704715015001,42140,58696.451701,44.889994,34.015204,0.587237,0.759236,0.280557,733,0.000004,44.889810,34.015059,0.587240,0.759235,0.280555
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
667,5242925,733,b'ZTF20aanbfry',1139159645615015005,93432,58893.159641,38.972055,36.841868,0.622190,0.700314,0.349914,733,0.000003,38.971815,36.841881,0.622192,0.700311,0.349915
668,5364815,733,b'ZTF20aaoeplv',1142123785315015008,94493,58896.123785,39.695137,36.790572,0.616202,0.707510,0.346013,733,0.000003,39.694888,36.790585,0.616204,0.707508,0.346014
672,5381516,733,b'ZTF20aaoggbf',1142182525315015008,94597,58896.182523,39.709609,36.789589,0.616080,0.707654,0.345936,733,0.000004,39.709351,36.789603,0.616082,0.707651,0.345937
675,5446528,733,b'ZTF20aapdbbp',1145110625315015005,95182,58899.110625,40.456356,36.751514,0.609662,0.715006,0.342168,733,0.000003,40.456109,36.751523,0.609665,0.715003,0.342169


In [21]:
# Alias tensorflow of times to numpy array mjd
mjd = ts.numpy()

# Show that mjd is sorted
np.min(np.diff(mjd))

0.00390625

In [22]:
# indices on left and right side side; want idx s/t mjd[idx] == ztf_i.mjd
idx_l = np.searchsorted(a=mjd, v=ztf_i.mjd)-1
idx_r = idx_l + 1

# Find the side that is closer
dist_l = ztf_i.mjd.values - mjd[idx_l]
dist_r = mjd[idx_r] - ztf_i.mjd.values
is_left = (dist_l < dist_r)
idx = idx_r.copy()

# The index that is closest
idx[is_left] = idx_l[is_left]
np.max(mjd[idx] - ztf_i.mjd.values)

0.0019212998013244942

In [23]:
# TensorFlow predictions on the ZTF dates for the first element_id
u_pred_tf = u_pred[0, idx, :]

# MSE predictions in Numpy
cols_ast_dir = ['ast_ux', 'ast_uy', 'ast_uz']
u_pred_mse = ztf_i[cols_ast_dir].values

# Difference between TensorFlow and 
u_diff = u_pred_tf - u_pred_mse

In [24]:
mean_err_dist = np.mean(np.linalg.norm(u_diff, axis=1))
mean_err_deg = dist2deg(mean_err_dist)
print(f'Mean direction error for element_id {elt0}:')
print(f'Cartesian: {mean_err_dist:6.3e}')
print(f'Degrees:   {mean_err_deg:8.8f}')

Mean direction error for element_id 733:
Cartesian: 9.073e-03
Degrees:   0.51985418


In [None]:
for k, (element_id, ztf_i) in enumerate(ztf_tbl.items()):
    pass


In [None]:
u_pred_ztf.shape

In [None]:
u_pred_mse.shape

## Initial Gradients

In [None]:
# Get intital gradients on entire data set
with tf.GradientTape(persistent=True) as gt:
    gt.watch([model.elements.e_, model.elements.inc_, model.elements.R_])
    pred = model.predict_on_batch(ds.take(traj_size))
    elts, R, u_pred, z, scores = pred
    # unpack elements
    a = elts[:,0]
    e = elts[:,1]
    inc = elts[:,2]
    Omega = elts[:,3]
    omega = elts[:,4]
    f = elts[:,5]
    epoch = elts[:,6]
    # unpack scores
    raw_score = scores[:, 0]
    mu = scores[:, 1]
    sigma2 = scores[:, 2]
    objective = scores[:, 3]
    # total loss function
    loss = tf.reduce_sum(-objective)

# Derivatives of loss w.r.t. control variables for elements and R
dL_da = gt.gradient(loss, model.elements.a_) / steps
dL_de = gt.gradient(loss, model.elements.e_) / steps
dL_dinc = gt.gradient(loss, model.elements.inc_) / steps
dL_dOmega = gt.gradient(loss, model.elements.Omega_) / steps
dL_domega = gt.gradient(loss, model.elements.omega_) / steps
dL_df = gt.gradient(loss, model.elements.f_) / steps
dL_dR = gt.gradient(loss, model.elements.R_) / steps
del gt

In [None]:
# Report gradients
report_model_attribute(att=np.abs(dL_da), mask_good=mask_good, att_name='abs(grad_a)')
report_model_attribute(att=np.abs(dL_de), mask_good=mask_good, att_name='abs(grad_e)')
report_model_attribute(att=np.abs(dL_dinc), mask_good=mask_good, att_name='abs(grad_inc)')

In [None]:
dL_da[mask_good].numpy()

In [None]:
dL_da[mask_bad].numpy()

## Train Model

In [None]:
# Train model
steps_per_epoch = steps*cycles_per_epoch
rows_per_epoch = steps_per_epoch * time_batch_size
print_header(f'Training for {epochs} Epochs of Size {rows_per_epoch} Observation Days (Cycles/Epoch = {cycles_per_epoch})...')
print(f'R (degrees):    {R_deg:5.1f}')
print(f'R_is_trainable: {R_is_trainable}')
print(f'alpha:          {alpha:5.1f}')
print(f'beta:           {beta:5.1f}')
print(f'learning_rate:  {learning_rate:7.2e}')
print(f'clipvalue:      {clipvalue:5.2f}')

train_time_0 = time.time()
hist = model.fit(ds, epochs=epochs, steps_per_epoch=steps_per_epoch)
train_time_1 = time.time()
train_time = train_time_1 - train_time_0
print(f'Elapsed Time: {str(timedelta(seconds=train_time))}')

# Report results
if display:
    print_header('Model After Training:')
pred1 = model.predict_on_batch(ds)
elts1, R1, u_pred1, z1, _ = pred1
scores1, traj_err1, elt_err1 = \
    report_model(model=model, ds=ds, R_deg=R_deg, mask_good=mask_good, 
                 batch_size=elt_batch_size, steps=steps, elts_true=elts_true, display=display)

# Unpack scores after training
raw_score = scores1[:,0]
mu = scores1[:,1]
sigma2 = scores1[:,2]
objective = scores1[:,3]
sigma = scores1[:, 4]
t_score = scores1[:, 5]
eff_obs = raw_score - mu

## Change in scores
d_scores = scores1 - scores0
d_traj_err = traj_err1 - traj_err0
d_elt_err = elt_err1 - elt_err0
d_R = R1 - R0

# Report training progress: scores, orbital element errors, and resolution
print_header('Progress During Training:')
scores_01 = (scores0, scores1)
traj_err_01 = (traj_err0, traj_err1)
elt_err_01 = (elt_err0, elt_err1)
R_01 = (R0, R1)
report_training_progress(scores_01, traj_err_01, elt_err_01, R_01, mask_good)

## Test Different Resolutions

In [None]:
thresh_deg = 1.0
R_is_trainable = False
elt_batch_size = 64
time_batch_size = None
epochs = 10
cycles_per_epoch = 10

In [None]:
R_deg_cand = np.power(2.0, -np.arange(0.0, 8.0))
R_deg_cand

In [None]:
# Results table
scores_tbl = dict()
traj_err_tbl = dict()
elt_err_tbl = dict()
R_tbl = dict()

# Iterate through different resolutions
R_deg_cand = np.power(2.0, -np.arange(0.0, 8.0))

for R_deg in R_deg_cand:
    # Generate scores
    scores_01, traj_err_01, elt_err_01, R_01, mask_good = \
        test_easy_batch(R_deg=R_deg, 
                        R_is_trainable=R_is_trainable,
                        elt_batch_size=elt_batch_size,
                        time_batch_size=time_batch_size,
                        epochs=epochs,
                        cycles_per_epoch=cycles_per_epoch)
    # Save to table
    scores_tbl[R_deg] = scores_01
    traj_err_tbl[R_deg] = traj_err_01
    elt_err_tbl[R_deg] = elt_err_01
    R_tbl[R_deg] = R_01

In [None]:
scores_01[0].shape