In [1]:
# Core
import numpy as np
import pandas as pd

# Tensorflow / ML
import tensorflow as tf
from tensorflow.python.keras import backend as K

# Plotting
import matplotlib.pyplot as plt
import matplotlib as mpl

# Utility
import time

# MSE Imports
import kepler_sieve
from asteroid_element import load_ast_elt
from candidate_element import asteroid_elts, perturb_elts, random_elts, elts_add_mixture_params, elts_add_H
from ztf_ast import load_ztf_nearest_ast, calc_hit_freq
from ztf_element import load_ztf_batch, make_ztf_batch, ztf_score_by_elt, ztf_elt_summary
from asteroid_model import AsteroidPosition, AsteroidDirection, make_model_ast_pos
from asteroid_search_layers import CandidateElements, TrajectoryScore
from asteroid_search_model import AsteroidSearchModel
from asteroid_search_report import traj_diff
from astro_utils import deg2dist, dist2deg, dist2sec

In [2]:
# Aliases
keras = tf.keras

# Constants
dtype = tf.float32
dtype_np = np.float32
space_dims = 3

In [3]:
# Set plot style variables
mpl.rcParams['figure.figsize'] = [16.0, 10.0]
mpl.rcParams['font.size'] = 16

## Load ZTF Data and Batch of Orbital Elements

In [4]:
# Load orbital elements for known asteroids
ast_elt = load_ast_elt()

# Number of asteroids
N_ast = ast_elt.shape[0]

In [5]:
# Load ztf nearest asteroid data
ztf_ast = load_ztf_nearest_ast()

In [6]:
# Asteroid numbers and hit counts
ast_num, hit_count = calc_hit_freq(ztf=ztf_ast, thresh_sec=2.0)

# Sort the hit counts in descending order and find the top batch_size
idx = np.argsort(hit_count)[::-1]

# Extract the asteroid number and hit count for this batch
ast_num_best = ast_num[idx]
hit_count_best = hit_count[idx]

In [7]:
# Parameters to build elements batch
batch_size = 64

# Batch of unperturbed elements
elts_ast = asteroid_elts(ast_nums=ast_num_best[0:batch_size])

In [8]:
# # Review unperturbed elements
# elts_ast

In [9]:
# Selected elements
elts = elts_ast

## Batches of ZTF Data vs. Elements

In [10]:
# Arguments to make_ztf_batch
thresh_deg = 2.0 / 3600.0
near_ast = True
regenerate = False

