In [1]:
# Core
import numpy as np
import pandas as pd

# Tensorflow / ML
import tensorflow as tf
from tensorflow.python.keras import backend as K

# Plotting
import matplotlib.pyplot as plt
import matplotlib as mpl

# Utility
import time

# MSE Imports
import kepler_sieve
from asteroid_element import load_ast_elt
from candidate_element import asteroid_elts, perturb_elts, random_elts, elts_add_mixture_params
from ztf_ast import load_ztf_nearest_ast, calc_hit_freq
from ztf_element import load_ztf_batch, make_ztf_batch, ztf_score_by_elt, ztf_elt_summary
from asteroid_model import AsteroidPosition, AsteroidDirection, make_model_ast_pos
from asteroid_search_layers import CandidateElements, TrajectoryScore
from asteroid_search_model import AsteroidSearchModel, make_opt_adam
from asteroid_search_report import traj_diff
from candidate_element import score_by_elt
from astro_utils import deg2dist, dist2deg, dist2sec

Found 4 GPUs.  Setting memory growth = True.


In [2]:
# Aliases
keras = tf.keras

# Constants
dtype = tf.float32
dtype_np = np.float32
space_dims = 3

In [3]:
# Set plot style variables
mpl.rcParams['figure.figsize'] = [16.0, 10.0]
mpl.rcParams['font.size'] = 16

In [4]:
# Color for plots
color_mean = 'blue'
color_lo = 'orange'
color_hi = 'green'
color_min = 'red'
color_max = 'purple'

## Load ZTF Data and Batch of Orbital Elements

In [5]:
# Load orbital elements for known asteroids
ast_elt = load_ast_elt()

# Number of asteroids
N_ast = ast_elt.shape[0]

In [6]:
# Load ztf nearest asteroid data
ztf_ast = load_ztf_nearest_ast()

In [7]:
# Asteroid numbers and hit counts
ast_num, hit_count = calc_hit_freq(ztf=ztf_ast, thresh_sec=2.0)

# Sort the hit counts in descending order and find the top batch_size
idx = np.argsort(hit_count)[::-1]

# Extract the asteroid number and hit count for this batch
ast_num_best = ast_num[idx]
hit_count_best = hit_count[idx]

In [8]:
# Parameters to build elements batch
batch_size = 64

# Batch of unperturbed elements
elts_ast = asteroid_elts(ast_nums=ast_num_best[0:batch_size])

In [10]:
# # Review unperturbed elements
# elts_ast

In [11]:
# Inputs to perturb elements: large
sigma_a = 0.05
sigma_e = 0.01
sigma_inc_deg = 0.25
sigma_f_deg = 1.0
sigma_Omega_deg = 1.0
sigma_omega_deg = 1.0
mask_pert = None
random_seed = 42

# Perturb orbital elements
elts_pert = perturb_elts(elts_ast, sigma_a=sigma_a, sigma_e=sigma_e, 
                         sigma_inc_deg=sigma_inc_deg, sigma_f_deg=sigma_f_deg, 
                         sigma_Omega_deg=sigma_Omega_deg, sigma_omega_deg=sigma_omega_deg,
                         mask_pert=mask_pert, random_seed=random_seed)

In [12]:
# Review perturbed elements
# elts_pert

In [13]:
# Random elements
elts_rand = random_elts(element_id_start=0, size=batch_size,
                        random_seed=random_seed, dtype=dtype_np)

In [14]:
# Review random elements
# elts_rand

## Batches of ZTF Data vs. Elements

In [15]:
# Arguments to make_ztf_batch
# thresh_deg = 1.0
# thresh_deg = 2.0
thresh_deg = 4.0
near_ast = False
regenerate = False

