In [None]:
# Core
import numpy as np
import pandas as pd

# Tensorflow / ML
import tensorflow as tf
from tensorflow.python.keras import backend as K

# Utility
import time

# Local
from asteroid_integrate import load_ast_elt
from candidate_element import orbital_element_batch, perturb_elts, random_elts
from ztf_data import load_ztf_nearest_ast, calc_hit_freq, load_ztf_batch, make_ztf_batch
from asteroid_model import AsteroidPosition, AsteroidDirection
from asteroid_search_layers import CandidateElements, TrajectoryScore
from asteroid_search_model import AsteroidSearchModel, make_adam_opt
from astro_utils import deg2dist, dist2deg, dist2sec
from tf_utils import Identity

In [None]:
# Aliases
keras = tf.keras

# Constants
dtype = tf.float32
dtype_np = np.float32
space_dims = 3

## Load ZTF Data and Batch of Orbital Elements

In [None]:
# Load orbital elements for known asteroids
ast_elt = load_ast_elt()

# Number of asteroids
N_ast = ast_elt.shape[0]

In [None]:
# Load ztf nearest asteroid data
ztf_ast = load_ztf_nearest_ast()

In [None]:
# Asteroid numbers and hit counts
ast_num, hit_count = calc_hit_freq(ztf=ztf_ast, thresh_sec=2.0)

# Sort the hit counts in descending order and find the top batch_size
idx = np.argsort(hit_count)[::-1]

# Extract the asteroid number and hit count for this batch
ast_num_best = ast_num[idx]
hit_count_best = hit_count[idx]

In [None]:
# Set batch size
batch_size = 64
elt_batch_size = batch_size

# Batch of unperturbed elements
elts_ast = orbital_element_batch(ast_nums=ast_num_best[0:batch_size])

In [None]:
elts_ast

In [None]:
# Perturb orbital elements
sigma_a = 0.0 
sigma_e = 0.0 
sigma_f_deg = 0.1
sigma_Omega_deg = 0.0
sigma_omega_deg = 0.0
mask_pert = None
random_seed = 42

elts_pert = perturb_elts(elts_ast, sigma_a=sigma_a, sigma_e=sigma_e, sigma_f_deg=sigma_f_deg, 
                         sigma_Omega_deg=sigma_Omega_deg, sigma_omega_deg=sigma_omega_deg,
                         mask_pert=mask_pert, random_seed=random_seed)

In [None]:
elts_pert

In [None]:
# Random elements
elts_rand = random_elts(element_id_start=0, size=elt_batch_size, random_seed=random_seed)

In [None]:
elts_rand

## Batches of ZTF Data vs. Elements

In [None]:
# Arguments to make_ztf_batch
thresh_deg = 1.0
near_ast = False
regenerate = False