In [11]:
# Load unperturbed element batch
ztf_elt = load_ztf_batch(elts=elts, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [12]:
# Score by element - unperturbed
score_by_elt = ztf_score_by_elt(ztf_elt)

In [13]:
# Summarize the ztf element batch: unperturbed asteroids
ztf_elt_summary(ztf_elt, score_by_elt, 'Unperturbed Asteroids')

ZTF Element Dataframe Unperturbed Asteroids:
                  Total     (Per Batch)
Observations   :    10333   (      161)

Summarize score = sum(-1.0 - log(v)) by batch.  (Mean=0, Variance=num_obs)
Mean score     :     159.76
Sqrt(batch_obs):      12.71
Mean t_score   :      12.61


In [14]:
# Mixture parameters
num_hits: int = 20
R_deg: float = 2.0 / 3600.0

In [15]:
# Add mixture parameters to candidate elements
elts_add_mixture_params(elts=elts, num_hits=num_hits, R_deg=R_deg, thresh_deg=thresh_deg)

In [16]:
# Add brightness parameter H
elts_add_H(elts=elts)

## View Example DataFrames and Hits

In [17]:
# Review ztf_elt DataFrame
ztf_elt

Unnamed: 0,ztf_id,element_id,ObjectID,CandidateID,TimeStampID,mjd,ra,dec,mag_app,ux,...,elt_ux,elt_uy,elt_uz,elt_r,s,s_sec,z,v,is_hit,is_match
0,341737,733,b'ZTF19abizrac',937427766115015019,40797,58691.427766,33.130412,43.596186,17.784599,0.606481,...,0.606486,0.637448,0.475220,2.115260,0.000007,1.509138,1.0,0.569375,True,False
1,345725,733,b'ZTF19abjajfg',937467364815015020,40840,58691.467361,33.148426,43.605278,18.220100,0.606265,...,0.606270,0.637614,0.475274,2.114851,0.000007,1.479386,1.0,0.547145,True,False
2,346522,733,b'ZTF19abjajmr',937468726115015011,40842,58691.468727,33.149062,43.605587,18.129601,0.606257,...,0.606263,0.637619,0.475276,2.114837,0.000007,1.503609,1.0,0.565210,True,False
3,347644,733,b'ZTF19abiyxiu',937402264815015008,40777,58691.402269,33.118785,43.590288,17.705700,0.606621,...,0.606626,0.637341,0.475186,2.115523,0.000007,1.491440,1.0,0.556098,True,False
4,431445,733,b'ZTF19abkkfhr',934448315015015003,40221,58688.448310,31.751906,42.913068,17.974701,0.622775,...,0.622780,0.624447,0.471393,2.145603,0.000007,1.493960,1.0,0.557979,True,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10328,5447574,324582,b'ZTF20aapdfmj',1145121244815015014,95204,58899.121250,42.392273,29.202629,18.231701,0.644677,...,0.644680,0.734021,0.213544,2.915471,0.000005,0.934688,1.0,0.218410,True,False
10329,5450145,324582,b'ZTF20aapdfmj',1145121716115015016,95205,58899.121713,42.392439,29.202638,18.045900,0.644675,...,0.644678,0.734022,0.213543,2.915477,0.000005,1.005771,1.0,0.252894,True,False
10330,5461311,324582,b'ZTF20aapeobw',1145164884815015018,95284,58899.164884,42.405891,29.202386,18.825701,0.644538,...,0.644541,0.734161,0.213479,2.916052,0.000004,0.890844,1.0,0.198401,True,False
10331,5461312,324582,b'ZTF20aapeobw',1145165336115015009,95285,58899.165336,42.406082,29.202412,18.750099,0.644536,...,0.644540,0.734163,0.213478,2.916058,0.000005,1.056357,1.0,0.278973,True,False


In [18]:
# Build numpy array of times
ts_np = ztf_elt.mjd.values.astype(dtype_np)

# Get observation count per element
row_lengths_np = ztf_elt.element_id.groupby(ztf_elt.element_id).count().values.astype(np.int32)

In [19]:
# Review results
element_id_best = ast_num_best[0]
mask = (ztf_elt.element_id == element_id_best)
hits_best = np.sum(ztf_elt[mask].is_hit)
hit_rate_best = np.mean(ztf_elt[mask].is_hit)
rows_best = np.sum(mask)
s_sec_min = np.min(ztf_elt[mask].s_sec)
idx = np.argmin(ztf_elt.s)
ztf_id = ztf_elt.ztf_id[idx]
# ztf_elt[mask].iloc[idx:idx+1]
print(f'Best asteroid has element_id = {element_id_best}')
print(f'Hit count: {hits_best} / {rows_best} observations')
print(f'Hit rate : {hit_rate_best:8.6f}')
print(f'Closest hit: {s_sec_min:0.3f} arc seconds')
# ztf_elt[mask]

Best asteroid has element_id = 51921
Hit count: 158 / 158 observations
Hit rate : 1.000000
Closest hit: 0.381 arc seconds


## Build Asteroid Search Model

In [20]:
# Observatory for ZTF data is Palomar Mountain
site_name = 'palomar'

In [21]:
# Training parameters
learning_rate = 2.0**-12
clipnorm = 1.0
save_at_end: bool = True

In [22]:
# Build asteroid search model
model = AsteroidSearchModel(
        elts=elts, ztf_elt=ztf_elt, 
        site_name=site_name, thresh_deg=thresh_deg, 
        learning_rate=learning_rate, clipnorm=clipnorm)

In [23]:
# Report summary outputs
model.report()


Good elements (hits >= 5):  64.00

         \  log_like :  hits  :    R_sec : thresh_sec
Mean Good:     3.36  :  99.23 :     1.00 :     2.00
Mean Bad :      nan  :    nan :      nan :      nan
Mean     :     3.36  :  99.23 :     1.00 :     2.00
Median   :     3.05  : 100.00 :     1.00 :     2.00
GeoMean  :     2.99  :  94.83 :     1.00 :     3.00
Min      :    -8.21  :  33.00 :     1.00 :     2.00
Max      :    11.49  : 150.00 :     1.00 :     2.00
Trained for 0 batches over 0 epochs and 0 episodes (elapsed time 0 seconds).


  out=out, **kwargs)


