In [1]:
# Core
import numpy as np
import pandas as pd

# Tensorflow / ML
import tensorflow as tf
from tensorflow.python.keras import backend as K

# Plotting
import matplotlib.pyplot as plt
import matplotlib as mpl

# Utility
import time

# MSE Imports
import kepler_sieve
from asteroid_element import load_ast_elt
from candidate_element import asteroid_elts, perturb_elts, random_elts, elts_add_mixture_params
from ztf_ast import load_ztf_nearest_ast, calc_hit_freq
from ztf_element import load_ztf_batch, make_ztf_batch, ztf_score_by_elt, ztf_elt_summary
from asteroid_model import AsteroidPosition, AsteroidDirection, make_model_ast_pos
from asteroid_search_layers import CandidateElements, TrajectoryScore
from asteroid_search_model import AsteroidSearchModel, make_opt_adam
from asteroid_search_report import traj_diff
from nearest_asteroid import nearest_ast_elt_cart, nearest_ast_elt_cov, elt_q_norm
from element_eda import score_by_elt
from asteroid_dataframe import calc_ast_data, spline_ast_vec_df
from astro_utils import deg2dist, dist2deg, dist2sec

Found 4 GPUs.  Setting memory growth = True.


In [None]:
# Aliases
keras = tf.keras

# Constants
dtype = tf.float32
dtype_np = np.float32
space_dims = 3

In [None]:
# Set plot style variables
mpl.rcParams['figure.figsize'] = [16.0, 10.0]
mpl.rcParams['font.size'] = 16

## Load ZTF Data and Batch of Orbital Elements

In [None]:
# Load orbital elements for known asteroids
ast_elt = load_ast_elt()

# Number of asteroids
N_ast = ast_elt.shape[0]

In [None]:
# Load ztf nearest asteroid data
ztf_ast = load_ztf_nearest_ast()

In [None]:
# Asteroid numbers and hit counts
ast_num, hit_count = calc_hit_freq(ztf=ztf_ast, thresh_sec=2.0)

# Sort the hit counts in descending order and find the top batch_size
idx = np.argsort(hit_count)[::-1]

# Extract the asteroid number and hit count for this batch
ast_num_best = ast_num[idx]
hit_count_best = hit_count[idx]

In [None]:
# a = elts.a.values
# e = elts.e.values
# inc = elts.inc.values
# Omega = elts.Omega.values
# omega = elts.omega.values
# f = elts.f.values
# epoch = elts.epoch.values

In [None]:
# q, v = model.position(a, e, inc, Omega, omega, f, epoch)

In [None]:
# Parameters to build elements batch
batch_size = 64

# Batch of unperturbed elements
elts_ast = asteroid_elts(ast_nums=ast_num_best[0:batch_size])

In [None]:
# # Review unperturbed elements
# elts_ast

In [None]:
# Inputs to perturb elements: large
sigma_a = 0.05
sigma_e = 0.01
sigma_inc_deg = 0.25
sigma_f_deg = 1.0
sigma_Omega_deg = 1.0
sigma_omega_deg = 1.0
mask_pert = None
random_seed = 42

# Perturb orbital elements
elts_pert= perturb_elts(elts_ast, sigma_a=sigma_a, sigma_e=sigma_e, 
                    sigma_inc_deg=sigma_inc_deg, sigma_f_deg=sigma_f_deg, 
                    sigma_Omega_deg=sigma_Omega_deg, sigma_omega_deg=sigma_omega_deg,
                    mask_pert=mask_pert, random_seed=random_seed)

In [None]:
# Choose which elements to search on
elts = elts_pert

In [None]:
# Search for nearest asteroids to these elements
# elts_near = nearest_ast_elt(elts)

In [None]:
# Review selected initial candidate elements, including the nearest asteroid information
# elts

In [None]:
# Review nearest asteroid to these candidate elements
# elts_near

In [None]:
# How many elements are still closest to the original elements?
# np.sum(elts.nearest_ast_num == elts.element_id)

## Batches of ZTF Data Near Initial Candidate Elements

In [None]:
# Arguments to make_ztf_batch
thresh_deg = 2.0
near_ast = False
regenerate = False

