In [1]:
# Core
import numpy as np
import pandas as pd

# Tensorflow / ML
import tensorflow as tf
from tensorflow.python.keras import backend as K

# Utility
import time
from datetime import timedelta

# Local imports
from asteroid_search_model import make_model_asteroid_search
from ztf_data import make_ztf_easy_batch
from asteroid_data import make_ztf_dataset, orbital_element_batch
from asteroid_integrate import calc_ast_pos
from asteroid_search_report import report_model, report_training_progress
from utils import print_header

# Typing
from typing import Dict

Found 4 GPUs.  Setting memory growth = True.


In [19]:
from asteroid_search import perturb_elts, test_easy_batch
from search_score_functions import score_mean_var
from astro_utils import deg2dist

In [12]:
# Aliases
keras = tf.keras

## Generate DataSet Matching an Easy Batch

In [4]:
time_batch_size = None
elt_batch_size = 128
epochs = 5

In [28]:
# Load all ZTF data with nearest asteroid calculations
ztf = make_ztf_easy_batch(batch_size=elt_batch_size)

# Build a TensorFlow DataSet from ZTF DataFrame
ds, ts, row_len = make_ztf_dataset(ztf=ztf, batch_size=time_batch_size)

# Trajectory size and steps per batch
traj_size = ts.shape[0]
# If time_batch_size was None, use all the time points in each batch
if time_batch_size is None:
    time_batch_size = traj_size
steps = int(np.ceil(traj_size / time_batch_size))

# Resolution and threshold in degrees
R_deg: float = 10.0 / 3600.0
thresh_deg: float = 1.0

# Batch of perturbed orbital elements for asteroid model
ast_nums = np.unique(ztf.nearest_ast_num)
elts_np = orbital_element_batch(ast_nums)
epoch = elts_np['epoch'][0]

# Get example batch
batch_in, batch_out = list(ds.take(1))[0]
# Contents of this batch
t = batch_in['t']
idx = batch_in['idx']
row_len = batch_in['row_len']
u_obs = batch_in['u_obs']

# Get max_obs and number of observations
max_obs: int = u_obs.shape[1]
# The number of observations is the TOTAL FOR THE ZTF DATA SET!
# It's not just the size of this easy batch, b/c the easy batch has been harvested to be close!
# num_obs: float = np.sum(row_len, dtype=np.float32)
num_obs: float = 5.7E6

In [29]:
ztf

Unnamed: 0,ObjectID,CandidateID,TimeStampID,mjd,ra,dec,ux,uy,uz,mag_app,asteroid_prob,nearest_ast_num,nearest_ast_dist,ast_ra,ast_dec,ast_ux,ast_uy,ast_uz
1415,b'ZTF19abgrbgv',927453854115015005,39178,58681.453854,44.462771,27.414173,0.633558,0.753622,0.175094,17.962000,1.0,4133,0.000005,44.462468,27.414047,0.633562,0.753619,0.175093
32387,b'ZTF19abiquuc',936427393415015008,40501,58690.427396,34.100063,26.325379,0.742182,0.637433,0.206993,18.446899,1.0,47139,0.000005,34.099859,26.325160,0.742185,0.637430,0.206990
37305,b'ZTF19abiqbau',936410851515015008,40482,58690.410856,9.400215,44.580601,0.702699,0.385939,0.597717,19.386600,1.0,313521,0.000007,9.399908,44.580294,0.702704,0.385934,0.597714
38892,b'ZTF19abiqrvx',936426030215015000,40499,58690.426030,25.393628,16.027121,0.868269,0.487978,0.089359,18.346800,1.0,59245,0.000007,25.393386,16.026813,0.868273,0.487974,0.089356
39390,b'ZTF19abiquzc',936427393915015001,40501,58690.427396,33.007043,27.055899,0.746830,0.626028,0.224353,19.067200,1.0,122263,0.000006,33.006788,27.055637,0.746833,0.626025,0.224350
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5441791,b'ZTF20aarbpja',1151112681015015000,96901,58905.112685,29.975418,30.398495,0.747155,0.596662,0.292837,19.142099,1.0,47029,0.000006,29.975023,30.398420,0.747159,0.596657,0.292838
5441935,b'ZTF20aarbqkw',1151121034915015003,96919,58905.121030,25.806967,29.268916,0.785333,0.542902,0.297505,18.497400,1.0,48453,0.000006,25.806544,29.268905,0.785336,0.542897,0.297507
5441946,b'ZTF20aarbqjg',1151121030215015003,96919,58905.121030,26.192790,22.998640,0.825990,0.528196,0.196849,18.916800,1.0,94346,0.000007,26.192344,22.998533,0.825994,0.528190,0.196849
5442136,b'ZTF20aarbqpo',1151121032115015000,96919,58905.121030,24.489749,25.919630,0.818493,0.515939,0.252736,18.606501,1.0,31539,0.000003,24.489554,25.919688,0.818494,0.515937,0.252739