In [None]:
# Freeze orbital elements; train only mixture parameters
model.freeze_candidate_elements()

In [None]:
# Load model
# model.load()

In [None]:
# Report summary outputs
# model.report()

In [None]:
# fig, ax = model_gold.plot_bar('hits', sorted=False)

## Check that Predicted Direction Matches Expected Direction

In [24]:
model.recalibrate()

In [25]:
# Predicted position
q_pred, v_pred = model.predict_position()

In [26]:
# Expected position
cols_q = ['qx', 'qy', 'qz']
cols_v = ['vx', 'vy', 'vz']
q_true = ztf_elt[cols_q].values
v_true = ztf_elt[cols_v].values

In [27]:
# Difference between actual and predicted
dq = q_pred - q_true
dv = v_pred - v_true

q_err = np.linalg.norm(dq, axis=-1)
v_err = np.linalg.norm(dv, axis=-1)

q_err_rel = q_err / np.linalg.norm(q_true, axis=-1)
v_err_rel = v_err / np.linalg.norm(v_true, axis=-1)

q_err_mean = np.mean(q_err)
v_err_mean = np.mean(v_err)

q_err_rel_mean = np.mean(q_err_rel)
v_err_rel_mean = np.mean(v_err_rel)

print(f'Mean position error: {q_err_mean:6.2e} AU       ({q_err_rel_mean:6.2e} rel)')
print(f'Mean velocity error: {v_err_mean:6.2e} AU / day ({v_err_rel_mean:6.2e} rel)')

Mean position error: 1.15e-05 AU       (5.20e-06 rel)
Mean velocity error: 5.77e-08 AU / day (4.71e-06 rel)


In [28]:
dq_cal = model.position.dq.numpy()
dv_cal = model.position.dv.numpy()

In [29]:
# Predicted direction
u_pred, r_pred, mag_pred = model.predict_direction()

In [30]:
# Expected direction
cols_u = ['elt_ux', 'elt_uy', 'elt_uz']
u_true = ztf_elt[cols_u].values

In [31]:
# Difference between actual and predicted
du = u_pred - u_true
u_err = np.linalg.norm(du, axis=-1)
u_err_sec = dist2sec(u_err)
u_err_mean = np.mean(u_err)
u_err_mean_sec = np.mean(u_err_sec)
print(f'Mean direction error: {u_err_mean:6.2e} Cartesian / {u_err_mean_sec:6.2f} arc seconds')

Mean direction error: 4.12e-06 Cartesian /   0.85 arc seconds


In [32]:
np.max(u_err_sec)

3.5408127

In [33]:
cols_u_obs = ['ux', 'uy', 'uz']
u_obs = ztf_elt[cols_u_obs].values

In [34]:
du = u_pred - u_obs
u_diff = np.linalg.norm(du, axis=-1)
u_diff_sec = dist2sec(u_diff)
u_diff_mean = np.mean(u_diff)
u_diff_mean_sec = np.mean(u_diff_sec)
print(f'Mean direction diff: {u_diff_mean:6.2e} Cartesian / {u_diff_mean_sec:6.2f} arc seconds')

Mean direction diff: 5.08e-06 Cartesian /   1.05 arc seconds


In [35]:
np.max(u_diff_sec)

5.3548026

In [None]:
# _ = model.position.dq.assign(np.zeros_like(dq_cal))
# _ = model.position.dv.assign(np.zeros_like(dv_cal))

In [None]:
model.sieve()


********************************************************************************
Round 1: 512 batches @ LR 2^-12 in mixture mode; thresh_sec_max = 7200.0
********************************************************************************

Training episode 0: Epoch    0, Batch      0
effective_learning_rate=2.441e-04, training_time 0 sec.
Train on 4096 samples
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4
Adjusted element weight down on 64 candidate elements. Mean weight = 5.00e-01
Increasing bad_episode_count to 1.
                    \  All Elts : Bad Elts : Good Elts (64)
Geom Mean Resolution:      1.00 :      nan :     1.00 arc seconds
Geom Mean Threshold :      2.00 :      nan :     2.00 arc seconds
Mean Log Likelihood :     12.29 :      nan :    12.29
Mean Hits           :    137.80 :      nan :   137.80
Good Elements       :     64.00

Training episode 1: Epoch    4, Batch    256
effective_learning_rate=1.221e-04, training_time 71 sec.
Train on 4096 samples
Epoch 5/8
Epoch 6/8
Epoch 7/8
