In [None]:
# Core
import numpy as np
import pandas as pd

# Tensorflow / ML
import tensorflow as tf
from tensorflow.python.keras import backend as K

# Utility
import logging
# import time

# Local imports
from ztf_data import make_ztf_easy_batch
# from ztf_data import load_ztf_det_all, load_ztf_nearest_ast, ztf_nearest_ast, calc_hit_freq
from asteroid_data import make_ztf_dataset, orbital_element_batch
from asteroid_integrate import load_data as load_data_asteroids, calc_ast_pos
from asteroid_model import AsteroidDirection, make_model_ast_pos
from search_score_functions import score_mean, score_var, score_mean_2d, score_var_2d
from astro_utils import deg2dist
from utils import print_header
from tf_utils import Identity

# Typing
from typing import Dict

# Import from asteroid_search
from asteroid_search import perturb_elts, make_model_asteroid_search, report_model

In [2]:
# Aliases
keras = tf.keras

# ********************************************************************************************************************* 
# Turn off all logging; only solution found to eliminate crushing volume of unresolvable autograph warnings
logging.getLogger('tensorflow').disabled = True

# Constants
space_dims = 3

# Load asteroid names and orbital elements
ast_elt = load_data_asteroids()

# Range for a
a_min_: float = 0.5
a_max_: float = 32.0

# Range for e
e_min_: float = 0.0
e_max_: float = 1.0 - 2.0**-10

# Range for resolution parameter R
R_min_ = deg2dist(1.0/3600)
R_max_ = deg2dist(10.0)
log_R_min_ = np.log(R_min_)
log_R_max_ = np.log(R_max_)

## Generate DataSet Matching an Easy Batch

In [3]:
# Batch size for time
time_batch_size = 128

# Batch size for candidate orbital elements
elt_batch_size = 64

In [4]:
# Load all ZTF data with nearest asteroid calculations
ztf = make_ztf_easy_batch(batch_size=elt_batch_size)

In [5]:
ztf

Unnamed: 0,ObjectID,CandidateID,TimeStampID,mjd,ra,dec,ux,uy,uz,mag_app,asteroid_prob,nearest_ast_num,nearest_ast_dist,ast_ra,ast_dec,ast_ux,ast_uy,ast_uz
1656,b'ZTF19aaxbtge',883432713515015010,32606,58637.432720,349.617804,26.695540,0.878779,0.030983,0.476222,20.118000,0.528649,313521,0.006915,349.874445,26.372674,0.881970,0.032184,0.470206
3345,b'ZTF18abtyroq',893472970715010014,34299,58647.472975,354.025163,30.103966,0.860417,0.116893,0.496002,18.765600,0.932566,313521,0.003449,353.882857,29.949447,0.861533,0.113872,0.494767
6049,b'ZTF18abosfjn',896427220115015004,34661,58650.427222,355.296384,30.983843,0.854425,0.140274,0.500280,16.204500,0.916464,313521,0.003766,355.045433,31.000861,0.853957,0.136955,0.501997
6105,b'ZTF18abosfjn',896427220115015003,34661,58650.427222,355.296371,30.983844,0.854425,0.140274,0.500281,16.128201,0.916464,313521,0.003766,355.045433,31.000861,0.853957,0.136955,0.501997
6127,b'ZTF18abosfjn',896427220115015005,34661,58650.427222,355.296415,30.983890,0.854425,0.140275,0.500281,15.890600,0.916464,313521,0.003767,355.045433,31.000861,0.853957,0.136955,0.501997
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5412924,b'ZTF20aaqvlmy',1150185755115015010,96635,58904.185752,41.868105,28.422819,0.654918,0.727863,0.203216,18.203100,1.000000,59123,0.000009,41.867533,28.422845,0.654924,0.727857,0.203219
5441789,b'ZTF20aarbpnr',1151112683815015002,96901,58905.112685,32.327962,34.020453,0.700369,0.629210,0.337013,18.458200,1.000000,73961,0.000007,32.327504,34.020497,0.700372,0.629205,0.337016
5441935,b'ZTF20aarbqkw',1151121034915015003,96919,58905.121030,25.806967,29.268916,0.785333,0.542902,0.297505,18.497400,1.000000,48453,0.000006,25.806544,29.268905,0.785336,0.542897,0.297507
5442136,b'ZTF20aarbqpo',1151121032115015000,96919,58905.121030,24.489749,25.919630,0.818493,0.515939,0.252736,18.606501,1.000000,31539,0.000003,24.489554,25.919688,0.818494,0.515937,0.252739


In [6]:
# Build a TensorFlow DataSet from ZTF DataFrame
ds, ts, row_len = make_ztf_dataset(ztf=ztf, batch_size=time_batch_size)

In [7]:
# ds

In [8]:
# ts

In [9]:
# row_len

In [10]:
print(f'mean(row_len) = {np.mean(row_len):5.1f}')

