In [None]:
# Core
import numpy as np
import pandas as pd

# Tensorflow / ML
import tensorflow as tf
from tensorflow.python.keras import backend as K

# Utility
import time
from datetime import timedelta

# Local imports
from asteroid_search_model import make_model_asteroid_search
from ztf_data import load_ztf_easy_batch
from asteroid_data import make_ztf_dataset, orbital_element_batch
from asteroid_integrate import calc_ast_pos
from asteroid_search_report import report_model, report_training_progress
from utils import print_header

# Typing
from typing import Optional

In [None]:
from asteroid_search import perturb_elts, test_easy_batch
from asteroid_search_report import report_model_attribute
from search_score_functions import score_mean_var
from asteroid_dataframe import calc_ast_data, spline_ast_vec_df, calc_ast_dir
from ztf_data import make_ztf_batch, make_ztf_near_elt
from astro_utils import deg2dist, dist2deg

In [3]:
# Aliases
keras = tf.keras

## Generate DataSet Matching an Easy Batch

In [4]:
# Resolution and threshold in degrees
R_deg: float = 0.2
thresh_deg: float = 1.0
# Whether resolution R is trainable
R_is_trainable = True

# Batch setup
time_batch_size = None
elt_batch_size = 64
epochs = 10
cycles_per_epoch = 10

# Set calibration flag
use_calibration: bool = True

In [5]:
# Load all ZTF data with nearest asteroid calculations
ztf, elts = load_ztf_easy_batch(batch_size=elt_batch_size, thresh_deg=thresh_deg)

# Build a TensorFlow DataSet from ZTF DataFrame
ds, ts, row_len = make_ztf_dataset(ztf=ztf, batch_size=time_batch_size)

# Trajectory size and steps per batch
traj_size = ts.shape[0]
# If time_batch_size was None, use all the time points in each batch
if time_batch_size is None:
    time_batch_size = traj_size
steps = int(np.ceil(traj_size / time_batch_size))

# The element_id is the same as the asteroid number on the easy batch
element_id = elts.element_id.values
# The epoch (scalar)
epoch = elts['epoch'][0]

# The correct orbital elements; also a DataFrame
elts_true = elts.copy()

# Get example batch
batch_in, batch_out = list(ds.take(1))[0]
# Contents of this batch
t = batch_in['t']
idx = batch_in['idx']
row_len = batch_in['row_len']
u_obs = batch_in['u_obs']

# Get max_obs and number of observations
max_obs: int = u_obs.shape[1]
# The number of observations is the TOTAL FOR THE ZTF DATA SET!
# It's not just the size of this easy batch, b/c the easy batch has been harvested to be close!
# num_obs: float = np.sum(row_len, dtype=np.float32)
num_obs: float = 5.69E6

In [6]:
ztf

Unnamed: 0,ztf_id,element_id,ObjectID,CandidateID,TimeStampID,mjd,ra,dec,ux,uy,uz,elt_ux,elt_uy,elt_uz,elt_r,elt_sec,score,is_hit
0,37606,733,b'ZTF18abtdsvv',621336711115010010,9577,58375.336713,350.405941,0.879484,0.985897,-0.146790,0.080371,0.983072,-0.161029,0.087400,2.595640,3326.835799,0.000023,False
1,37607,733,b'ZTF18abtdsvv',621309991115010019,9572,58375.310000,350.405990,0.879483,0.985897,-0.146789,0.080371,0.983090,-0.160937,0.087369,2.595627,3306.916479,0.000026,False
2,37610,733,b'ZTF18abtdsvv',614358404915010002,8240,58368.358403,350.405972,0.879528,0.985897,-0.146789,0.080371,0.987293,-0.137772,0.079197,2.599708,1897.631312,0.031018,False
3,37612,733,b'ZTF18abtdsvv',615324371115015021,8425,58369.324375,350.405987,0.879540,0.985897,-0.146789,0.080371,0.986748,-0.140962,0.080359,2.598268,1214.514216,0.241065,False
4,141833,733,b'ZTF18abwpfho',623339405515010000,9980,58377.339410,348.764546,0.189321,0.980829,-0.177448,0.080534,0.981751,-0.167709,0.089660,2.597194,2759.450676,0.000646,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
90205,5584064,324582,b'ZTF18aaaknvz',1149162620015010022,96198,58903.162627,67.404504,45.304078,0.270241,0.878532,0.393893,0.271240,0.872415,0.406597,1.545974,2915.647647,0.000275,False
90206,5584155,324582,b'ZTF20aaqjjhn',1149162620115015001,96198,58903.162627,66.164947,45.635778,0.282557,0.871180,0.401505,0.271240,0.872415,0.406597,1.545974,2572.286861,0.001692,False
90207,5584262,324582,b'ZTF20aaqjltu',1149162621915015008,96198,58903.162627,67.072721,46.398450,0.268658,0.870797,0.411746,0.271240,0.872415,0.406597,1.545974,1234.232609,0.230095,False
90208,5584879,324582,b'ZTF20aaqjltx',1149162621915015001,96198,58903.162627,66.993445,46.563928,0.268718,0.869478,0.414486,0.271240,0.872415,0.406597,1.545974,1812.635532,0.042044,False


