In [None]:
# Core
import numpy as np
import pandas as pd

# Tensorflow / ML
import tensorflow as tf
from tensorflow.python.keras import backend as K

# Plotting
import matplotlib.pyplot as plt
import matplotlib as mpl

# Utility
import time

# MSE Imports
import kepler_sieve
from asteroid_element import load_ast_elt
from candidate_element import asteroid_elts, perturb_elts, random_elts
from ztf_ast import load_ztf_nearest_ast, calc_hit_freq
from ztf_element import load_ztf_batch, make_ztf_batch, ztf_elt_summary
from asteroid_model import AsteroidPosition, AsteroidDirection, make_model_ast_pos
from asteroid_search_layers import CandidateElements, TrajectoryScore
from asteroid_search_model import AsteroidSearchModel, make_adam_opt
from asteroid_search_report import traj_diff
from astro_utils import deg2dist, dist2deg, dist2sec
from tf_utils import Identity

In [None]:
# Aliases
keras = tf.keras

# Constants
dtype = tf.float32
dtype_np = np.float32
space_dims = 3

In [None]:
# Set plot style variables
mpl.rcParams['figure.figsize'] = [16.0, 10.0]
mpl.rcParams['font.size'] = 16

## Load ZTF Data and Batch of Orbital Elements

In [None]:
# Load orbital elements for known asteroids
ast_elt = load_ast_elt()

# Number of asteroids
N_ast = ast_elt.shape[0]

In [None]:
# Load ztf nearest asteroid data
ztf_ast = load_ztf_nearest_ast()

In [None]:
# Asteroid numbers and hit counts
ast_num, hit_count = calc_hit_freq(ztf=ztf_ast, thresh_sec=2.0)

# Sort the hit counts in descending order and find the top batch_size
idx = np.argsort(hit_count)[::-1]

# Extract the asteroid number and hit count for this batch
ast_num_best = ast_num[idx]
hit_count_best = hit_count[idx]

In [None]:
# Parameters to build elements batch
batch_size = 64
h = 1.0/64.0  # (1.5625%)
R_deg = 1.0

# Batch of unperturbed elements
elts_ast = asteroid_elts(ast_nums=ast_num_best[0:batch_size])

In [None]:
# Review unperturbed elements
elts_ast

In [None]:
# Inpute to perturb elements: large
sigma_a = 0.05
sigma_e = 0.01
sigma_inc_deg = 0.25
sigma_f_deg = 1.0
sigma_Omega_deg = 1.0
sigma_omega_deg = 1.0
mask_pert = None
random_seed = 42

# Perturb orbital elements
elts_pert = perturb_elts(elts_ast, sigma_a=sigma_a, sigma_e=sigma_e, 
                         sigma_inc_deg=sigma_inc_deg, sigma_f_deg=sigma_f_deg, 
                         sigma_Omega_deg=sigma_Omega_deg, sigma_omega_deg=sigma_omega_deg,
                         mask_pert=mask_pert, random_seed=random_seed)

In [None]:
# Review perturbed elements
elts_pert

In [None]:
# Random elements
elts_rand = random_elts(element_id_start=0, size=batch_size, h=h, R_deg=R_deg,
                        random_seed=random_seed, dtype=dtype_np)

In [None]:
# Review random elements
elts_rand

## Batches of ZTF Data vs. Elements

In [None]:
# Arguments to make_ztf_batch
# thresh_deg = 1.0
thresh_deg = 4.0
near_ast = False
regenerate = False