In [16]:
# Load unperturbed element batch
ztf_elt_ast = load_ztf_batch(elts=elts_ast, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [17]:
# Load perturbed element batch
ztf_elt_pert = load_ztf_batch(elts=elts_pert, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [18]:
# Load random element batch
ztf_elt_rand = load_ztf_batch(elts=elts_rand, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [19]:
# Score by element - unperturbed
score_by_elt_ast = ztf_score_by_elt(ztf_elt_ast)

In [20]:
# Score by element - perturbed
score_by_elt_pert = ztf_score_by_elt(ztf_elt_pert)

In [21]:
# Score by element - random
score_by_elt_rand = ztf_score_by_elt(ztf_elt_rand)

In [22]:
# Summarize the ztf element batch: unperturbed asteroids
ztf_elt_summary(ztf_elt_ast, score_by_elt_ast, 'Unperturbed Asteroids')

ZTF Element Dataframe Unperturbed Asteroids:
                  Total     (Per Batch)
Observations   :  1233691   (    19276)

Summarize score = sum(-1.0 - log(v)) by batch.  (Mean=0, Variance=num_obs)
Mean score     :    3347.31
Sqrt(batch_obs):     138.84
Mean t_score   :      26.13


In [23]:
# Summarize the ztf element batch: perturbed asteroids
ztf_elt_summary(ztf_elt_pert, score_by_elt_pert, 'Perturbed Asteroids')

ZTF Element Dataframe Perturbed Asteroids:
                  Total     (Per Batch)
Observations   :  1148437   (    17944)

Summarize score = sum(-1.0 - log(v)) by batch.  (Mean=0, Variance=num_obs)
Mean score     :     146.79
Sqrt(batch_obs):     133.96
Mean t_score   :       0.90


In [24]:
# Summarize the ztf element batch: random elements
ztf_elt_summary(ztf_elt_rand, score_by_elt_rand, 'Random Elements')

ZTF Element Dataframe Random Elements:
                  Total     (Per Batch)
Observations   :  1056703   (    16773)

Summarize score = sum(-1.0 - log(v)) by batch.  (Mean=0, Variance=num_obs)
Mean score     :      82.21
Sqrt(batch_obs):     129.51
Mean t_score   :       0.28


In [25]:
# Mixture parameters
num_hits: int = 10
R_deg: float = 1.0

In [26]:
# Add mixture parameters to candidate elements
elts_add_mixture_params(elts=elts_ast, score_by_elt=score_by_elt_ast, num_hits=num_hits, R_deg=R_deg)
elts_add_mixture_params(elts=elts_pert, score_by_elt=score_by_elt_pert, num_hits=num_hits, R_deg=R_deg)
elts_add_mixture_params(elts=elts_rand, score_by_elt=score_by_elt_rand, num_hits=num_hits, R_deg=R_deg)

In [28]:
# Review unperturbed elements
elts_ast

Unnamed: 0,element_id,a,e,inc,Omega,omega,f,epoch,h,R
0,51921,2.669306,0.217361,0.499554,4.699703,2.450796,-1.133491,58600.0,0.000560,0.017453
1,59244,2.634727,0.262503,0.465045,5.738298,1.766995,-1.601363,58600.0,0.000981,0.017453
2,15786,1.883227,0.047655,0.392360,6.134689,0.804823,-1.246069,58600.0,0.000723,0.017453
3,3904,2.556387,0.098279,0.261542,5.450163,2.202423,-1.357345,58600.0,0.000687,0.017453
4,142999,2.619944,0.191376,0.514017,0.238022,0.946463,-1.299301,58600.0,0.000687,0.017453
...,...,...,...,...,...,...,...,...,...,...
59,11952,2.219650,0.086091,0.117967,0.042442,2.904823,-3.016580,58600.0,0.000225,0.017453
60,134815,2.612770,0.140831,0.513922,0.272689,0.645552,-0.957836,58600.0,0.000746,0.017453
61,27860,2.619406,0.096185,0.200633,5.541400,3.266046,3.948770,58600.0,0.000442,0.017453
62,85937,2.342291,0.197267,0.439063,5.279693,3.210025,3.947687,58600.0,0.000863,0.017453


## View Example DataFrames and Hits

In [None]:
# Review ztf_elt DataFrame
ztf_elt_ast

In [None]:
# Review hits
mask = ztf_elt_ast.is_hit
ztf_elt_ast[mask]

In [None]:
ztf_elt_ast.columns

In [None]:
# Alias ztf_elt_ast to ztf_elt
ztf_elt = ztf_elt_ast

In [None]:
# Build numpy array of times
ts_np = ztf_elt.mjd.values.astype(dtype_np)

# Get observation count per element
row_lengths_np = ztf_elt.element_id.groupby(ztf_elt.element_id).count().values.astype(np.int32)

In [None]:
# Review results
element_id_best = ast_num_best[0]
mask = (ztf_elt.element_id == element_id_best)
hits_best = np.sum(ztf_elt[mask].is_hit)
hit_rate_best = np.mean(ztf_elt[mask].is_hit)
rows_best = np.sum(mask)
s_sec_min = np.min(ztf_elt[mask].s_sec)
idx = np.argmin(ztf_elt.s)
ztf_id = ztf_elt.ztf_id[idx]
# ztf_elt[mask].iloc[idx:idx+1]
print(f'Best asteroid has element_id = {element_id_best}')
print(f'Hit count: {hits_best} / {rows_best} observations')
print(f'Hit rate : {hit_rate_best:8.6f}')
print(f'Closest hit: {s_sec_min:0.3f} arc seconds')
# ztf_elt[mask]

## Build "Gold" Asteroid Search Model (Reference)

In [None]:
# Ovservatory for ZTF data is Palomar Mountain
site_name = 'palomar'

In [None]:
# Training parameters
learning_rate = 2.0**-12
clipnorm = 1.0
regenerate = False

In [None]:
# "Golden elements": unperturbed, with sharp resolution setting and approximately correct hit rate
elts_gold = elts_ast.copy()
ztf_elt_gold = ztf_elt_ast.copy()
# add_mixture_params(elts_gold, h=h_gold, R_deg=10.0/3600.0, dtype=np.float32)

In [None]:
# Build asteroid search model
model_gold = AsteroidSearchModel(
        elts=elts_gold, ztf_elt=ztf_elt_gold, 
        site_name=site_name, thresh_deg=thresh_deg, 
        learning_rate=learning_rate, clipnorm=clipnorm)

In [None]:
# Report summary outputs
model_gold.report()

In [None]:
# Score using scipy optimization of h, lambda
# elt_score_gold = score_by_elt(ztf_elt_gold, thresh_deg=thresh_deg, fit_mixture=True)
# elt_score_gold = score_by_elt(ztf_elt_gold, thresh_deg=thresh_deg, fit_mixture=False)
# elt_score_gold

In [None]:
# Freeze orbital elements; train only mixture parameters
model_gold.freeze_candidate_elements()

In [None]:
# Adaptive search parameters - small size
max_batches = 2000
batches_per_epoch = 100
epochs_per_episode = 5
min_learning_rate = 2.0**-20
verbose = 1

In [None]:
model_gold.search_adaptive(
    max_batches=max_batches, 
    batches_per_epoch=batches_per_epoch,
    epochs_per_episode=epochs_per_episode,
    min_learning_rate=min_learning_rate,
    regenerate=regenerate,
    verbose=verbose)

In [None]:
model_gold.report()

In [None]:
# Visualize log likelihood of gold model
model_gold.plot_score_bar('log_like')

## Check that Predicted Direction Matches Expected Direction

In [None]:
# Predicted direction
u_pred, r_pred = model_gold.predict_direction()

In [None]:
# Expected direction
cols_u = ['elt_ux', 'elt_uy', 'elt_uz']
u_true = ztf_elt_gold[cols_u].values

In [None]:
# Difference between actual and predicted
du = u_pred - u_true
u_err = np.linalg.norm(du, axis=-1)
u_err_mean = np.mean(u_err)
u_err_mean_sec = dist2sec(u_err_mean)
print(f'Mean direction error: {u_err_mean:6.2e} Cartesian / {u_err_mean_sec:6.2f} arc seconds')

## Build Asteroid Search Model

In [None]:
# Review candidate elements
elts_ast

In [None]:
# elts_ast.h = 1.0/256.0

In [None]:
# Build asteroid search model
model = AsteroidSearchModel(
        elts=elts_ast, ztf_elt=ztf_elt, 
        site_name=site_name, thresh_deg=thresh_deg, 
        learning_rate=learning_rate, clipnorm=clipnorm)

In [None]:
# Report before training starts
model.report()

In [None]:
# Built in log likelihood caclulation
# model.calc_log_like()

In [None]:
# Visualize model summary - layers and parameters
model.summary()

In [None]:
# Visualize log likelihood before traning
model.plot_score_bar('log_like', sorted=False)

## Fit Model on Unperturbed Elements

In [None]:
# Training parameters
learning_rate = 2.0**-12
clipnorm = 1.0
regenerate = False

In [None]:
# Build asteroid search model
model_ast = AsteroidSearchModel(
                elts=elts_ast, ztf_elt=ztf_elt, 
                site_name=site_name, thresh_deg=thresh_deg, 
                learning_rate=learning_rate, clipnorm=clipnorm)

In [None]:
# Report before training starts
model.report()

In [None]:
# Adaptive search parameters
max_batches_mixture = 10000
max_batches_element = 20000
batches_per_epoch = 10
epochs_per_episode = 5
max_bad_episodes = 3
learning_rate = 2.0**-12
min_learning_rate = None
verbose = 1

In [None]:
# # Tiny size for fast testing
# max_batches = 1000
# batches_per_epoch = 10

In [None]:
# Preliminary round of training with frozen elements
model_ast.freeze_candidate_elements()

In [None]:
# Train unperturbed model with frozen orbital elements
model_ast.search_adaptive(
    max_batches=max_batches_mixture, 
    batches_per_epoch=batches_per_epoch,
    epochs_per_episode=epochs_per_episode,
    min_learning_rate=min_learning_rate,
    regenerate=True,
    verbose=verbose)

In [None]:
# Unfreeze the elements, freeze the mixture model parameters
model_ast.thaw_candidate_elements()
model_ast.freeze_mixture_parameters()

In [None]:
# Train unperturbed model with frozen orbital elements
model_ast.search_adaptive(
    max_batches=max_batches_element, 
    batches_per_epoch=batches_per_epoch,
    epochs_per_episode=epochs_per_episode,
    min_learning_rate=min_learning_rate,
    regenerate=regenerate,
    verbose=verbose)

In [None]:
# Report after training
model_ast.report()

In [None]:
model_gold.plot_score_bar('log_like', episode=0)

In [None]:
# Bar chart: log likelihood
model_ast.plot_score_bar('log_like', episode=0)

In [None]:
# Bar chart: hits
model_ast.plot_score_bar('hits')

In [None]:
# Learning curve: log likelihood
model_ast.plot_score_hist('log_like')

In [None]:
# Learning curve: hits
model_ast.plot_score_hist('hits')

In [None]:
# Mixture parameter: resolution
model_ast.plot_mixture_param('R_deg')

In [None]:
# Plot error in orbital elements
model_ast.plot_elt_error(elts_true=elts_ast, elt_name='a', is_log=True, elt_num=None)

In [None]:
# Plot error in orbital elements
model_ast.plot_elt_error(elts_true=elts_ast, elt_name='e', is_log=True, elt_num=None)

In [None]:
# Plot control variables for worst element
model_ast.plot_control(element_num=55)

In [None]:
# Plot control variables for best element
model_ast.plot_control(element_num=59)

## Train on Perturbed Elements

In [None]:
# Build asteroid search model
model_pert = AsteroidSearchModel(
                 elts=elts_pert, ztf_elt=ztf_elt_pert, site_name=site_name,
                 thresh_deg=thresh_deg, h=h, R_deg=R_deg,
                 learning_rate=learning_rate, clipnorm=clipnorm)

In [None]:
# Freeze elements
model_pert.freeze_elements()

In [None]:
# Train model on perturbed elements
model_pert.search_adaptive(
    max_batches=max_batches_mixture, 
    batches_per_epoch=batches_per_epoch,
    epochs_per_episode=epochs_per_episode,
    min_learning_rate=min_learning_rate,
    regenerate=True,
    verbose=verbose)

In [None]:
# Unfreeze the elements, freeze the mixture model parameters
model_pert.thaw_candidate_elements()
model_pert.freeze_mixture_parameters()

In [None]:
# Train unperturbed model with frozen orbital elements
model_pert.search_adaptive(
    max_batches=max_batches_element, 
    batches_per_epoch=batches_per_epoch,
    epochs_per_episode=epochs_per_episode,
    min_learning_rate=min_learning_rate,
    regenerate=True,
    verbose=verbose)

In [None]:
model_pert.report()

## Train on Random Elements

In [None]:
# Filter elts_rand down to only those that had matching ztf observations
idx = np.unique(ztf_elt_rand.element_id)
elts_rand = elts_rand.loc[idx]

In [None]:
elts_rand['a'].dtype

In [None]:
elts_rand['a'].dtype

In [None]:
for col in ['a', 'e', 'f', 'inc', 'Omega', 'omega', 'epoch']:
    elts_rand[col] = elts_rand[col].astype(np.float32)

In [None]:
# Build asteroid search model
model_rand = AsteroidSearchModel(
                 elts=elts_rand, ztf_elt=ztf_elt_rand, site_name=site_name,
                 thresh_deg=thresh_deg, h=h, R_deg=R_deg,
                 learning_rate=learning_rate, clipnorm=clipnorm)

In [None]:
# Train model on perturbed elements
model_rand.search_adaptive(
    max_batches=max_batches, 
    batches_per_epoch=batches_per_epoch,
    epochs_per_episode=epochs_per_episode,
    min_learning_rate=min_learning_rate,
    verbose=verbose)

In [None]:
# Review likelihoods by element
log_like_rand, orbital_elements_rand, mixture_parameters_rand = model_rand.calc_outputs()

In [None]:
log_like_rand

In [None]:
# orbital_elements_rand