In [None]:
# Load unperturbed element batch
ztf_elt_ast = load_ztf_batch(elts=elts_ast, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [None]:
# Load perturbed element batch
ztf_elt_pert = load_ztf_batch(elts=elts_pert, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [None]:
# Load random element batch
ztf_elt_rand = load_ztf_batch(elts=elts_rand, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [None]:
ztf_elt_ast

In [None]:
# Review hits
mask = ztf_elt_ast.is_hit
ztf_elt_ast[mask]

In [None]:
ztf_elt_ast.columns

In [None]:
# Alias ztf_elt_ast to ztf_elt
ztf_elt = ztf_elt_ast

In [None]:
# Build numpy array of times
ts_np = ztf_elt.mjd.values.astype(dtype_np)

# Get observation count per element
row_lengths_np = ztf_elt.element_id.groupby(ztf_elt.element_id).count().values.astype(np.int32)

In [None]:
# Review results
element_id_best = ast_num_best[0]
mask = (ztf_elt.element_id == element_id_best)
hits_best = np.sum(ztf_elt[mask].is_hit)
hit_rate_best = np.mean(ztf_elt[mask].is_hit)
rows_best = np.sum(mask)
s_sec_min = np.min(ztf_elt[mask].s_sec)
idx = np.argmin(ztf_elt.s)
ztf_id = ztf_elt.ztf_id[idx]
# ztf_elt[mask].iloc[idx:idx+1]
print(f'Best asteroid has element_id = {element_id_best}')
print(f'Hit count: {hits_best} / {rows_best} observations')
print(f'Hit rate : {hit_rate_best:8.6f}')
print(f'Closest hit: {s_sec_min:0.3f} arc seconds')
# ztf_elt[mask]

## Build Asteroid Search Model

In [None]:
# Additional arguments for asteroid search models
site_name = 'palomar'
h = 0.01
R_deg = 1.0

# Training parameters
learning_rate = 2.0E-4
clipnorm = 1.0

In [None]:
# Build asteroid search model
model = AsteroidSearchModel(
        elts=elts_ast, ztf_elt=ztf_elt, site_name=site_name,
        thresh_deg=thresh_deg, h=h, R_deg=R_deg,
        learning_rate=learning_rate, clipnorm=clipnorm)

In [None]:
# Dummy inputs for search model; any array with shape [elt_batch_size,] is good
x = np.ones(elt_batch_size)

In [None]:
# Run model on unperturbed elements
log_like, elts_tf, mixture = model(x)

In [None]:
# Summarize log likelihood on unperturbed elements
log_like_tot = np.sum(log_like)
log_like_mean = np.mean(log_like)
log_like_std = np.std(log_like)

# Report on unperturbed elements
print(f'Log likelihood:')
print(f'Total: {log_like_tot:8.2f}')
print(f'Mean: {log_like_mean:8.2f}')
print(f'Std : {log_like_std:8.2f}')
print(f'First 5:')
print(log_like[0:5].numpy())

In [None]:
# model.summary()

In [None]:
model.evaluate(x)

## Fit Model on Unperturbed Elements

In [None]:
model.current_loss()

In [None]:
model.restore_best_weights()

In [None]:
model.current_loss()

In [None]:
steps_per_epoch = 1
samples_per_epoch = batch_size*steps_per_epoch
x_trn = tf.ones(samples_per_epoch, dtype=dtype)
# model.recompile()

In [None]:
hist = model.fit(x_trn, batch_size=batch_size, epochs=1, steps_per_epoch=steps_per_epoch, shuffle=False)

In [None]:
model.search_adaptive(max_batches=10000, 
                      batches_per_epoch=20,
                      epochs_per_episode=10,
                      verbose=0)

In [None]:
# Callbacks
early_stop = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=0, restore_best_weights=True, verbose=True)
callbacks = [early_stop]

In [None]:
# Set training length: epochs, steps per epoch
epochs = 10
steps_per_epoch = 200
samples_per_epoch = steps_per_epoch * batch_size
x_trn = np.ones(samples_per_epoch)

In [None]:
# Evaluate before training
model.evaluate(x)

In [None]:
# Review learning rate
print(f'learning_rate = {model.learning_rate}')
print(f'clipnorm      = {model.clipnorm}')

In [None]:
# Train model
hist = model.fit(x=x_trn, batch_size=batch_size, epochs=20, steps_per_epoch=steps_per_epoch, 
                 callbacks=callbacks, shuffle=False, verbose=1)

In [None]:
model.evaluate(x)

In [None]:
hist = model.fit(x=x_trn, batch_size=batch_size, epochs=epochs, steps_per_epoch=steps_per_epoch, 
                 callbacks=callbacks, shuffle=False, verbose=1)

In [None]:
model.evaluate(x)

In [None]:
# Predict
log_like, elts_tf, mixture = model.predict(x)

# Report mixture
h = mixture[:, 0]
lam = mixture[:, 1]
h_mean = np.mean(h)
lam_mean = np.mean(lam)
print(f'h_mean      = {h_mean:8.6f}')
print(f'lambda_mean = {lam_mean:6.2e}')

In [None]:
# model.elements.get_weights()

In [None]:
# Build asteroid search model
model = make_model_asteroid_search(
        elts=elts_ast, ztf_elt=ztf_elt, site_name=site_name,
        thresh_deg=thresh_deg, h=h, R_deg=R_deg)

In [None]:
# model.best_loss

In [None]:
# model.best_weights

In [None]:
# Train model adaptively
ast_search_adaptive(model,
                    learning_rate=1.0E-4, clipnorm=1.0,
                    max_epochs=20, batch_size=batch_size)

In [None]:
# Second (tune-up) adaptive training
ast_search_adaptive(model,
                    learning_rate=None, clipnorm=None,
                    max_epochs=10, batch_size=batch_size)

## Model Diagnostic

In [None]:
# Threshold
thresh_s = keras.backend.constant(deg2dist(thresh_deg))
thresh_s2 = keras.backend.constant(thresh_s**2)
thresh_z = keras.backend.constant(np.sqrt(1.0 - thresh_s2/2.0))

# Report thresholds
print(f'Thresholds:')
print(f's  : {thresh_s:6.2e}')
print(f's2 : {thresh_s2:6.2e}')
# print(f'z  : {thresh_z:10.8f}')
print(f'1-z: {1.0 - thresh_z:6.2e}')

In [None]:
from asteroid_model import AsteroidPosition, AsteroidDirection
from asteroid_search_model import OrbitalElements, TrajectoryScore

In [None]:
space_dims = 3

In [None]:
# Alias inputs
elts = elts_ast
ztf_elt = ztf_elt_ast

In [None]:
# Observed directions; extract from ztf_elt DataFrame
cols_u_obs = ['ux', 'uy', 'uz']
u_obs_np = ztf_elt[cols_u_obs].values.astype(dtype_np)    

# Set of trainable weights with candidate orbital elements; initialize according to elts
elements_layer = OrbitalElements(elts=elts, h=h, lam=lam, name='candidates')

# Extract the candidate elements and mixture parameters; pass dummy inputs to satisfy keras Layer API
a, e, inc, Omega, omega, f, epoch, h, lam = elements_layer(inputs=x)

In [None]:
# The orbital elements; stack to shape (elt_batch_size, 7)
elts_tf = tf.stack(values=[a, e, inc, Omega, omega, f, epoch], axis=1, name='elts')

In [None]:
# The predicted direction
direction_layer = AsteroidDirection(ts_np=ts_np, row_lengths_np=row_lengths_np, 
                                    site_name=site_name, name='direction_layer')

# Calibration arrays (flat)
cols_q_ast = ['qx', 'qy', 'qz']
cols_v_ast = ['vx', 'vy', 'vz']
q_ast = ztf_elt[cols_q_ast].values.astype(dtype_np)
v_ast = ztf_elt[cols_v_ast].values.astype(dtype_np)

In [None]:
# Run calibration
direction_layer.q_layer.calibrate(elts=elts, q_ast=q_ast, v_ast=v_ast)

# Tensor of predicted directions
u_pred, r_pred = direction_layer(a, e, inc, Omega, omega, f, epoch)

In [None]:
# Score layer for these observations
score_layer = TrajectoryScore(row_lengths_np=row_lengths_np, u_obs_np=u_obs_np,
                              thresh_deg=thresh_deg, name='score_layer')

# Compute the log likelihood by element from the predicted direction and mixture model parameters
log_like = score_layer(u_pred, h=h, lam=lam)

In [None]:
# Check selected row: row 11 has ztf_id = 341737, elt_id 733 (first hit)
u_pred[11]

In [None]:
# Data shapes
data_size = keras.backend.constant(value=tf.reduce_sum(row_lengths_np), dtype=tf.int32)
row_lengths = keras.backend.constant(value=row_lengths_np, shape=row_lengths_np.shape, dtype=tf.int32)
u_shape = (data_size, space_dims,)        

In [None]:
# Save u_obs
u_obs = keras.backend.constant(value=u_obs_np, shape=u_shape, dtype=dtype)

In [None]:
# Calculate distance
du = u_pred - u_obs
s2 = tf.reduce_sum(tf.square(du), axis=(-1), name='s2')

In [None]:
s2[11]

In [None]:
# Filter to only include terms where z2 is within the threshold distance^2
is_close = tf.math.less(s2, thresh_s2, name='is_close')

In [None]:
is_close[11]

In [None]:
# Relative distance v on data inside threshold
v = tf.divide(tf.boolean_mask(tensor=s2, mask=is_close), thresh_s2, name='v')

In [None]:
v[11]

In [None]:
# Row_lengths, for close observations only
# is_close_r = tf.RaggedTensor.from_row_lengths(values=is_close, row_lengths=self.row_lengths, name='is_close_r')
ragged_map_func = lambda x : tf.RaggedTensor.from_row_lengths(values=x, row_lengths=row_lengths)
is_close_r = tf.keras.layers.Lambda(function=ragged_map_func, name='is_close_r')(is_close)
row_lengths_close = tf.reduce_sum(tf.cast(is_close_r, tf.int32), axis=1, name='row_lengths_close')

In [None]:
row_lengths_close[0]

In [None]:
# Shape of parameters
close_size = tf.reduce_sum(row_lengths_close)
param_shape = (close_size,)

# Upsample h and lambda
h_rep = tf.repeat(input=h, repeats=row_lengths_close, name='h_rep')
h_vec = tf.reshape(tensor=h_rep, shape=param_shape, name='h_vec')
lam_rep = tf.repeat(input=lam, repeats=row_lengths_close, name='lam_rep')
lam_vec = tf.reshape(tensor=lam_rep, shape=param_shape, name='lam_vec')

In [None]:
h_vec[11]

In [None]:
lam_vec[11]

In [None]:
# Probability according to mixture model
emlx = tf.exp(-lam_vec * v, name='p_hit_cond') 
p_hit_cond_num = tf.multiply(emlx, lam_vec)
p_hit_cond_den = tf.subtract(1.0, tf.exp(-lam_vec))
p_hit_cond = tf.divide(p_hit_cond_num, p_hit_cond_den)
p_hit = tf.multiply(h_vec, p_hit_cond, name='p_hit')
p_miss = tf.subtract(1.0, h_vec, name='p_miss')
p = tf.add(p_hit, p_miss, name='p')
log_p_flat = keras.layers.Activation(tf.math.log, name='log_p_flat')(p)

In [None]:
log_p_flat[11]

In [None]:
p[11]

In [None]:
p_hit_cond[11]