In [None]:
elts

In [None]:
# Load perturbed element batch
ztf_elt = load_ztf_batch(elts=elts, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [None]:
ztf_elt

In [None]:
# Score by element - perturbed
score_by_elt = ztf_score_by_elt(ztf_elt)

In [None]:
# Summarize the ztf element batch: perturbed asteroids
ztf_elt_summary(ztf_elt, score_by_elt, 'Perturbed Asteroids')

In [None]:
# Mixture parameters
num_hits: int = 10
R_deg: float = 0.5

In [None]:
# Add mixture parameters to candidate elements
elts_add_mixture_params(elts=elts, num_hits=num_hits, R_deg=R_deg, thresh_deg=thresh_deg)

In [None]:
# Review perturbed elements; includes nearest asteroid number and distance
elts

## Train on Perturbed Elements: Learn Mixture Parameters

In [None]:
# Observatory for ZTF data is Palomar Mountain
site_name = 'palomar'

In [None]:
# Training parameters
learning_rate = 2.0**-15
clipnorm = 1.0
save_at_end: bool = True

In [None]:
# Build asteroid search model
model = AsteroidSearchModel(
                elts=elts, ztf_elt=ztf_elt, 
                site_name=site_name, thresh_deg=thresh_deg, 
                learning_rate=learning_rate, clipnorm=clipnorm,
                name='model')

In [None]:
# Report before training starts
model.report()

In [None]:
# Visualize log likelihood before traning
fig, ax = model.plot_bar('log_like', sorted=False)

In [None]:
# Adaptive search parameters
max_batches_mixture = 2000
max_batches_element = 10000
batches_per_epoch = 100
epochs_per_episode = 5
max_bad_episodes = 3
min_learning_rate = None
save_at_end = False
reset_active_weight = False
verbose = 1

In [None]:
# # Load model
# model.load()
# model.report()

In [None]:
# Preliminary round of training with frozen elements
model.freeze_candidate_elements()
# model.thaw_score()

In [None]:
# Train perturbed model with frozen orbital elements
model.search_adaptive(
    max_batches=1000, 
    learning_rate=2.0**-12,
)

In [None]:
# Report after initial training on mixture parameters
model.report()

In [None]:
model.save_state()

In [None]:
# Bar chart - log likelihood by element
fig, ax = model.plot_bar('log_like', sorted=False)

In [None]:
# # Bar chart - hits by element
# fig, ax = model.plot_bar('hits', sorted=False)

In [None]:
# # Bar chart - resolution by element
# fig, ax = model.plot_bar('R_deg', sorted=False)

## Train on Perturbed Model: Joint (Elements & Mixture)

In [None]:
model.load()

In [None]:
# Unfreeze the elements
model.thaw_candidate_elements()
model.thaw_mixture_parameters()
# model.thaw_score()

In [None]:
# Length of training
max_batches_element = 2000

# thresh_deg at end: don't use, score layer is thawed
thresh_deg_end = None

# New smaller learning rate
learning_rate = 2.0**-15

# Reset active weight
reset_active_weight = True

In [None]:
# Train model in joint mode
model.search_adaptive(max_batches=2000)

In [None]:
# Report after training
model.report()

In [None]:
model.save_state()

In [None]:
model.load()

In [None]:
model.set_thresh_deg_max(1.75)
model.set_R_deg_max(1.75/4.0)
model.save_weights()

In [None]:
model.report()

In [None]:
raise ValueError

In [None]:
# Train model in joint mode
model.search_adaptive(
    max_batches=3000, 
    batches_per_epoch=batches_per_epoch,
    epochs_per_episode=epochs_per_episode,
    thresh_deg_end=thresh_deg_end,
    learning_rate=learning_rate,
    reset_active_weight=reset_active_weight,
    verbose=verbose)

In [None]:
model.report()

In [None]:
# Bar chart - log likelihood by element
fig, ax = model.plot_bar('log_like', sorted=False)

In [None]:
# Bar chart - hits
fig, ax = model.plot_bar('hits', sorted=False)

In [None]:
# Bar chart - hits
fig, ax = model.plot_bar('R_deg', sorted=False)

In [None]:
# Bar chart - hits
fig, ax = model.plot_bar('thresh_deg', sorted=False)

In [None]:
# Save model
model.save_state()

In [None]:
model.freeze_candidate_elements()

In [None]:
# Train model in joint mode
model.search_adaptive(
    max_batches=model.current_batch+2000, 
    batches_per_epoch=batches_per_epoch,
    epochs_per_episode=epochs_per_episode,
    learning_rate=learning_rate,
    reset_active_weight=True,
    verbose=verbose)

In [None]:
model.report()

In [None]:
model.save_state()

In [None]:
model.thaw_score()

In [None]:
# Train model in joint mode
model.search_adaptive(
    max_batches=model.current_batch+2000, 
    batches_per_epoch=batches_per_epoch,
    epochs_per_episode=epochs_per_episode,
    learning_rate=learning_rate,
    reset_active_weight=True,
    verbose=verbose)

In [None]:
model.report()

In [None]:
model.save_state()

## Extended Training

In [None]:
# # Unfreeze the score layer
# model.thaw_score()

In [None]:
# Length of training
max_batches_element = 22000

# thresh_deg at end: 500 arc seconds
thresh_deg_end = 500 / 3600.0

# New smaller learning rate
learning_rate = 2.0**-15

# Reset active weight
reset_active_weight = True

In [None]:
# Train model in joint mode
model.search_adaptive(
    max_batches=max_batches_element, 
    batches_per_epoch=batches_per_epoch,
    epochs_per_episode=epochs_per_episode,
    thresh_deg_end=thresh_deg_end,
    learning_rate=learning_rate,
    reset_active_weight=reset_active_weight,
    verbose=verbose)

In [None]:
model.report()

In [None]:
# model.save_state()

In [None]:
# Lower threshold manually
thresh_deg_score = 1.0
model.set_thresh_deg_score(thresh_deg_score)
model.freeze_score()

In [None]:
model.search_adaptive(
    max_batches=20000, 
    batches_per_epoch=batches_per_epoch,
    epochs_per_episode=epochs_per_episode,
    learning_rate=2.0**-15,
    reset_active_weight=True,
    verbose=verbose)

In [None]:
model.report()

In [None]:
# model.save_state()

## Review Results Graphically

In [None]:
# Bar chart - log likelihood by element
fig, ax = model.plot_bar('log_like', sorted=False)

In [None]:
# Bar chart: hits
fig, ax = model.plot_bar('hits', sorted=False)

In [None]:
# Bar chart: log likelihood
fig, ax = model.plot_bar('log_R', sorted=False)

In [None]:
# # Learning curve: log likelihood
# fig, ax = model.plot_hist('log_like')

In [None]:
# # Learning curve: hits
# fig, ax = model.plot_hist('hits')

## Distance vs. Nearest Asteroid for Fitted Elements

In [None]:
# Find nearest asteroid to fitted elements
elts_fit, elts_near = model.nearest_ast()

In [None]:
q_norm = elt_q_norm(elts=elts_fit, ast_num=elts_fit.nearest_ast_num)
elts_fit['nearest_ast_q_norm'] = q_norm

In [None]:
# # Review asteroids nearest to the fitted elements
# elts_near

In [None]:
# Review fitted elements including nearest asteroid
elts_fit

In [None]:
# Filter for only the good ones
mask = (elts_fit.log_like > 200) & (elts_fit.R_sec < 60)
elts_fit[mask]

In [None]:
# Filter for incomplete convergence
mask = (elts_fit.log_like > 200) & (elts_fit.R_sec >= 60)
elts_fit[mask]

## Visualize Error vs. Nearest Asteroid

In [None]:
model.elts_near_ast

In [None]:
# Plot position error vs. known elements
fig, ax = model.plot_q_error(is_log=True, use_near_ast_dist=True)

In [None]:
# Plot error in orbital elements
fig, ax = model.plot_elt_error(elt_name='a', is_log=True, elt_num=None)

In [None]:
# Plot error in orbital elements
fig, ax = model.plot_elt_error(elt_name='e', is_log=True, elt_num=None)