In [None]:
# Core
import numpy as np
import pandas as pd

# Tensorflow / ML
import tensorflow as tf
from tensorflow.python.keras import backend as K

# Utility
import time
from datetime import timedelta

# Local imports
from asteroid_search_model import make_model_asteroid_search
from ztf_data import load_ztf_easy_batch
from asteroid_data import make_ztf_dataset, orbital_element_batch
from asteroid_integrate import calc_ast_pos
from asteroid_search_report import report_model, report_training_progress
from utils import print_header

# Typing
from typing import Dict

In [2]:
from asteroid_search import perturb_elts, test_easy_batch
from asteroid_search_report import report_model_attribute
from search_score_functions import score_mean_var
from astro_utils import deg2dist

In [3]:
# Aliases
keras = tf.keras

## Generate DataSet Matching an Easy Batch

In [4]:
# Resolution and threshold in degrees
R_deg: float = 1.0
thresh_deg: float = 1.0
# Whether resolution R is trainable
R_is_trainable = False
# Batch setup
time_batch_size = None
elt_batch_size = 64
epochs = 10
cycles_per_epoch = 10

In [5]:

# Load all ZTF data with nearest asteroid calculations
ztf, elts_np = load_ztf_easy_batch(batch_size=elt_batch_size, thresh_deg=thresh_deg)

# Build a TensorFlow DataSet from ZTF DataFrame
ds, ts, row_len = make_ztf_dataset(ztf=ztf, batch_size=time_batch_size)

# Trajectory size and steps per batch
traj_size = ts.shape[0]
# If time_batch_size was None, use all the time points in each batch
if time_batch_size is None:
    time_batch_size = traj_size
steps = int(np.ceil(traj_size / time_batch_size))

# Pop asteroid number from and epoch from elts DataFrame
element_id = elts_np.pop('element_id')
# Extract epoch, leaving it on
epoch = elts_np['epoch'][0]

# The correct orbital elements as an array of shape Nx6
elts_true = elts_np.values

# Get example batch
batch_in, batch_out = list(ds.take(1))[0]
# Contents of this batch
t = batch_in['t']
idx = batch_in['idx']
row_len = batch_in['row_len']
u_obs = batch_in['u_obs']

# Get max_obs and number of observations
max_obs: int = u_obs.shape[1]
# The number of observations is the TOTAL FOR THE ZTF DATA SET!
# It's not just the size of this easy batch, b/c the easy batch has been harvested to be close!
# num_obs: float = np.sum(row_len, dtype=np.float32)
num_obs: float = 5.69E6

In [6]:
ztf

Unnamed: 0,ObjectID,CandidateID,TimeStampID,mjd,ra,dec,ux,uy,uz,nearest_ast_num,nearest_ast_dist,ast_ra,ast_dec,ast_ux,ast_uy,ast_uz
37606,b'ZTF18abtdsvv',621336711115010010,9577,58375.336713,350.405941,0.879484,0.985897,-0.146790,0.080371,1136336,0.001973,350.518390,0.891030,0.986219,-0.144934,0.079786
37607,b'ZTF18abtdsvv',621309991115010019,9572,58375.310000,350.405990,0.879483,0.985897,-0.146789,0.080371,1136336,0.002098,350.525361,0.893904,0.986239,-0.144804,0.079784
37610,b'ZTF18abtdsvv',614358404915010002,8240,58368.358403,350.405972,0.879528,0.985897,-0.146789,0.080371,1131588,0.001976,350.425686,0.991042,0.985923,-0.145699,0.082020
37612,b'ZTF18abtdsvv',615324371115015021,8425,58369.324375,350.405987,0.879540,0.985897,-0.146789,0.080371,320064,0.001601,350.497212,0.889307,0.986159,-0.145280,0.079903
141833,b'ZTF18abwpfho',623339405515010000,9980,58377.339410,348.764546,0.189321,0.980829,-0.177448,0.080534,298985,0.002356,348.831076,0.306747,0.981046,-0.175586,0.081961
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5584064,b'ZTF18aaaknvz',1149162620015010022,96198,58903.162627,67.404504,45.304078,0.270241,0.878532,0.393893,417889,0.006270,67.794274,45.536954,0.264723,0.878880,0.396851
5584155,b'ZTF20aaqjjhn',1149162620115015001,96198,58903.162627,66.164947,45.635778,0.282557,0.871180,0.401505,1009539,0.003937,66.390553,45.474743,0.280840,0.873100,0.398528
5584262,b'ZTF20aaqjltu',1149162621915015008,96198,58903.162627,67.072721,46.398450,0.268658,0.870797,0.411746,153267,0.000034,67.069974,46.398068,0.268690,0.870788,0.411746
5584879,b'ZTF20aaqjltx',1149162621915015001,96198,58903.162627,66.993445,46.563928,0.268718,0.869478,0.414486,155845,0.000004,66.993111,46.563943,0.268721,0.869477,0.414487


## Orbital Element Batch