## Orbital Element Batch

In [7]:
# The correct orbital elements as an array
elts_true = np.array([elts_np['a'], elts_np['e'], elts_np['inc'], elts_np['Omega'], 
                      elts_np['omega'], elts_np['f'], elts_np['epoch']]).transpose()

# Mask where data expected vs not
mask_good = np.arange(elt_batch_size) < (elt_batch_size//2)
mask_bad = ~mask_good
# Perturb second half of orbital elements
# elts_np2 = perturb_elts(elts_np, sigma_a=0.00, sigma_e=0.00, sigma_f_deg=0.0, mask=mask_bad)
elts_np2 = perturb_elts(elts_np, mask=mask_bad)

# Orbits for calibration
if 'q_cal' not in globals():
    print(f'Numerically integrating calibration trajectories q_cal...')
    q_cal = calc_ast_pos(elts=elts_np2, epoch=epoch, ts=ts)
# q_cal = None

Numerically integrating calibration trajectories q_cal...


## Build Model

In [30]:
# Set calibration flag
use_calibration: bool = True

# Alpha and beta parameters for the objective function
alpha = 1.0
beta = 0.0

# Build functional model for asteroid score
model = make_model_asteroid_search(\
    ts=ts, elts_np=elts_np2, max_obs=max_obs, num_obs=num_obs,
    elt_batch_size=elt_batch_size, time_batch_size=time_batch_size,
    R_deg=R_deg, thresh_deg=thresh_deg, alpha=alpha, beta=beta, 
    q_cal=q_cal, use_calibration=use_calibration)

# Use Adam optimizer with gradient clipping
# learning_rate = 2.0e-5
learning_rate = 2.0E-5  # default 1.0E-3
beta_1 = 0.900          # default 0.900
beta_2 = 0.999          # default 0.999
epsilon = 1.0E-7        # default 1.0E-7
amsgrad = False         # default False
clipvalue = 5.0         # default not used
opt = keras.optimizers.Adam(learning_rate=learning_rate, 
                            beta_1=beta_1, 
                            beta_2=beta_2,
                            epsilon=epsilon, 
                            amsgrad=amsgrad,
                            # clipvalue=clipvalue, 
                            )
model.compile(optimizer=opt)
# Whether to display results
display = True

## Initial Losses & Gradients

In [32]:
R = deg2dist(R_deg)
A = 1.0 / R**2
thresh = deg2dist(thresh_deg)
mu_per_obs, sigma2_per_obs = score_mean_var(A, thresh=thresh)
N_obs = 5.7e6
mu = N_obs * mu_per_obs
sigma2 = N_obs * sigma2_per_obs

# report
print(f'Resolution: {R_deg} degrees / {R:8.6} cartesian')
print(f'A         : {A:8.3f}')
print('\nPer Observation:')
print(f'mu:     {mu_per_obs:5.3e}')
print(f'sigma2: {sigma2_per_obs:5.3e}')
print('\nFor Data:')
print(f'mu:     {mu:10.6f}')
print(f'sigma2: {sigma2:10.6f}')

Resolution: 0.002777777777777778 degrees / 4.84814e-05 cartesian
A         : 425451703.045

Per Observation:
mu:     1.175e-09
sigma2: 5.876e-10

For Data:
mu:       0.006699
sigma2:   0.003349


In [33]:
# Report losses before training
print(f'Processed easy batch with {elt_batch_size} asteroids. Perturbed second half.')
if display:
    print_header('Model Before Training:')
pred0 = model.predict_on_batch(ds)
elts0, R0, u_pred0, z0, _ = pred0
scores0, traj_err0, elt_err0 = \
    report_model(model=model, ds=ds, R_deg=R_deg, mask_good=mask_good, 
                 batch_size=elt_batch_size, steps=steps, elts_true=elts_true, display=display)

Processed easy batch with 128 asteroids. Perturbed second half.

********************************************************************************
Model Before Training:
********************************************************************************

Mean & Std Trajectory Error (AU) vs. True Elements by Category:
Good:     0.00 +/-     0.00
Bad:      0.31 +/-     0.17
All:      0.15 +/-     0.19

Error in orbital elements:
(Angles shown in degrees)
      a          e         inc       Omega      omega     f
Good: 0.000000,  0.000000, 0.000005, 0.000000,  0.000002, 0.000002, 
Bad : 0.156913,  0.007847, 0.000002, 0.000000,  0.000001, 3.216398, 

Mean & Std Raw Score by Category:
Good:   147.29 +/-    12.40
Bad:      0.00 +/-     0.00
All:     73.64 +/-    74.16

Mean & Std Mu by Category:
Good:     0.01 +/-     0.00
Bad:      0.01 +/-     0.00
All:      0.01 +/-     0.00

Mean & Std Effective Observations by Category:
Good:   147.28 +/-    12.40
Bad:     -0.01 +/-     0.00
All:     73.64

In [34]:
# Get intital gradients on entire data set
with tf.GradientTape(persistent=True) as gt:
    gt.watch([model.elements.e_, model.elements.inc_, model.elements.R_])
    pred = model.predict_on_batch(ds.take(traj_size))
    elts, R, u_pred, z, scores = pred
    # unpack elements
    a = elts[:,0]
    e = elts[:,1]
    inc = elts[:,2]
    Omega = elts[:,3]
    omega = elts[:,4]
    f = elts[:,5]
    epoch = elts[:,6]
    # unpack scores
    raw_score = scores[:, 0]
    mu = scores[:, 1]
    sigma2 = scores[:, 2]
    objective = scores[:, 3]
    # total loss function
    loss = tf.reduce_sum(-objective)

#    # Derivatives of elements w.r.t. control variables
#    da_da_ = gt.gradient(a, model.elements.a_)
#    de_de_ = gt.gradient(e, model.elements.e_)
#    dinc_dinc_ = gt.gradient(inc, model.elements.inc_)
#    dOmega_dOmega_ = gt.gradient(Omega, model.elements.Omega_)
#    domega_domega_ = gt.gradient(omega, model.elements.omega_)
#    df_df_ = gt.gradient(f, model.elements.f_)

# Derivatives of loss w.r.t. control variables for elements and R
dL_da = gt.gradient(loss, model.elements.a_) / steps
dL_de = gt.gradient(loss, model.elements.e_) / steps
dL_dinc = gt.gradient(loss, model.elements.inc_) / steps
dL_dOmega = gt.gradient(loss, model.elements.Omega_) / steps
dL_domega = gt.gradient(loss, model.elements.omega_) / steps
dL_df = gt.gradient(loss, model.elements.f_) / steps
dL_dR = gt.gradient(loss, model.elements.R_) / steps
del gt

## Train Model

In [35]:
# Train model
step_multiplier = 5
steps_per_epoch = steps*step_multiplier
rows_per_epoch = steps_per_epoch * time_batch_size
print_header(f'Training for {epochs} Epochs of Size {rows_per_epoch} Observation Days...')
print(f'alpha:         {alpha:5.1f}')
print(f'beta:          {beta:5.1f}')
print(f'R (degrees):   {R_deg:5.1f}')
print(f'learning_rate:   {learning_rate:5.2e}')
print(f'clipvalue:      {clipvalue:5.2f}')

train_time_0 = time.time()
hist = model.fit(ds, epochs=epochs, steps_per_epoch=steps_per_epoch)
train_time_1 = time.time()
train_time = train_time_1 - train_time_0
print(f'Elapsed Time: {str(timedelta(seconds=train_time))}')

# Report results
if display:
    print_header('Model After Training:')
pred1 = model.predict_on_batch(ds)
elts1, R1, u_pred1, z1, _ = pred1
scores1, traj_err1, elt_err1 = \
    report_model(model=model, ds=ds, R_deg=R_deg, mask_good=mask_good, 
                 batch_size=elt_batch_size, steps=steps, elts_true=elts_true, display=display)

# Unpack scores after training
raw_score = scores1[:,0]
mu = scores1[:,1]
sigma2 = scores1[:,2]
objective = scores1[:,3]
sigma = scores1[:, 4]
t_score = scores1[:, 5]
eff_obs = raw_score - mu

## Change in scores
d_scores = scores1 - scores0
d_traj_err = traj_err1 - traj_err0
d_elt_err = elt_err1 - elt_err0
d_R = R1 - R0

# Report training progress: scores, orbital element errors, and resolution
print_header('Progress During Training:')
# report_training_progress(d_scores, d_traj_err, d_elt_err, d_R)
scores_01 = (scores0, scores1)
traj_err_01 = (traj_err0, traj_err1)
elt_err_01 = (elt_err0, elt_err1)
R_01 = (R0, R1)
# report_training_progress(d_scores, d_traj_err, d_elt_err, d_R)
report_training_progress(scores_01, traj_err_01, elt_err_01, R_01, mask_good)


********************************************************************************
Training for 5 Epochs of Size 12560 Observation Days...
********************************************************************************
alpha:           1.0
beta:            0.0
R (degrees):     0.0
learning_rate:   2.00e-05
clipvalue:       5.00
Train for 5 steps
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Elapsed Time: 0:00:02.592670

********************************************************************************
Model After Training:
********************************************************************************

Mean & Std Trajectory Error (AU) vs. True Elements by Category:
Good:     0.01 +/-     0.00
Bad:      0.31 +/-     0.17
All:      0.16 +/-     0.19

Error in orbital elements:
(Angles shown in degrees)
      a          e         inc       Omega      omega     f
Good: 0.001397,  0.000111, 0.020048, 0.040102,  0.040093, 0.040083, 
Bad : 0.156913,  0.007847, 0.000002, 0.000000,  0.000001

In [38]:
elts_np.keys()

dict_keys(['a', 'e', 'inc', 'Omega', 'omega', 'f', 'epoch'])