mean(row_len) =  12.8


## Follow main() in asteroid_search.py

In [11]:
# Trajectory size and steps per batch
traj_size = ts.shape[0]
steps = int(np.ceil(traj_size / time_batch_size))

In [12]:
print(f'traj_size = {traj_size}')
print(f'steps per epoch = {steps}.')

traj_size = 2176
steps per epoch = 17.


In [13]:
# Get example batch
batch_in, batch_out = list(ds.take(1))[0]
# Contents of this batch
t = batch_in['t']
idx = batch_in['idx']
row_len = batch_in['row_len']
u_obs = batch_in['u_obs']

In [14]:
# t

In [15]:
# idx

In [16]:
# u_obs

In [17]:
u_obs.shape

TensorShape([128, 233, 3])

In [18]:
# Get max_obs and number of observations in this batch
max_obs: int = u_obs.shape[1]
num_obs: float = np.sum(row_len, dtype=np.float32)

In [19]:
print(f'max_obs = {max_obs}')
print(f'num_obs = {num_obs}')

max_obs = 233
num_obs = 1098.0


In [20]:
# Batch of perturbed orbital elements for asteroid model
R_deg: float = 2.0
ast_nums = np.unique(ztf.nearest_ast_num)
elts_np = orbital_element_batch(ast_nums)
epoch = elts_np['epoch'][0]

In [21]:
# The correct orbital elements as an array
elts_true = np.array([elts_np['a'], elts_np['e'], elts_np['inc'], elts_np['Omega'], 
                      elts_np['omega'], elts_np['f'], elts_np['epoch']]).transpose()

In [22]:
# Mask where data expected vs not
mask_good = np.arange(64) < 32
mask_bad = ~mask_good
# Perturb second half of orbital elements
elts_np2 = perturb_elts(elts_np, sigma_a=0.00, sigma_e=0.00, sigma_f_deg=0.0, mask=mask_bad)

In [23]:
# Orbits for calibration
if 'q_cal' not in globals():
    print(f'Numerically integrating calibration trajectories q_cal...')
    q_cal = calc_ast_pos(elts=elts_np2, epoch=epoch, ts=ts)
# q_cal = None

Numerically integrating calibration trajectories q_cal...


In [24]:
# Set calibration flag
use_calibration: bool = True

# Alpha and beta parameters for the objective function
alpha = 8.0
beta = 20.0

In [25]:
# Build functional model for asteroid score
model = make_model_asteroid_search(\
    ts=ts, elts_np=elts_np2, max_obs=max_obs, num_obs=num_obs,
    elt_batch_size=elt_batch_size, time_batch_size=time_batch_size,
    R_deg=R_deg, alpha=alpha, beta=beta, q_cal=q_cal, use_calibration=use_calibration)

In [26]:
# Use Adam optimizer with gradient clipping
learning_rate = 2.0e-5
clipvalue = 5.0
opt = keras.optimizers.Adam(learning_rate=learning_rate, 
                            beta_1=0.900, 
                            beta_2=0.999,
                            epsilon=1.0E-7, 
                            clipvalue=clipvalue, 
                            amsgrad=False)
model.compile(optimizer=opt)

In [27]:
# Report losses before training
print(f'Processed easy batch with {elt_batch_size} asteroids. Perturbed second half.')
if display:
    print_header('Model Before Training:')
pred0 = model.predict_on_batch(ds)
elts0, R0, u_pred0, z0, _ = pred0
scores0, traj_err0, elt_err0 = \
    report_model(model=model, ds=ds, R_deg=R_deg, mask_good=mask_good, 
                 batch_size=elt_batch_size, steps=steps, elts_true=elts_true, display=display)

Processed easy batch with 64 asteroids. Perturbed second half.

********************************************************************************
Model Before Training:
********************************************************************************

Mean & Std Trajectory Error (AU) vs. True Elements by Category:
Good:     0.00 +/-     0.00
Bad:      0.00 +/-     0.00
All:      0.00 +/-     0.00

Error in orbital elements:
(Angles shown in degrees)
      a          e         inc       Omega      omega     f
Good: 0.000000,  0.000000, 0.000002, 0.000000,  0.000002, 0.000002, 
Bad : 0.000000,  0.000000, 0.000003, 0.000000,  0.000000, 0.000001, 

Mean & Std Raw Score by Category:
Good:   578.52 +/-   244.41
Bad:    822.80 +/-   493.86
All:    700.66 +/-   408.33

Mean & Std Mu by Category:
Good:    11.37 +/-     0.00
Bad:     11.37 +/-     0.00
All:     11.37 +/-     0.00

Mean & Std Effective Observations by Category:
Good:   567.14 +/-   244.41
Bad:    811.42 +/-   493.86
All:    689.28 

In [None]:
scores0

In [28]:
from datetime import timedelta