In [None]:
# Load unperturbed element batch
ztf_elt_ast = load_ztf_batch(elts=elts_ast, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [None]:
# Load perturbed element batch
ztf_elt_pert = load_ztf_batch(elts=elts_pert, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [None]:
# Load random element batch
ztf_elt_rand = load_ztf_batch(elts=elts_rand, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [None]:
# Summarize the ztf element batch: unperturbed asteroids
score_by_elt_ast = ztf_elt_summary(ztf_elt_ast, 'Unperturbed Asteroids')

In [None]:
# score_by_elt_ast

In [None]:
# Summarize the ztf element batch: perturbed asteroids
score_by_elt_pert = ztf_elt_summary(ztf_elt_pert, 'Perturbed Asteroids')

In [None]:
# Summarize the ztf element batch: random elements
score_by_elt_rand = ztf_elt_summary(ztf_elt_rand, 'Random Elements')

## View Example DataFrames and Hits

In [None]:
ztf_elt_ast

In [None]:
# Review hits
mask = ztf_elt_ast.is_hit
ztf_elt_ast[mask]

In [None]:
ztf_elt_ast.columns

In [None]:
# Alias ztf_elt_ast to ztf_elt
ztf_elt = ztf_elt_ast

In [None]:
# Build numpy array of times
ts_np = ztf_elt.mjd.values.astype(dtype_np)

# Get observation count per element
row_lengths_np = ztf_elt.element_id.groupby(ztf_elt.element_id).count().values.astype(np.int32)

In [None]:
# Review results
element_id_best = ast_num_best[0]
mask = (ztf_elt.element_id == element_id_best)
hits_best = np.sum(ztf_elt[mask].is_hit)
hit_rate_best = np.mean(ztf_elt[mask].is_hit)
rows_best = np.sum(mask)
s_sec_min = np.min(ztf_elt[mask].s_sec)
idx = np.argmin(ztf_elt.s)
ztf_id = ztf_elt.ztf_id[idx]
# ztf_elt[mask].iloc[idx:idx+1]
print(f'Best asteroid has element_id = {element_id_best}')
print(f'Hit count: {hits_best} / {rows_best} observations')
print(f'Hit rate : {hit_rate_best:8.6f}')
print(f'Closest hit: {s_sec_min:0.3f} arc seconds')
# ztf_elt[mask]

## Build Asteroid Search Model

In [None]:
# Additional arguments for asteroid search models
site_name = 'palomar'

# Training parameters
learning_rate = 2.0**-13 # (1.22E-4)
clipnorm = 1.0
regenerate = False

In [None]:
# Review candidate elements
elts_ast

In [None]:
# Build asteroid search model
model = AsteroidSearchModel(
        elts=elts_ast, ztf_elt=ztf_elt, 
        site_name=site_name, thresh_deg=thresh_deg, 
        learning_rate=learning_rate, clipnorm=clipnorm)

In [None]:
# Dummy inputs for search model; any array with shape [batch_size,] is good
x = tf.ones(batch_size)

In [None]:
# Run model on unperturbed elements
log_like, orbital_elements, mixture_parameters = model(x)

In [None]:
# Summarize log likelihood on unperturbed elements
log_like_tot = np.sum(log_like)
log_like_mean = np.mean(log_like)
log_like_std = np.std(log_like)

# Report on unperturbed elements
print(f'Log likelihood:')
print(f'Total: {log_like_tot:8.2f}')
print(f'Mean: {log_like_mean:8.2f}')
print(f'Std : {log_like_std:8.2f}')
print(f'First 5:')
print(log_like[0:5].numpy())

In [None]:
# Fit one batch
hist = model.fit(x)

In [None]:
# Evaluate
model.evaluate(x)

In [None]:
# Built in log likelihood caclulation
model.calc_log_like()

In [None]:
# Visualize model summary - layers and parameters
model.summary()

## Fit Model on Unperturbed Elements

In [None]:
# Training parameters
learning_rate = 2.0**-13 # (1.22E-4)
clipnorm = 1.0
regenerate = False

In [None]:
# Adaptive search parameters
max_batches = 20000
batches_per_epoch = 100
epochs_per_episode = 5
min_learning_rate = 2.0**-20 # about 9.54E-7
verbose = 1

In [None]:
# Build asteroid search model
model_ast = AsteroidSearchModel(
                elts=elts_ast, ztf_elt=ztf_elt, 
                site_name=site_name, thresh_deg=thresh_deg, 
                learning_rate=learning_rate, clipnorm=clipnorm)

In [None]:
# Tiny size for fast testing
max_batches = 1000
batches_per_epoch = 10

In [None]:
# alias model_ast to model for interactive testing
model = model_ast

In [None]:
# # Regenerate elements from model
# elts_df = model.load_candidates(verbose=True)

In [None]:
# Report elements from model
elts_df = model.candidates_df()

# Review regenerated elements and scores
elts_df

In [None]:
# Train unperturbed model
model_ast.search_adaptive(
    max_batches=max_batches, 
    batches_per_epoch=batches_per_epoch,
    epochs_per_episode=epochs_per_episode,
    min_learning_rate=min_learning_rate,
    regenerate=True,
    verbose=verbose)

In [None]:
model.train_hist_summary

In [None]:
model.train_hist_elt

In [None]:
train_hist = model.train_hist_elt
mask = train_hist.element_num == 55
train_hist[mask]

In [None]:
model2.load

In [None]:
model2 = AsteroidSearchModel(
                elts=elts_ast, ztf_elt=ztf_elt, 
                site_name=site_name, thresh_deg=thresh_deg, 
                learning_rate=learning_rate, clipnorm=clipnorm)

In [None]:
model2.load_candidates(verbose=True)

In [None]:
raise ValueError

In [None]:
# Generate the outputs
log_like, orbital_elements, mixture_params = model.calc_outputs()
# Total log likelihood
total_log_like = tf.reduce_sum(log_like)
# Current loss
loss = model.calc_loss()

In [None]:
total_log_like.numpy()

In [None]:
# Report elements from model
elts_df = model.candidates_df()

# Review regenerated elements and scores
elts_df

In [None]:
# model.review_members()

In [None]:
# # Second (tune-up) adaptive training
# model.search_adaptive(
#     max_batches=max_batches, 
#     batches_per_epoch=batches_per_epoch,
#     epochs_per_episode=epochs_per_episode,
#     min_learning_rate=min_learning_rate,
#     regenerate=True,
#     verbose=verbose)

In [None]:
def plot_total_log_like(model):
    pass

In [None]:
# Plot density of log(a)
fig, ax = plt.subplots()
ax.set_title('Training Progress: Total Log Likelihood')
ax.set_xlabel('Batch Count')
ax.set_ylabel('Total Log Likelihood')
# n, bins, patches = ax.hist(x=log_a, bins=bins_log_a, density=True, cumulative=False, color='blue')
# ax.plot(bins, pdf_log_a, color='red')
# ax.set_xlim([log_a_min, log_a_max])
# ax.legend()
ax.grid()
# fig.savefig('../figs/training/total_log_like.png', bbox_inches='tight')
plt.show()

In [None]:
# Review likelihoods by element
log_like_ast, orbital_elements_ast, mixture_parameters_ast = model_ast.calc_outputs()

In [None]:
log_like_ast

In [None]:
# Review resolution in arc seconds
dist2deg(mixture_parameters_ast[2])

In [None]:
elts_fit = model_ast.candidates_df()
elts_fit

In [None]:
np.median(elts_fit.R_deg)

In [None]:
model_pos = make_model_ast_pos(ts_np=ztf_elt.mjd, row_lengths_np=row_lengths_np)

In [None]:
traj_err = traj_diff(elts_ast, elts_fit, model_pos)
traj_err

In [None]:
np.exp(np.mean(np.log(traj_err)))

In [None]:
np.median(traj_err)

In [None]:
# plot log_like

## Train on Perturbed Elements

In [None]:
# Build asteroid search model
model_pert = AsteroidSearchModel(
                 elts=elts_pert, ztf_elt=ztf_elt_pert, site_name=site_name,
                 thresh_deg=thresh_deg, h=h, R_deg=R_deg,
                 learning_rate=learning_rate, clipnorm=clipnorm)

In [None]:
# Train model on perturbed elements
model_pert.search_adaptive(
    max_batches=max_batches, 
    batches_per_epoch=batches_per_epoch,
    epochs_per_episode=epochs_per_episode,
    min_learning_rate=min_learning_rate,
    regenerate=regenerate,
    verbose=verbose)

In [None]:
# Review likelihoods by element
log_like_per, orbital_elements_pert, mixture_parameters_pert = model_pert.calc_outputs()

In [None]:
log_like_pert

## Train on Random Elements

In [None]:
# Filter elts_rand down to only those that had matching ztf observations
idx = np.unique(ztf_elt_rand.element_id)
elts_rand = elts_rand.loc[idx]

In [None]:
elts_rand['a'].dtype

In [None]:
elts_rand['a'].dtype

In [None]:
for col in ['a', 'e', 'f', 'inc', 'Omega', 'omega', 'epoch']:
    elts_rand[col] = elts_rand[col].astype(np.float32)

In [None]:
# Build asteroid search model
model_rand = AsteroidSearchModel(
                 elts=elts_rand, ztf_elt=ztf_elt_rand, site_name=site_name,
                 thresh_deg=thresh_deg, h=h, R_deg=R_deg,
                 learning_rate=learning_rate, clipnorm=clipnorm)

In [None]:
# Train model on perturbed elements
model_rand.search_adaptive(
    max_batches=max_batches, 
    batches_per_epoch=batches_per_epoch,
    epochs_per_episode=epochs_per_episode,
    min_learning_rate=min_learning_rate,
    verbose=verbose)

In [None]:
# Review likelihoods by element
log_like_rand, orbital_elements_rand, mixture_parameters_rand = model_rand.calc_outputs()

In [None]:
log_like_rand

In [None]:
# orbital_elements_rand

In [None]:
from candidate_element import elts_np2df