In [7]:
# Mask where data perturbed vs not
mask_good = np.arange(elt_batch_size) < (elt_batch_size//2)
mask_bad = ~mask_good
# Perturb second half of orbital elements
# elts_np2 = perturb_elts(elts_np, sigma_a=0.00, sigma_e=0.00, sigma_f_deg=0.0, mask=mask_bad)
elts_np2 = perturb_elts(elts_np, mask=mask_bad)

# Orbits for calibration
if 'q_cal' not in globals():
    print(f'Numerically integrating calibration trajectories q_cal...')
    q_cal = calc_ast_pos(elts=elts_np2, epoch=epoch, ts=ts)
# q_cal = None

Numerically integrating calibration trajectories q_cal...


## Build Model

In [8]:
# Set calibration flag
use_calibration: bool = True

# Alpha and beta parameters for the objective function
alpha = 1.0
beta = 0.0

# Build functional model for asteroid score
model = make_model_asteroid_search(\
    ts=ts, elts_np=elts_np2, max_obs=max_obs, num_obs=num_obs,
    elt_batch_size=elt_batch_size, time_batch_size=time_batch_size,
    R_deg=R_deg, thresh_deg=thresh_deg, R_is_trainable=R_is_trainable, alpha=alpha, beta=beta, 
    q_cal=q_cal, use_calibration=use_calibration)

# Use Adam optimizer with gradient clipping
# learning_rate = 2.0e-5
learning_rate = 2.0E-5  # default 1.0E-3
beta_1 = 0.900          # default 0.900
beta_2 = 0.999          # default 0.999
epsilon = 1.0E-7        # default 1.0E-7
amsgrad = False         # default False
clipvalue = 5.0         # default not used
opt = keras.optimizers.Adam(learning_rate=learning_rate, 
                            beta_1=beta_1, 
                            beta_2=beta_2,
                            epsilon=epsilon, 
                            amsgrad=amsgrad,
                            # clipvalue=clipvalue, 
                            )
model.compile(optimizer=opt)
# Whether to display results
display = False

## Initial Losses & Gradients

In [9]:
R = deg2dist(R_deg)
A = 1.0 / R**2
thresh = deg2dist(thresh_deg)
mu_per_obs, sigma2_per_obs = score_mean_var(A, thresh=thresh)
N_obs = 5.69e6
mu = N_obs * mu_per_obs
sigma2 = N_obs * sigma2_per_obs

# report
print(f'Resolution: {R_deg} degrees / {R:8.6} cartesian')
print(f'A         : {A:8.3f}')
print('\nPer Observation:')
print(f'mu:     {mu_per_obs:5.3e}')
print(f'sigma2: {sigma2_per_obs:5.3e}')
print('\nFor Data:')
print(f'mu:     {mu:10.6f}')
print(f'sigma2: {sigma2:10.6f}')

Resolution: 1.0 degrees / 0.0174531 cartesian
A         : 3282.890

Per Observation:
mu:     5.993e-05
sigma2: 4.813e-05

For Data:
mu:     340.986259
sigma2: 273.882006


In [10]:
# Report losses before training
print(f'Processed easy batch with {elt_batch_size} asteroids. Perturbed second half.')
if display:
    print_header('Model Before Training:')
pred0 = model.predict_on_batch(ds)
elts0, R0, u_pred0, z0, _ = pred0
scores0, traj_err0, elt_err0 = \
    report_model(model=model, ds=ds, R_deg=R_deg, mask_good=mask_good, 
                 batch_size=elt_batch_size, steps=steps, elts_true=elts_true, display=display)

Processed easy batch with 64 asteroids. Perturbed second half.


In [11]:
# Get intital gradients on entire data set
with tf.GradientTape(persistent=True) as gt:
    gt.watch([model.elements.e_, model.elements.inc_, model.elements.R_])
    pred = model.predict_on_batch(ds.take(traj_size))
    elts, R, u_pred, z, scores = pred
    # unpack elements
    a = elts[:,0]
    e = elts[:,1]
    inc = elts[:,2]
    Omega = elts[:,3]
    omega = elts[:,4]
    f = elts[:,5]
    epoch = elts[:,6]
    # unpack scores
    raw_score = scores[:, 0]
    mu = scores[:, 1]
    sigma2 = scores[:, 2]
    objective = scores[:, 3]
    # total loss function
    loss = tf.reduce_sum(-objective)

# Derivatives of loss w.r.t. control variables for elements and R
dL_da = gt.gradient(loss, model.elements.a_) / steps
dL_de = gt.gradient(loss, model.elements.e_) / steps
dL_dinc = gt.gradient(loss, model.elements.inc_) / steps
dL_dOmega = gt.gradient(loss, model.elements.Omega_) / steps
dL_domega = gt.gradient(loss, model.elements.omega_) / steps
dL_df = gt.gradient(loss, model.elements.f_) / steps
dL_dR = gt.gradient(loss, model.elements.R_) / steps
del gt

In [12]:
# Report gradients
report_model_attribute(att=np.abs(dL_da), mask_good=mask_good, att_name='abs(grad_a)')
report_model_attribute(att=np.abs(dL_de), mask_good=mask_good, att_name='abs(grad_e)')
report_model_attribute(att=np.abs(dL_dinc), mask_good=mask_good, att_name='abs(grad_inc)')


Mean & Std abs(grad_a) by Category:
Good: 20346.64 +/- 26002.40
Bad:   9192.30 +/- 15219.63
All:  14769.47 +/- 22022.39

Mean & Std abs(grad_e) by Category:
Good:  5121.73 +/-  7266.40
Bad:   2641.09 +/-  5178.73
All:   3881.41 +/-  6430.26

Mean & Std abs(grad_inc) by Category:
Good:  6409.71 +/-  6441.77
Bad:   2356.57 +/-  3180.55
All:   4383.14 +/-  5469.29


In [13]:
dL_da[mask_good].numpy()

array([   6358.932  ,    8780.761  ,    5846.5327 ,    3052.6455 ,
         29515.115  ,    8058.57   ,  -33225.8    ,   35585.33   ,
         15536.563  ,     173.07379,    6094.67   ,    4249.4507 ,
         18456.44   ,   11300.358  ,   -4659.456  ,   65316.92   ,
         -6364.0933 ,    1470.2394 ,   14391.808  ,   19684.38   ,
        -27798.596  ,  -28614.352  ,   57122.098  , -127493.36   ,
          2101.0498 ,   -1628.559  ,   -1227.6288 ,   -8258.63   ,
         28621.734  ,    8879.207  ,     883.169  ,   60343.133  ],
      dtype=float32)

In [14]:
dL_da[mask_bad].numpy()

array([ -4261.484  ,   7345.37   ,   -501.7567 ,  13816.406  ,
        -7941.4165 ,    249.37183,      0.     , -11514.715  ,
        -4419.9995 ,  40459.36   ,      0.     ,   -282.43674,
        -2899.2373 ,  -2411.7175 ,  -1946.2804 ,  25964.748  ,
        -5674.3496 ,   -708.3224 , -12106.656  ,  -8593.382  ,
         6584.795  ,  15095.895  ,  -8941.217  ,   -539.3009 ,
         4836.5537 ,  -7878.742  ,    748.54   , -17015.94   ,
        -1678.1724 ,  79533.5    ,    203.87889,      0.     ],
      dtype=float32)

## Train Model

In [15]:
# Train model
steps_per_epoch = steps*cycles_per_epoch
rows_per_epoch = steps_per_epoch * time_batch_size
print_header(f'Training for {epochs} Epochs of Size {rows_per_epoch} Observation Days (Cycles/Epoch = {cycles_per_epoch})...')
print(f'R (degrees):    {R_deg:5.1f}')
print(f'R_is_trainable: {R_is_trainable}')
print(f'alpha:          {alpha:5.1f}')
print(f'beta:           {beta:5.1f}')
print(f'learning_rate:  {learning_rate:7.2e}')
print(f'clipvalue:      {clipvalue:5.2f}')

train_time_0 = time.time()
hist = model.fit(ds, epochs=epochs, steps_per_epoch=steps_per_epoch)
train_time_1 = time.time()
train_time = train_time_1 - train_time_0
print(f'Elapsed Time: {str(timedelta(seconds=train_time))}')

# Report results
if display:
    print_header('Model After Training:')
pred1 = model.predict_on_batch(ds)
elts1, R1, u_pred1, z1, _ = pred1
scores1, traj_err1, elt_err1 = \
    report_model(model=model, ds=ds, R_deg=R_deg, mask_good=mask_good, 
                 batch_size=elt_batch_size, steps=steps, elts_true=elts_true, display=display)

# Unpack scores after training
raw_score = scores1[:,0]
mu = scores1[:,1]
sigma2 = scores1[:,2]
objective = scores1[:,3]
sigma = scores1[:, 4]
t_score = scores1[:, 5]
eff_obs = raw_score - mu

## Change in scores
d_scores = scores1 - scores0
d_traj_err = traj_err1 - traj_err0
d_elt_err = elt_err1 - elt_err0
d_R = R1 - R0

# Report training progress: scores, orbital element errors, and resolution
print_header('Progress During Training:')
scores_01 = (scores0, scores1)
traj_err_01 = (traj_err0, traj_err1)
elt_err_01 = (elt_err0, elt_err1)
R_01 = (R0, R1)
report_training_progress(scores_01, traj_err_01, elt_err_01, R_01, mask_good)


********************************************************************************
Training for 10 Epochs of Size 23710 Observation Days (Cycles/Epoch = 10)...
********************************************************************************
R (degrees):      1.0
R_is_trainable: False
alpha:            1.0
beta:             0.0
learning_rate:  2.00e-05
clipvalue:       5.00
Train for 10 steps
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Elapsed Time: 0:00:10.489894

********************************************************************************
Progress During Training:
********************************************************************************

Mean & Std Change in Trajectory Error (AU) by Category:
Good: +0.006866 +/- 0.005200 -- from 0.000000 to 0.006866     (+1593630.5%)
Bad:  -0.010718 +/- 0.035835 -- from 0.194696 to 0.183978     (-5.5%)
All:  -0.001926 +/- 0.027072 -- from 0.097348 to 0.095422     (-2.0%)

Mea