In [7]:
elts

Unnamed: 0,element_id,a,e,inc,Omega,omega,f,epoch
0,733,3.398871,0.058840,0.354075,5.951576,3.315651,3.703184,58600.0
1,1476,2.280890,0.189703,0.110415,5.768549,6.107239,0.309986,58600.0
2,1803,2.349173,0.247823,0.376183,5.886310,4.431771,2.468021,58600.0
3,2015,2.335363,0.103840,0.207837,6.012107,4.788967,1.897608,58600.0
4,2294,2.581424,0.116455,0.109954,5.075755,0.753231,0.373019,58600.0
...,...,...,...,...,...,...,...,...
59,203722,3.171154,0.291647,0.426091,5.309571,1.789398,-1.162711,58600.0
60,253246,2.677164,0.227275,0.446592,5.260714,1.455787,-0.777289,58600.0
61,306781,2.534977,0.209075,0.228417,5.551083,0.846400,-0.451545,58600.0
62,313521,2.637372,0.254435,0.524255,4.487565,2.113770,-1.015696,58600.0


## Review Scores on Unperturbed ZTF Easy Batch

In [18]:
def report_ztf_hit_noise(ztf):
    # Mean hist and score
    mean_hits = np.sum(ztf.is_hit) / elt_batch_size
    mean_score = np.sum(ztf.score) / elt_batch_size
    mean_score_hits = np.sum(ztf[ztf.is_hit].score) / elt_batch_size
    mean_score_noise = np.sum(ztf[~ztf.is_hit].score) / elt_batch_size
    score_hit_pct = mean_score_hits / mean_score * 100.0
    score_noise_pct = mean_score_noise / mean_score * 100.0

    # Report
    print(f'Mean for elements in easy batch:')
    print(f'Hits:  {mean_hits:6.2f}')
    print(f'Score: {mean_score:6.2f}')
    print(f'\nScore due to Hits vs. Noise')
    print(f'Hits:  {mean_score_hits:6.2f} / {score_hit_pct:5.2f}%')
    print(f'Noise: {mean_score_noise:6.2f} / {score_noise_pct:5.2f}%')

In [19]:
report_ztf_hit_noise(ztf)

Mean for elements in easy batch:
Hits:  161.42
Score: 265.86

Score due to Hits vs. Noise
Hits:  161.42 / 60.72%
Noise: 104.44 / 39.28%


## Orbital Element Batch

