In [1]:
# Core
import numpy as np
import pandas as pd

# Tensorflow / ML
import tensorflow as tf
from tensorflow.python.keras import backend as K

# Plotting
import matplotlib.pyplot as plt
import matplotlib as mpl

# Utility
import time

# MSE Imports
import kepler_sieve
from asteroid_element import load_ast_elt
from candidate_element import asteroid_elts, perturb_elts, random_elts, elts_add_mixture_params
from ztf_ast import load_ztf_nearest_ast, calc_hit_freq
from ztf_element import load_ztf_batch, make_ztf_batch, ztf_score_by_elt, ztf_elt_summary
from asteroid_model import AsteroidPosition, AsteroidDirection, make_model_ast_pos
from asteroid_search_layers import CandidateElements, TrajectoryScore
from asteroid_search_model import AsteroidSearchModel, make_opt_adam
from asteroid_search_report import traj_diff
from nearest_asteroid import nearest_ast_elt
from element_eda import score_by_elt
from asteroid_dataframe import calc_ast_data, spline_ast_vec_df
from astro_utils import deg2dist, dist2deg, dist2sec

Found 4 GPUs.  Setting memory growth = True.


In [2]:
# Aliases
keras = tf.keras

# Constants
dtype = tf.float32
dtype_np = np.float32
space_dims = 3

In [3]:
# Set plot style variables
mpl.rcParams['figure.figsize'] = [16.0, 10.0]
mpl.rcParams['font.size'] = 16

In [4]:
# Color for plots
color_mean = 'blue'
color_lo = 'orange'
color_hi = 'green'
color_min = 'red'
color_max = 'purple'

## Load ZTF Data and Batch of Orbital Elements

In [6]:
# Load orbital elements for known asteroids
ast_elt = load_ast_elt()

# Number of asteroids
N_ast = ast_elt.shape[0]

In [7]:
# Load ztf nearest asteroid data
ztf_ast = load_ztf_nearest_ast()

In [8]:
# Add the original ztf_id
ztf_ast.insert(loc=0, column='ztf_id', value=ztf_ast.index)
# ztf_ast.drop(columns='ztf_id', inplace=True)

In [9]:
# Filter ztf_ast to only include hits
hit_thresh_sec = 2.0
hit_thresh_s = deg2dist(hit_thresh_sec / 3600.0)
ztf_ast['is_hit'] = ztf_ast.nearest_ast_dist < hit_thresh_s
mask_hit = ztf_ast.is_hit
ztf_ast = ztf_ast[mask_hit]

In [10]:
# Review ztf_ast filtered for matches to known asteroids
ztf_ast

Unnamed: 0,ztf_id,ObjectID,CandidateID,TimeStampID,mjd,ra,dec,ux,uy,uz,mag_app,asteroid_prob,nearest_ast_num,nearest_ast_dist,ast_ra,ast_dec,ast_ux,ast_uy,ast_uz,is_hit
602,602,b'ZTF18acqvjbw',689314713015015008,16951,58443.314711,63.025354,17.655906,0.432230,0.899796,-0.059531,18.519199,1.000000,135355,3.663691e-06,63.025567,17.655961,0.432226,0.899798,-0.059531,True
649,649,b'ZTF18acrcadh',693308213515015017,17096,58447.308218,113.354357,33.962194,-0.328791,0.920837,0.209659,18.055201,1.000000,29076,8.997226e-07,113.354324,33.962150,-0.328790,0.920837,0.209658,True
811,811,b'ZTF18abwawbk',621435901715015019,9621,58375.435903,32.348343,11.149056,0.828867,0.558563,-0.031414,19.707199,1.000000,101787,2.084178e-06,32.348465,11.149059,0.828866,0.558565,-0.031414,True
833,833,b'ZTF18acurwxa',707509454415015012,18339,58461.509456,103.310341,27.540611,-0.204137,0.975586,0.080995,19.427401,1.000000,138763,4.068908e-06,103.310581,27.540515,-0.204141,0.975585,0.080994,True
890,890,b'ZTF18acuiphp',707257536115015009,18272,58461.257535,48.496514,22.452973,0.612431,0.786949,0.075091,19.038000,0.976392,105289,3.491132e-06,48.496726,22.453012,0.612428,0.786952,0.075091,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5697829,5697829,b'ZTF20aarepii',1151522393415015017,97090,58905.522396,236.892710,-2.341055,-0.545753,-0.784135,0.295442,15.089800,1.000000,727,2.473898e-06,236.892569,-2.341034,-0.545755,-0.784134,0.295442,True
5697832,5697832,b'ZTF20aareowa',1151522390315015003,97090,58905.522396,237.654792,-5.909325,-0.532176,-0.811960,0.239813,15.717300,1.000000,625,2.355026e-06,237.654657,-5.909328,-0.532178,-0.811958,0.239813,True
5697833,5697833,b'ZTF20aareowm',1151521944715015006,97089,58905.521944,236.239354,-9.304550,-0.548413,-0.817041,0.178007,15.695800,1.000000,464,2.205903e-06,236.239228,-9.304526,-0.548415,-0.817040,0.178007,True
5697838,5697838,b'ZTF20aareowf',1151521946215015005,97089,58905.521944,234.674482,-7.771349,-0.572910,-0.795468,0.197496,16.917900,1.000000,2879,4.163074e-06,234.674242,-7.771368,-0.572914,-0.795466,0.197494,True


In [11]:
# Parameters to build batch of random orbital elements
batch_size_init = 1024
batch_size = 64
random_seed = 0
element_id_start = random_seed * batch_size

# Random elements
elts_rand = random_elts(element_id_start=0, size=batch_size_init,
                        random_seed=random_seed, dtype=dtype_np)

In [12]:
# Review random elements
elts_rand

Unnamed: 0,element_id,a,e,f,inc,Omega,omega,epoch
0,0,2.602843,0.233124,1.436109,0.136210,2.787195,1.639466,58600.0
1,1,2.724244,0.115775,3.466613,0.104491,3.138966,4.711716,58600.0
2,2,2.454118,0.075277,-0.098933,0.046201,5.053379,2.250998,58600.0
3,3,3.236996,0.231935,3.138587,0.310031,6.188627,3.335157,58600.0
4,4,2.688052,0.252831,3.602334,0.130012,1.298741,3.428075,58600.0
...,...,...,...,...,...,...,...,...
1019,1019,2.301447,0.123352,-1.481379,0.152518,3.180693,2.485485,58600.0
1020,1020,2.232799,0.059774,2.314652,0.202272,2.048625,5.188941,58600.0
1021,1021,2.307718,0.106472,-5.096721,0.062396,3.968814,2.726140,58600.0
1022,1022,2.590508,0.058045,3.344757,0.325955,2.490392,4.369721,58600.0


In [13]:
# Alias elts_rand to elts; these are the candidate elements
elts = elts_rand

## Batches of ZTF Data Near Initial Candidate Elements

In [14]:
# Arguments to make_ztf_batch
# thresh_deg = 1.0
thresh_deg = 2.0
# thresh_deg = 4.0
near_ast = False
regenerate = False