In [8]:
# Mask where data perturbed vs not
mask_good = np.arange(elt_batch_size) < (elt_batch_size//2)
mask_bad = ~mask_good
# Perturb second half of orbital elements
elts_pert = perturb_elts(elts, sigma_a=0.00, sigma_e=0.00, sigma_f_deg=0.1, mask=mask_bad)
# elts_pert = perturb_elts(elts, mask=mask_bad)

# Orbits for calibration on perturbed elements
if use_calibration:
    print(f'Numerically integrating calibration trajectories on perturbed elements, q_cal...')
    q_cal = calc_ast_pos(elts=elts_pert, epoch=epoch, ts=ts)
else:
    q_cal = None

Numerically integrating calibration trajectories on perturbed elements, q_cal...


In [9]:
# ZTF batch of perturbed elements
file_path = '../data/ztf/ztf-easy-batch-pert.h5'
try:
    ztfp = pd.read_hdf(file_path)
except:
    ztfp = make_ztf_batch(elts=elts_pert, thresh_deg=thresh_deg, near_ast=False)
    ztf.to_hdf(file_path, key='ztf', mode='w')

HBox(children=(FloatProgress(value=0.0, max=64.0), HTML(value='')))




In [None]:
# ztfp

## Build Model

In [None]:
# The observation site for ZTF data is Palomar Mountain.
site_name = 'palomar'

# Alpha and beta parameters for the objective function
alpha = 1.0
beta = 0.0

# Build functional model for asteroid score
model = make_model_asteroid_search(\
    ts=ts, elts=elts_pert, max_obs=max_obs, num_obs=num_obs, site_name=site_name,
    elt_batch_size=elt_batch_size, time_batch_size=time_batch_size,
    R_deg=R_deg, thresh_deg=thresh_deg, R_is_trainable=R_is_trainable, alpha=alpha, beta=beta, 
    q_cal=q_cal, use_calibration=use_calibration)

# Use Adam optimizer with gradient clipping
# learning_rate = 2.0e-5
learning_rate = 2.0E-5  # default 1.0E-3
beta_1 = 0.900          # default 0.900
beta_2 = 0.999          # default 0.999
epsilon = 1.0E-7        # default 1.0E-7
amsgrad = False         # default False
clipvalue = 5.0         # default not used
opt = keras.optimizers.Adam(learning_rate=learning_rate, 
                            beta_1=beta_1, 
                            beta_2=beta_2,
                            epsilon=epsilon, 
                            amsgrad=amsgrad,
                            # clipvalue=clipvalue, 
                            )
model.compile(optimizer=opt)
# Whether to display results
display = False

## Initial Losses

In [None]:
R = deg2dist(R_deg)
A = 1.0 / R**2
thresh = deg2dist(thresh_deg)
mu_per_obs, sigma2_per_obs = score_mean_var(A, thresh=thresh)
N_obs = 5.69e6
mu = N_obs * mu_per_obs
sigma2 = N_obs * sigma2_per_obs
sigma = np.sqrt(sigma2)

# report
print(f'Resolution: {R_deg} degrees / {R:8.6} cartesian')
print(f'A         : {A:8.3f}')
print('\nPer Observation:')
print(f'mu:     {mu_per_obs:5.3e}')
print(f'sigma2: {sigma2_per_obs:5.3e}')
print('\nFor Data:')
print(f'mu:     {mu:10.6f}')
print(f'sigma2: {sigma2:10.6f}')
print(f'simga:  {sigma:10.6f}')

In [None]:
# Report losses before training
print(f'Processed easy batch with {elt_batch_size} asteroids. Perturbed second half.')
if display:
    print_header('Model Before Training:')
pred0 = model.predict_on_batch(ds)
elts0, R0, u_pred0, z0, _ = pred0
scores0, traj_err0, elt_err0 = \
    report_model(model=model, ds=ds, R_deg=R_deg, mask_good=mask_good, 
                 batch_size=elt_batch_size, steps=steps, elts_true=elts_true, display=display)

## Check that untrained model predictions match expected results

In [None]:
# The results predicted by the TF model, as a numpy array of shape (batch_size=64, N_t=2371, 3)
u_pred_tf = u_pred0.numpy()

In [None]:
# Alias tensorflow of times to numpy array mjd
mjd = ts.numpy()

# Show that mjd is sorted
# np.min(np.diff(mjd))

In [None]:
# Arrays of element errors in Cartesian distance degrees
obs_dist = np.zeros(elt_batch_size)
elt_dist = np.zeros(elt_batch_size)
ast_dist = np.zeros(elt_batch_size)

# Columns to extract integrated asteroid direction
cols_obs_dir = ['ux', 'uy', 'uz']
cols_ast_dir = ['ast_ux', 'ast_uy', 'ast_uz']
cols_elt_dir = ['elt_ux', 'elt_uy', 'elt_uz']

for i, elt_id in enumerate(element_id):
    # Select entries for this element that are hits
    mask = (ztfp.element_id == elt_id) & (ztfp.nearest_ast_num == elt_id)
    ztf_i = ztfp[mask]
    
    # Time indices on left and right side side; want idx s/t mjd[idx] == ztf_i.mjd
    idx_l = np.searchsorted(a=mjd, v=ztf_i.mjd)-1
    idx_r = idx_l + 1

    # Find the side that is closer
    dist_l = ztf_i.mjd.values - mjd[idx_l]
    dist_r = mjd[idx_r] - ztf_i.mjd.values
    is_left = (dist_l < dist_r)
    idx = idx_r.copy()

    # The index that is closest
    idx[is_left] = idx_l[is_left]

    # TensorFlow predictions on the ZTF dates for this element_id
    u_pred_tf_i = u_pred_tf[i, idx, :]

    # Direction: observation and MSE prediction in Numpy
    u_obs_np_i = ztf_i[cols_obs_dir].values
    u_elt_np_i = ztf_i[cols_elt_dir].values
    u_ast_np_i = ztf_i[cols_ast_dir].values

    # Difference between TensorFlow and Numpy
    u_diff_obs_i = u_pred_tf_i - u_obs_np_i
    u_diff_elt_i = u_pred_tf_i - u_elt_np_i
    u_diff_ast_i = u_pred_tf_i - u_ast_np_i
    
    # Save this entry to err_obs_dist and err_ast_dist
    obs_dist[i] = np.mean(np.linalg.norm(u_diff_obs_i, axis=1))
    elt_dist[i] = np.mean(np.linalg.norm(u_diff_elt_i, axis=1))
    ast_dist[i] = np.mean(np.linalg.norm(u_diff_ast_i, axis=1))

# Distance in degrees
obs_deg = dist2deg(obs_dist)
elt_deg = dist2deg(elt_dist)
ast_deg = dist2deg(ast_dist)

In [None]:
# ztf_i

In [None]:
def report_error(diff_deg, mask_good, dist_type: str):
    mask_bad = ~mask_good
    mean_deg_g = np.mean(diff_deg[mask_good])
    mean_deg_b = np.mean(diff_deg[mask_bad])
    mean_sec_g = mean_deg_g * 3600.0
    mean_sec_b = mean_deg_b * 3600.0

    print(f'Mean Angular Distance: {dist_type}')
    print(f'Good: {mean_deg_g:8.4f} deg / {mean_sec_g:10.3f} arc sec')
    print(f'Bad : {mean_deg_b:8.4f} deg / {mean_sec_b:10.3f} arc sec\n')

In [None]:
# Report errors
report_error(diff_deg=elt_deg, mask_good=mask_good, dist_type='Orbital Elements')
report_error(diff_deg=obs_deg, mask_good=mask_good, dist_type='Observed')

## Review the Real Hits vs. Randomly Close Points

In [None]:
# Good and bad half of ztfp
elt_id_max_good = element_id[elt_batch_size//2-1]
is_good = ztfp.element_id <= elt_id_max_good
is_bad = ~is_good
ztf_g = ztfp[is_good]
ztf_b = ztfp[is_bad]

In [None]:
# Theshold for hits is 2.0 arc seconds
thresh_hit = deg2dist(2.0 / 3600.0)

# Count number of hits; compute mean on good and bad
is_hit = (ztfp.element_id == ztfp.nearest_ast_num) & (ztfp.nearest_ast_dist < thresh_hit)
hits_g = np.sum(is_hit & is_good) / (elt_batch_size/2)
hits_b = np.sum(is_hit & is_bad) / (elt_batch_size/2)

# Report hits
print(f'Mean hits for unperturbed (good) and perturbed (bad) candidate elements.')
print(f'Good: {hits_g:8.2f}')
print(f'Bad : {hits_b:8.2f}')

In [None]:
nearest_ast_deg = dist2deg(ztfp.nearest_ast_dist)
ztfp['nearest_ast_sec'] = nearest_ast_deg * 3600.0
score_arg = 0.5 * (nearest_ast_deg / R_deg)**2
ztfp['score'] = np.exp(-score_arg)

In [None]:
ztfp['is_hit'] = (ztfp.nearest_ast_sec < 2.0) & (ztfp.element_id == ztfp.nearest_ast_num)

In [None]:
np.sum(ztfp[is_bad].score)/32

In [None]:
ztfp

## Initial Gradients

In [None]:
# Get intital gradients on entire data set
with tf.GradientTape(persistent=True) as gt:
    gt.watch([model.elements.e_, model.elements.inc_, model.elements.R_])
    pred = model.predict_on_batch(ds.take(traj_size))
    elts, R, u_pred, z, scores = pred
    # unpack elements
    a = elts[:,0]
    e = elts[:,1]
    inc = elts[:,2]
    Omega = elts[:,3]
    omega = elts[:,4]
    f = elts[:,5]
    epoch = elts[:,6]
    # unpack scores
    raw_score = scores[:, 0]
    mu = scores[:, 1]
    sigma2 = scores[:, 2]
    objective = scores[:, 3]
    # total loss function
    loss = tf.reduce_sum(-objective)

# Derivatives of loss w.r.t. control variables for elements and R
dL_da = gt.gradient(loss, model.elements.a_) / steps
dL_de = gt.gradient(loss, model.elements.e_) / steps
dL_dinc = gt.gradient(loss, model.elements.inc_) / steps
dL_dOmega = gt.gradient(loss, model.elements.Omega_) / steps
dL_domega = gt.gradient(loss, model.elements.omega_) / steps
dL_df = gt.gradient(loss, model.elements.f_) / steps
dL_dR = gt.gradient(loss, model.elements.R_) / steps
del gt

In [None]:
# Report gradients
report_model_attribute(att=np.abs(dL_da), mask_good=mask_good, att_name='abs(grad_a)')
report_model_attribute(att=np.abs(dL_de), mask_good=mask_good, att_name='abs(grad_e)')
report_model_attribute(att=np.abs(dL_dinc), mask_good=mask_good, att_name='abs(grad_inc)')
report_model_attribute(att=np.abs(dL_dR), mask_good=mask_good, att_name='abs(grad_R)')

## Train Model

In [None]:
# Train model
steps_per_epoch = steps*cycles_per_epoch
rows_per_epoch = steps_per_epoch * time_batch_size
print_header(f'Training for {epochs} Epochs of Size {rows_per_epoch} Observation Days (Cycles/Epoch = {cycles_per_epoch})...')
print(f'R (degrees):    {R_deg:5.1f}')
print(f'R_is_trainable:  {R_is_trainable}')
print(f'alpha:          {alpha:5.1f}')
print(f'beta:           {beta:5.1f}')
print(f'learning_rate:    {learning_rate:7.2e}')
print(f'clipvalue:      {clipvalue:5.2f}')

train_time_0 = time.time()
hist = model.fit(ds, epochs=epochs, steps_per_epoch=steps_per_epoch)
train_time_1 = time.time()
train_time = train_time_1 - train_time_0
print(f'Elapsed Time: {str(timedelta(seconds=train_time))}')

# Report results
if display:
    print_header('Model After Training:')
pred1 = model.predict_on_batch(ds)
elts1, R1, u_pred1, z1, _ = pred1
scores1, traj_err1, elt_err1 = \
    report_model(model=model, ds=ds, R_deg=R_deg, mask_good=mask_good, 
                 batch_size=elt_batch_size, steps=steps, elts_true=elts_true, display=display)

# Unpack scores after training
raw_score = scores1[:,0]
mu = scores1[:,1]
sigma2 = scores1[:,2]
objective = scores1[:,3]
sigma = scores1[:, 4]
t_score = scores1[:, 5]
eff_obs = raw_score - mu

## Change in scores
d_scores = scores1 - scores0
d_traj_err = traj_err1 - traj_err0
d_elt_err = elt_err1 - elt_err0
d_R = R1 - R0

# Report training progress: scores, orbital element errors, and resolution
print_header('Progress During Training:')
scores_01 = (scores0, scores1)
traj_err_01 = (traj_err0, traj_err1)
elt_err_01 = (elt_err0, elt_err1)
R_01 = (R0, R1)
report_training_progress(scores_01, traj_err_01, elt_err_01, R_01, mask_good)

## Test Different Resolutions

In [None]:
thresh_deg = 1.0
R_is_trainable = False
elt_batch_size = 64
time_batch_size = None
epochs = 10
cycles_per_epoch = 10

In [None]:
R_deg_cand = np.power(2.0, -np.arange(0.0, 8.0))
R_deg_cand

In [None]:
# Results table
scores_tbl = dict()
traj_err_tbl = dict()
elt_err_tbl = dict()
R_tbl = dict()

In [None]:
# # Iterate through different resolutions
# R_deg_cand = np.power(2.0, -np.arange(0.0, 8.0))
# for R_deg in R_deg_cand:
#     # Generate scores
#     scores_01, traj_err_01, elt_err_01, R_01, mask_good = \
#         test_easy_batch(R_deg=R_deg, 
#                         R_is_trainable=R_is_trainable,
#                         elt_batch_size=elt_batch_size,
#                         time_batch_size=time_batch_size,
#                         epochs=epochs,
#                         cycles_per_epoch=cycles_per_epoch)
#     # Save to table
#     scores_tbl[R_deg] = scores_01
#     traj_err_tbl[R_deg] = traj_err_01
#     elt_err_tbl[R_deg] = elt_err_01
#     R_tbl[R_deg] = R_01

In [None]:
scores_01[0].shape