In [15]:
# Load perturbed element batch
ztf_elt = load_ztf_batch(elts=elts, ztf=ztf_ast, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [16]:
ztf_elt

Unnamed: 0,ztf_id,element_id,ObjectID,CandidateID,TimeStampID,mjd,ra,dec,ux,uy,...,elt_ux,elt_uy,elt_uz,elt_r,s,s_sec,z,v,is_hit,is_match
0,283228,0,b'ZTF19abiqrbg',936424663415015005,942,58690.424664,21.919034,5.144727,0.923975,0.376782,...,0.935787,0.348785,-0.051495,2.201675,0.033509,6911.997221,0.999439,0.921607,False,False
1,283566,0,b'ZTF19abiqrcs',936424662015015002,942,58690.424664,20.959303,4.269688,0.931243,0.356892,...,0.935787,0.348785,-0.051495,2.201675,0.023965,4943.252882,0.999713,0.471394,False,False
2,284573,0,b'ZTF19abiqrkb',936424665515015008,942,58690.424664,20.684350,6.963277,0.928640,0.369906,...,0.935787,0.348785,-0.051495,2.201675,0.032219,6646.020730,0.999481,0.852050,False,False
3,284720,0,b'ZTF19abiqrlb',936424664015015005,942,58690.424664,18.489075,6.043591,0.943113,0.331218,...,0.935787,0.348785,-0.051495,2.201675,0.029584,6102.325735,0.999562,0.718353,False,False
4,286920,0,b'ZTF19abiqtqw',936424664115015039,942,58690.424664,18.051860,5.620294,0.946206,0.321897,...,0.935787,0.348785,-0.051495,2.201675,0.034357,7087.025737,0.999410,0.968868,False,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3106507,5682134,1023,b'ZTF20aarbsss',1151129913015015008,39107,58905.129919,37.442314,2.834092,0.792995,0.576780,...,0.781182,0.598195,-0.178658,3.584732,0.030082,6205.131098,0.999548,0.742760,False,False
3106508,5682143,1023,b'ZTF20aarbssf',1151129914415015006,39107,58905.129919,38.147911,5.689426,0.782545,0.603365,...,0.781182,0.598195,-0.178658,3.584732,0.025682,5297.443819,0.999670,0.541362,False,False
3106509,5682179,1023,b'ZTF20aarbsvd',1151129914115015014,39107,58905.129919,38.792809,5.842300,0.775368,0.612312,...,0.781182,0.598195,-0.178658,3.584732,0.028558,5890.634000,0.999592,0.669381,False,False
3106510,5682185,1023,b'ZTF20aarbsvl',1151129914315015006,39107,58905.129919,39.864452,5.293346,0.764290,0.622271,...,0.781182,0.598195,-0.178658,3.584732,0.030884,6370.439611,0.999523,0.782859,False,False


In [18]:
# Score by element
score_by_elt = ztf_score_by_elt(ztf_elt)
# score_by_elt['element_id'] = score_by_elt.index.values

In [29]:
f'{5:03d}'

'005'

In [25]:
elts_init = elts_rand

In [27]:
best_elts(elts_init=elts_init, ztf_elt=ztf_elt, batch_size=batch_size, element_id_start=64)

Unnamed: 0,element_id,a,e,f,inc,Omega,omega,epoch,num_obs,score,t_score
0,64,3.064054,0.121235,-1.291546,0.068009,0.662207,0.935398,58600.0,12711,998.002460,8.852007
1,65,2.771068,0.151421,-0.531263,0.068958,1.364763,6.111150,58600.0,11901,886.807208,8.129007
2,66,3.172947,0.095819,1.903106,0.035792,3.984825,0.687484,58600.0,12572,781.406566,6.969071
3,67,2.697356,0.098236,0.132486,0.025427,1.610715,5.259525,58600.0,9835,767.421579,7.738322
4,68,3.113809,0.257043,2.596678,0.090822,0.106255,4.320585,58600.0,9129,697.797101,7.303273
...,...,...,...,...,...,...,...,...,...,...,...
59,123,1.890665,0.083441,3.132056,0.035993,0.656159,3.255417,58600.0,9064,280.408751,2.945314
60,124,3.168699,0.088259,0.853523,0.052655,5.188591,0.902693,58600.0,10197,280.102124,2.773832
61,125,2.255124,0.055124,-1.131100,0.065725,4.743550,1.882832,58600.0,6700,279.483345,3.414432
62,126,2.325415,0.176758,-0.619556,0.046713,3.318604,2.235867,58600.0,4532,277.344008,4.119778


In [None]:
score_by_elt

In [None]:
score_by_elt.sort_values(by='score_sum', ascending=False, inplace=True)
element_id = score_by_elt.index[0:batch_size].values
num_obs = score_by_elt.num_obs[0:batch_size].values
score = score_by_elt.score_sum[0:batch_size].values
t_score = score_by_elt.t_score[0:batch_size].values

In [None]:
num_obs

In [None]:
elts_good = elts.loc[element_id].copy()
elts_good.reset_index(drop=True, inplace=True)
elts_good['num_obs'] = num_obs
elts_good['score'] = score
elts_good['t_score'] = t_score

In [None]:
element_id_new = np.arange(element_id_start, element_id_start+batch_size, dtype=np.int32)
element_id_new

In [None]:
elts_good

In [None]:
# Summarize the ztf element batch: perturbed asteroids
ztf_elt_summary(ztf_elt, score_by_elt, 'Random Elements')

In [None]:
# Mixture parameters
num_hits: int = 20
R_deg: float = 0.5

In [None]:
# Add mixture parameters to candidate elements
elts_add_mixture_params(elts=elts, score_by_elt=score_by_elt, num_hits=num_hits, R_deg=R_deg)

In [None]:
score_by_elt.sort_values(by='score_sum', ascending=False, inplace=True)
score_by_elt

In [None]:
# Review perturbed elements; includes nearest asteroid number and distance
elts

## Train on Perturbed Elements: Learn Mixture Parameters

In [None]:
# Observatory for ZTF data is Palomar Mountain
site_name = 'palomar'

In [None]:
# Training parameters
learning_rate = 2.0**-12
clipnorm = 1.0
save_at_end: bool = True

In [None]:
# Build asteroid search model
model = AsteroidSearchModel(
                elts=elts, ztf_elt=ztf_elt, 
                site_name=site_name, thresh_deg=thresh_deg, 
                learning_rate=learning_rate, clipnorm=clipnorm,
                name='model')

In [None]:
# Report before training starts
model.report()

In [None]:
# Visualize log likelihood before traning
fig, ax = model.plot_bar('log_like', sorted=False)

In [None]:
# Adaptive search parameters
max_batches_mixture = 2000
max_batches_element = 10000
batches_per_epoch = 100
epochs_per_episode = 5
max_bad_episodes = 3
min_learning_rate = None
save_at_end = False
reset_active_weight = False
verbose = 1

In [None]:
# Load model
model.load()
model.report()

In [None]:
# Preliminary round of training with frozen elements
model.freeze_candidate_elements()

In [None]:
# Train perturbed model with frozen orbital elements
model.search_adaptive(
    max_batches=max_batches_mixture, 
    batches_per_epoch=batches_per_epoch,
    epochs_per_episode=epochs_per_episode,
    reset_active_weight=reset_active_weight,
    verbose=verbose)

In [None]:
# Report after initial training on mixture parameters
model.report()

In [None]:
# Bar chart - log likelihood by element
fig, ax = model.plot_bar('log_like', sorted=False)

In [None]:
# Bar chart - log likelihood by element
fig, ax = model.plot_bar('hits', sorted=False)

In [None]:
# Bar chart - log likelihood by element
fig, ax = model.plot_bar('R_deg', sorted=False)

## Train on Perturbed Model: Joint (Elements & Mixture)

In [None]:
# Unfreeze the elements, freeze the mixture model parameters
model.thaw_candidate_elements()

In [None]:
# New smaller learning rate
learning_rate = 2.0**-15

In [None]:
# Train unperturbed model in joint mode
model.search_adaptive(
    max_batches=max_batches_element, 
    batches_per_epoch=batches_per_epoch,
    epochs_per_episode=epochs_per_episode,
    learning_rate=learning_rate,
    reset_active_weight=reset_active_weight,
    verbose=verbose)

In [None]:
# Report after training
model.report()

In [None]:
# model.save_state()

In [None]:
# Bar chart - log likelihood by element
fig, ax = model.plot_bar('log_like', sorted=True)

## Extended Training

In [None]:
max_batches_element_2 = 15000

In [None]:
# Train unperturbed model in joint mode
model.search_adaptive(
    max_batches=max_batches_element_2, 
    batches_per_epoch=batches_per_epoch,
    epochs_per_episode=epochs_per_episode,
    learning_rate=learning_rate,
    save_at_end=False,
    verbose=verbose)

In [None]:
model.report()

In [None]:
# model.save_state()

In [None]:
model.search_adaptive(
    max_batches=100000, 
    batches_per_epoch=batches_per_epoch,
    epochs_per_episode=epochs_per_episode,
    learning_rate=2.0**-16,
    reset_active_weight=True,
    save_at_end=False,
    verbose=verbose)

In [None]:
model.save_state()

## Review Results Graphically

In [None]:
# Bar chart - log likelihood by element
fig, ax = model.plot_bar('log_like', sorted=False)

In [None]:
# Bar chart: hits
fig, ax = model.plot_bar('hits', sorted=False)

In [None]:
# Bar chart: log likelihood
fig, ax = model.plot_bar('log_R', sorted=False)

In [None]:
# Learning curve: log likelihood
fig, ax = model.plot_hist('log_like')

In [None]:
# Learning curve: hits
fig, ax = model.plot_hist('hits')

## Distance vs. Nearest Asteroid for Fitted Elements

In [None]:
# Find nearest asteroid to fitted elements
elts_fit, elts_near = model.nearest_ast()

In [None]:
# Review asteroids nearest to the fitted elements
elts_near

In [None]:
# Review fitted elements including nearest asteroid
elts_fit

In [None]:
# Filter for only the good ones
mask = (elts_fit.log_like > 200) & (elts_fit.R_sec < 60)
elts_fit[mask]

In [None]:
# Filter for incomplete convergence
mask = (elts_fit.log_like > 200) & (elts_fit.R_sec >= 60)
elts_fit[mask]

## Visualize Error vs. Nearest Asteroid

In [None]:
model.elts_near_ast

In [None]:
# Plot position error vs. known elements
fig, ax = model.plot_q_error(is_log=True, use_near_ast_dist=False)

In [None]:
# Plot error in orbital elements
fig, ax = model.plot_elt_error(elt_name='a', is_log=True, elt_num=None)

In [None]:
# Plot error in orbital elements
fig, ax = model.plot_elt_error(elt_name='e', is_log=True, elt_num=None)