In [1]:
# Core
import numpy as np
import pandas as pd

# Tensorflow / ML
import tensorflow as tf
from tensorflow.python.keras import backend as K

# Utility
import time

# Local
from asteroid_integrate import load_ast_elt
from candidate_element import orbital_element_batch, perturb_elts, random_elts
from ztf_data import load_ztf_nearest_ast, calc_hit_freq, load_ztf_batch, make_ztf_batch
from asteroid_model import AsteroidPosition, AsteroidDirection
from asteroid_search_layers import CandidateElements, TrajectoryScore
from asteroid_search_model import make_model_asteroid_search, ast_search_adaptive
from asteroid_search_model import adjust_learning_rate, compile_search_model
from astro_utils import deg2dist, dist2deg, dist2sec
from tf_utils import Identity

Found 4 GPUs.  Setting memory growth = True.


In [2]:
# Aliases
keras = tf.keras

# Constants
dtype = tf.float32
dtype_np = np.float32
space_dims = 3

## Load ZTF Data and Batch of Orbital Elements

In [3]:
# Load orbital elements for known asteroids
ast_elt = load_ast_elt()

# Number of asteroids
N_ast = ast_elt.shape[0]

In [4]:
# Load ztf nearest asteroid data
ztf_ast = load_ztf_nearest_ast()

In [5]:
# Asteroid numbers and hit counts
ast_num, hit_count = calc_hit_freq(ztf=ztf_ast, thresh_sec=2.0)

# Sort the hit counts in descending order and find the top batch_size
idx = np.argsort(hit_count)[::-1]

# Extract the asteroid number and hit count for this batch
ast_num_best = ast_num[idx]
hit_count_best = hit_count[idx]

In [6]:
# Set batch size
batch_size = 64
elt_batch_size = batch_size

# Batch of unperturbed elements
elts_ast = orbital_element_batch(ast_nums=ast_num_best[0:batch_size])

In [7]:
elts_ast

Unnamed: 0,element_id,a,e,inc,Omega,omega,f,epoch
0,51921,2.669306,0.217361,0.499554,4.699703,2.450796,-1.133491,58600.0
1,59244,2.634727,0.262503,0.465045,5.738297,1.766995,-1.601363,58600.0
2,15786,1.883227,0.047655,0.392360,6.134689,0.804823,-1.246069,58600.0
3,3904,2.556387,0.098279,0.261542,5.450163,2.202423,-1.357345,58600.0
4,142999,2.619944,0.191376,0.514017,0.238022,0.946463,-1.299301,58600.0
...,...,...,...,...,...,...,...,...
59,11952,2.219650,0.086091,0.117967,0.042442,2.904823,-3.016580,58600.0
60,134815,2.612770,0.140831,0.513923,0.272689,0.645552,-0.957836,58600.0
61,27860,2.619406,0.096185,0.200633,5.541399,3.266046,3.948770,58600.0
62,85937,2.342292,0.197267,0.439063,5.279693,3.210025,3.947687,58600.0


In [8]:
# Perturb orbital elements
sigma_a = 0.0 
sigma_e = 0.0 
sigma_f_deg = 0.1
sigma_Omega_deg = 0.0
sigma_omega_deg = 0.0
mask_pert = None
random_seed = 42

elts_pert = perturb_elts(elts_ast, sigma_a=sigma_a, sigma_e=sigma_e, sigma_f_deg=sigma_f_deg, 
                         sigma_Omega_deg=sigma_Omega_deg, sigma_omega_deg=sigma_omega_deg,
                         mask_pert=mask_pert, random_seed=random_seed)

In [9]:
elts_pert

Unnamed: 0,element_id,a,e,inc,Omega,omega,f,epoch
0,51921,2.669306,0.217361,0.499554,4.699703,2.450796,-1.133117,58600.0
1,59244,2.634727,0.262503,0.465045,5.738297,1.766995,-1.603537,58600.0
2,15786,1.883227,0.047655,0.392360,6.134689,0.804823,-1.245767,58600.0
3,3904,2.556387,0.098279,0.261542,5.450163,2.202423,-1.356673,58600.0
4,142999,2.619945,0.191376,0.514017,0.238022,0.946463,-1.300844,58600.0
...,...,...,...,...,...,...,...,...
59,11952,2.219650,0.086091,0.117967,0.042442,2.904823,-3.014978,58600.0
60,134815,2.612770,0.140831,0.513923,0.272689,0.645552,-0.954132,58600.0
61,27860,2.619406,0.096185,0.200633,5.541399,3.266046,3.950572,58600.0
62,85937,2.342292,0.197267,0.439063,5.279693,3.210025,3.945035,58600.0


In [10]:
# Random elements
elts_rand = random_elts(element_id_start=0, size=elt_batch_size, random_seed=random_seed)

In [11]:
elts_rand

Unnamed: 0,element_id,a,e,f,inc,Omega,omega,epoch
0,0,2.346512,0.191774,3.947884,0.123224,5.077989,4.698056,58600.0
1,1,3.002211,0.239903,-2.104391,0.032646,3.623584,5.988327,58600.0
2,2,2.317087,0.055763,-4.869156,0.034794,3.476631,2.078165,58600.0
3,3,2.349419,0.200190,3.323487,0.222124,1.726133,3.473125,58600.0
4,4,2.712220,0.116107,4.183110,0.200886,0.668298,3.595820,58600.0
...,...,...,...,...,...,...,...,...
59,59,2.328723,0.639718,1.285795,0.064922,1.580547,4.395165,58600.0
60,60,3.972593,0.039944,-0.006957,0.176738,5.749878,0.869298,58600.0
61,61,3.019390,0.032757,-3.315554,0.115168,1.725482,0.834064,58600.0
62,62,2.481321,0.025182,2.025885,0.337006,2.972029,6.091780,58600.0


## Batches of ZTF Data vs. Elements

In [12]:
# Arguments to make_ztf_batch
thresh_deg = 1.0
near_ast = False
regenerate = False

In [13]:
# Load unperturbed element batch
ztf_elt_ast = load_ztf_batch(elts=elts_ast, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [14]:
# Load perturbed element batch
ztf_elt_pert = load_ztf_batch(elts=elts_pert, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [15]:
# Load random element batch
ztf_elt_rand = load_ztf_batch(elts=elts_rand, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [16]:
ztf_elt_ast

Unnamed: 0,ztf_id,element_id,ObjectID,CandidateID,TimeStampID,mjd,ra,dec,ux,uy,...,vz,elt_ux,elt_uy,elt_uz,elt_r,s,s_sec,z,v,is_hit
0,53851,733,b'ZTF18abnothj',594197584815010004,5501,58348.197581,266.229165,-13.513802,-0.063945,-0.983101,...,0.004080,-0.057300,-0.982042,0.179751,2.234078,0.010624,2191.371398,0.999944,0.370539,False
1,73604,733,b'ZTF18ablwzmb',594197584815015003,5501,58348.197581,265.761024,-13.509148,-0.071871,-0.982578,...,0.004080,-0.057300,-0.982042,0.179751,2.234078,0.016809,3467.103003,0.999859,0.927533,False
2,82343,733,b'ZTF18abiydvm',635193253015015018,12089,58389.193252,270.331454,-11.244934,0.005674,-0.977422,...,0.003825,0.000918,-0.977996,0.208622,2.703478,0.005450,1124.142942,0.999985,0.097510,False
3,257221,733,b'ZTF18acakcqg',931471223715015007,39920,58685.471227,29.693832,42.180412,0.643725,0.603886,...,-0.001953,0.639004,0.610779,0.467571,2.175851,0.008712,1797.042210,0.999962,0.249184,False
4,327000,733,b'ZTF18achmdmw',937465970615015011,40837,58691.465972,33.104905,44.059131,0.601970,0.636719,...,-0.002129,0.606278,0.637608,0.475272,2.114865,0.007949,1639.537152,0.999968,0.207418,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
90205,5650588,324582,b'ZTF20aaqvhld',1150176701515015008,96618,58904.176701,44.164238,29.650540,0.623416,0.752309,...,-0.001541,0.627640,0.750696,0.206212,2.981799,0.008187,1688.636853,0.999966,0.220027,False
90206,5650589,324582,b'ZTF20aaqvhld',1150176245715015005,96617,58904.176250,44.164062,29.650536,0.623417,0.752307,...,-0.001541,0.627641,0.750695,0.206213,2.981793,0.008187,1688.600639,0.999966,0.220018,False
90207,5650665,324582,b'ZTF20aaqvhll',1150176245815015010,96617,58904.176250,44.368640,28.490480,0.628284,0.753618,...,-0.001541,0.627641,0.750695,0.206213,2.981793,0.013370,2757.856469,0.999911,0.586871,False
90208,5650697,324582,b'ZTF20aaqvhmb',1150176246015015005,96617,58904.176250,43.296207,29.505908,0.633424,0.743491,...,-0.001541,0.627641,0.750695,0.206213,2.981793,0.012388,2555.279465,0.999923,0.503822,False


In [17]:
# Review hits
mask = ztf_elt_ast.is_hit
ztf_elt_ast[mask]

Unnamed: 0,ztf_id,element_id,ObjectID,CandidateID,TimeStampID,mjd,ra,dec,ux,uy,...,vz,elt_ux,elt_uy,elt_uz,elt_r,s,s_sec,z,v,is_hit
11,341737,733,b'ZTF19abizrac',937427766115015019,40797,58691.427766,33.130412,43.596186,0.606481,0.637452,...,-0.002127,0.606487,0.637448,0.475220,2.115259,0.000008,1.551631,1.0,1.857731e-07,True
12,345725,733,b'ZTF19abjajfg',937467364815015020,40840,58691.467361,33.148426,43.605278,0.606265,0.637618,...,-0.002129,0.606270,0.637613,0.475274,2.114851,0.000007,1.522407,1.0,1.788411e-07,True
14,346522,733,b'ZTF19abjajmr',937468726115015011,40842,58691.468727,33.149062,43.605587,0.606257,0.637624,...,-0.002129,0.606263,0.637619,0.475276,2.114837,0.000008,1.547274,1.0,1.847314e-07,True
15,347644,733,b'ZTF19abiyxiu',937402264815015008,40777,58691.402269,33.118785,43.590288,0.606621,0.637345,...,-0.002127,0.606626,0.637341,0.475186,2.115523,0.000007,1.535273,1.0,1.818767e-07,True
20,431445,733,b'ZTF19abkkfhr',934448315015015003,40221,58688.448310,31.751906,42.913068,0.622775,0.624451,...,-0.002040,0.622780,0.624447,0.471393,2.145603,0.000007,1.536414,1.0,1.821473e-07,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
90176,5447574,324582,b'ZTF20aapdfmj',1145121244815015014,95204,58899.121250,42.392273,29.202629,0.644677,0.734024,...,-0.001502,0.644680,0.734021,0.213544,2.915471,0.000005,0.932672,1.0,6.712185e-08,True
90182,5450145,324582,b'ZTF20aapdfmj',1145121716115015016,95205,58899.121713,42.392439,29.202638,0.644675,0.734026,...,-0.001502,0.644678,0.734022,0.213543,2.915477,0.000005,1.003729,1.0,7.773903e-08,True
90186,5461311,324582,b'ZTF20aapeobw',1145164884815015018,95284,58899.164884,42.405891,29.202386,0.644538,0.734164,...,-0.001503,0.644541,0.734161,0.213479,2.916052,0.000004,0.888835,1.0,6.096041e-08,True
90187,5461312,324582,b'ZTF20aapeobw',1145165336115015009,95285,58899.165336,42.406082,29.202412,0.644536,0.734166,...,-0.001503,0.644540,0.734163,0.213478,2.916058,0.000005,1.054282,1.0,8.576689e-08,True


In [18]:
ztf_elt_ast.columns

Index(['ztf_id', 'element_id', 'ObjectID', 'CandidateID', 'TimeStampID', 'mjd',
       'ra', 'dec', 'ux', 'uy', 'uz', 'qx', 'qy', 'qz', 'vx', 'vy', 'vz',
       'elt_ux', 'elt_uy', 'elt_uz', 'elt_r', 's', 's_sec', 'z', 'v',
       'is_hit'],
      dtype='object')

In [19]:
# Alias ztf_elt_ast to ztf_elt
ztf_elt = ztf_elt_ast

In [20]:
# Build numpy array of times
ts_np = ztf_elt.mjd.values.astype(dtype_np)

# Get observation count per element
row_lengths_np = ztf_elt.element_id.groupby(ztf_elt.element_id).count().values.astype(np.int32)

In [21]:
# Review results
element_id_best = ast_num_best[0]
mask = (ztf_elt.element_id == element_id_best)
hits_best = np.sum(ztf_elt[mask].is_hit)
hit_rate_best = np.mean(ztf_elt[mask].is_hit)
rows_best = np.sum(mask)
s_sec_min = np.min(ztf_elt[mask].s_sec)
idx = np.argmin(ztf_elt.s)
ztf_id = ztf_elt.ztf_id[idx]
# ztf_elt[mask].iloc[idx:idx+1]
print(f'Best asteroid has element_id = {element_id_best}')
print(f'Hit count: {hits_best} / {rows_best} observations')
print(f'Hit rate : {hit_rate_best:8.6f}')
print(f'Closest hit: {s_sec_min:0.3f} arc seconds')
# ztf_elt[mask]

Best asteroid has element_id = 51921
Hit count: 158 / 708 observations
Hit rate : 0.223164
Closest hit: 0.382 arc seconds


## Build Asteroid Search Model

In [22]:
# Additional arguments for asteroid search models
site_name = 'palomar'
h = 0.01
R_deg = 1.0

In [23]:
# Build asteroid search model
model = make_model_asteroid_search(
        elts=elts_ast, ztf_elt=ztf_elt, site_name=site_name,
        thresh_deg=thresh_deg, h=h, R_deg=R_deg)

In [24]:
# Dummy inputs for search model; any array with shape [elt_batch_size,] is good
x = np.ones(elt_batch_size)

In [25]:
# Run model on unperturbed elements
log_like, elts_tf, mixture = model(x)

In [26]:
# Summarize log likelihood on unperturbed elements
log_like_tot = np.sum(log_like)
log_like_mean = np.mean(log_like)
log_like_std = np.std(log_like)

# Report on unperturbed elements
print(f'Log likelihood:')
print(f'Total: {log_like_tot:8.2f}')
print(f'Mean: {log_like_mean:8.2f}')
print(f'Std : {log_like_std:8.2f}')
print(f'First 5:')
print(log_like[0:5].numpy())

Log likelihood:
Total:    30.02
Mean:     0.47
Std :     0.11
First 5:
[0.57954353 0.49264225 0.6762784  0.393577   0.42903855]


In [30]:
class AsteroidSearchModel(tf.keras.Model):
    def __init__(self, elts: pd.DataFrame, ztf_elt: pd.DataFrame, site_name: str='geocenter', 
                 thresh_deg: float = 1.0, h: float = 0.01, R_deg: float = 1.0, 
                 **kwargs):
        """
        Functional API model for scoring elements
        INPUTS:
            elts:       DataFrame with initial guess for orbital elements.
                        Columns: element_id, a, e, inc, Omega, omega, f, epoch
                        Output of orbital_element_batch, perturb_elts or random_elts
            ztf_elt:    DataFrame with ZTF observations within thresh_deg degrees of
                        of the orbits predicted by these elements.
                        Output of make_ztf_batch or load_ztf_batch
            site_name:  Used for topos adjustment, e.g. 'geocenter' or 'palomar'
            h:          Initial value of hit probability in mixture model
            lam:        Initial value of exponential decay parameter in mixture model
            R_deg:      Initial value of resolution parameter (in degrees) in mixture model
        """
        # Initialize tf.keras.Model
        super(AsteroidSearchModel, self).__init__(**kwargs)
        
        # Batch size comes from elts
        self.batch_size = elts.shape[0]

        # Numpy array and tensor of observation times; flat, shape (data_size,)
        ts_np = ztf_elt.mjd.values.astype(dtype_np)

        # Get observation count per element
        row_lengths_np = ztf_elt.element_id.groupby(ztf_elt.element_id).count()
        self.row_lengths = keras.backend.constant(value=row_lengths_np, shape=(self.batch_size,), dtype=tf.int32)

        # Shape of the observed trajectories
        self.data_size = ztf_elt.shape[0]
        self.traj_shape = (self.data_size, space_dims)

        # Observed directions; extract from ztf_elt DataFrame
        cols_u_obs = ['ux', 'uy', 'uz']
        u_obs_np = ztf_elt[cols_u_obs].values.astype(dtype_np)

        # Convert resolution from degrees to Cartesian distance
        R_s = deg2dist(R_deg)

        # Set of trainable weights with candidate orbital elements; initialize according to elts
        self.candidate_elements = CandidateElements(elts=elts, h=h, R=R_s, name='candidate_elements')

        # Stack the current orbital elements; shape is [elt_batch_size, 7,]
        self.orbital_elements = tf.keras.layers.Concatenate(axis=-1, name='orbital_elements') 

        # Stack mixture model parameters
        self.mixture_params = tf.keras.layers.Concatenate(axis=-1, name='mixture_params')

        # The predicted direction; shape is [data_size, 3,]
        self.direction = AsteroidDirection(ts_np=ts_np, row_lengths_np=row_lengths_np, 
                                           site_name=site_name, name='direction')

        # Calibration arrays (flat)
        cols_q_ast = ['qx', 'qy', 'qz']
        cols_v_ast = ['vx', 'vy', 'vz']
        q_ast = ztf_elt[cols_q_ast].values.astype(dtype_np)
        v_ast = ztf_elt[cols_v_ast].values.astype(dtype_np)

        # Run calibration
        self.direction.q_layer.calibrate(elts=elts, q_ast=q_ast, v_ast=v_ast)

        # Score layer for these observations
        self.score = TrajectoryScore(row_lengths_np=row_lengths_np, u_obs_np=u_obs_np,
                                     thresh_deg=thresh_deg, name='score')

        # Save the learning rate on the model object to facilitate adaptive training
        self.learning_rate = 1.0E-4
        self.clipnorm = 1.0

        # Compile the model
        # compile_search_model(self)

    def call(self, inputs=None):
        # Extract the candidate elements and mixture parameters; pass dummy inputs to satisfy keras Layer API
        a, e, inc, Omega, omega, f, epoch, h, lam, R, = self.candidate_elements(inputs=inputs)
        
        # Stack the current orbital elements
        orbital_elements = self.orbital_elements([a, e, inc, Omega, omega, f, epoch,])
        # Stack mixture model parameters
        mixture_params = self.mixture_params([h, lam, R,])

        # Tensor of predicted directions.  Shape is [data_size, 3,]
        u_pred, r_pred = self.direction(a, e, inc, Omega, omega, f, epoch)        
        
        # Compute the log likelihood by element from the predicted direction and mixture model parameters
        # Shape is [elt_batch_size,]
        log_like = self.score(u_pred, h=h, lam=lam)
        
        # Add the loss function - the NEGATIVE of the log likelihood
        # (Take negative b/c TensorFlow minimizes the loss function)
        self.add_loss(-tf.reduce_sum(log_like))
        
        # Wrap outputs
        outputs = (log_like, orbital_elements, mixture_params)
        
        return outputs

In [31]:
model2 = AsteroidSearchModel(elts=elts_ast, ztf_elt=ztf_elt, site_name=site_name, 
                             thresh_deg=thresh_deg, h=0.01, R_deg=1.0)

In [32]:
log_like2, elts_tf2, mixture2 = model2(x)

In [33]:
log_like2

<tf.Tensor: shape=(64,), dtype=float32, numpy=
array([0.57954407, 0.49264216, 0.6762789 , 0.39357695, 0.42903855,
       0.46678448, 0.4540908 , 0.8074737 , 0.50630903, 0.4601222 ,
       0.5014895 , 0.42889953, 0.51551783, 0.4629032 , 0.729247  ,
       0.44831958, 0.50237894, 0.44125974, 0.5343248 , 0.5904063 ,
       0.4815468 , 0.592896  , 0.51904047, 0.46014905, 0.52789986,
       0.39928222, 0.45920268, 0.45837677, 0.51399684, 0.35359716,
       0.44146407, 0.5588145 , 0.39859107, 0.4259917 , 0.30092096,
       0.4072559 , 0.24244528, 0.5085129 , 0.62426984, 0.43850031,
       0.5087013 , 0.56281   , 0.34524867, 0.5212759 , 0.47598663,
       0.53847355, 0.5233214 , 0.40033844, 0.30608696, 0.44641176,
       0.51399535, 0.43811274, 0.461171  , 0.40968397, 0.17428087,
       0.27800792, 0.35467115, 0.24122518, 0.49376643, 0.798387  ,
       0.42239437, 0.5063009 , 0.4115517 , 0.35718793], dtype=float32)>

In [36]:
from asteroid_search_model import make_adam_opt

In [37]:
optimizer = make_adam_opt(learning_rate=model.learning_rate, clipnorm=model.clipnorm)
model2.compile(optimizer=optimizer)

In [38]:
model2.evaluate(x)



-30.022746086120605

## Fit Model on Unperturbed Elements

In [None]:
# Callbacks
early_stop = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=0, restore_best_weights=True, verbose=True)
callbacks = [early_stop]

In [None]:
# Set training length: epochs, steps per epoch
epochs = 10
steps_per_epoch = 200
samples_per_epoch = steps_per_epoch * batch_size
x_trn = np.ones(samples_per_epoch)

In [None]:
# Evaluate before training
model.evaluate(x)

In [None]:
# Review learning rate
print(f'learning_rate = {model.learning_rate}')
print(f'clipnorm      = {model.clipnorm}')
print(f'clipvalue     = {model.clipvalue}')

In [None]:
# Train model
hist = model.fit(x=x_trn, batch_size=batch_size, epochs=20, steps_per_epoch=steps_per_epoch, 
                 callbacks=callbacks, shuffle=False, verbose=1)

In [None]:
model.evaluate(x)

In [None]:
hist = model.fit(x=x_trn, batch_size=batch_size, epochs=epochs, steps_per_epoch=steps_per_epoch, 
                 callbacks=callbacks, shuffle=False, verbose=1)

In [None]:
model.evaluate(x)

In [None]:
# Predict
log_like, elts_tf, mixture = model.predict(x)

# Report mixture
h = mixture[:, 0]
lam = mixture[:, 1]
h_mean = np.mean(h)
lam_mean = np.mean(lam)
print(f'h_mean      = {h_mean:8.6f}')
print(f'lambda_mean = {lam_mean:6.2e}')

In [None]:
# model.elements.get_weights()

In [None]:
def ast_search_adaptive(model, learning_rate=None, clipnorm=None,
                        batch_size: int = 64, max_epochs: int = 100):
    """
    Run asteroid search model adaptively.  
    Start with a high learning rate, gradually reduce it if early stopping triggered
    """
    
    # Start timer
    t0 = time.time()
    
    # Set the learning rate and clipnorm to the model if they were specified
    if learning_rate is not None:
        model.learning_rate = learning_rate
    if clipnorm is not None:
        model.clipnorm = clipnorm
    
    # Recompile the model
    compile_search_model(model)
    
    # Early stopping callback
    early_stop = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=0, restore_best_weights=True)
    callbacks = [early_stop]

    # Define one epoch as a number of batches
    steps_per_epoch: int = 200
    # steps_per_epoch: int = 10
    samples_per_epoch: int = steps_per_epoch * batch_size
    x_trn: np.ndarray = np.ones(samples_per_epoch)

    # Set number of epochs for one episode of training
    epochs_per_episode: int = 5 # min(10, max_epochs)
    training_episode: int = 1

    # Epoch and episode counters
    episode_length: int = 0
    model.current_epoch: int = 0
    model.best_loss: float = np.inf
    model.best_weights: np.ndarray

    # Set the learning rate factor
    lr_factor_up: float = 1.0 # 2.0**0.125
    lr_factor_dn: float = 0.5
        
    # Verbose flag for training
    verbose = 1
    
    def after_episode():
        nonlocal episode_length, training_episode
        # Update training counters
        episode_length = hist.epoch[-1] + 1
        model.current_epoch += episode_length
        training_episode += 1
        elapsed_time = (time.time() - t0)

        # Update best loss and weights
        current_loss: float = hist.history['loss'][-1]
        print(f'Updating best_loss at end of episode {training_episode}')
        if current_loss < model.best_loss:
            print(f'New best_loss = {current_loss:8.2f}.  Old best_loss was {model.best_loss:8.2f}')
            model.best_loss = current_loss
            model.best_weights = model.elements.get_weights()
        else:
            print(f'Manually restoring best weights to recover best_loss = {model.best_loss:8.2f}')
            model.elements.set_weights(model.best_weights)
            current_loss = model.best_loss

        # Log likelihood is negative of the loss
        log_like = -current_loss        
        
        # Update early_stop
        early_stop = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=0, baseline=model.best_loss, restore_best_weights=True)

        # Status message
        print(f'Epoch {model.current_epoch:4}. Elapsed time = {elapsed_time:0.0f} sec')
        print(f'Log Likelihood: {log_like:8.2f}')

    # Run first episode of training
    hist = model.fit(x=x_trn, batch_size=batch_size, epochs=epochs_per_episode, steps_per_epoch=steps_per_epoch, 
                     callbacks=callbacks, shuffle=False, verbose=verbose)            
    after_episode()
    
    # Continue training until max_epochs have elapsed
    while model.current_epoch < max_epochs:        
        # If the last training ran without early stopping, increase the learning rate
        if episode_length == epochs_per_episode:
            # adjust_learning_rate(model, lr_factor_up, verbose=False)
            pass
        # If the last training hit early stopping, decrease the learning rate
        else:
            adjust_learning_rate(model, lr_factor_dn, verbose=False)
        # Train for another episode
        print(f'\nTraining episode {training_episode}:')
        print(f'Epoch {model.current_epoch:4}, learning_rate={model.learning_rate:8.3e}, clipnorm={model.clipnorm:6.3f}.')
        hist = model.fit(x=x_trn, batch_size=batch_size, epochs=epochs_per_episode, steps_per_epoch=steps_per_epoch, 
                         callbacks=callbacks, shuffle=False, verbose=verbose)        
        after_episode()

In [None]:
# Build asteroid search model
model = make_model_asteroid_search(
        elts=elts_ast, ztf_elt=ztf_elt, site_name=site_name,
        thresh_deg=thresh_deg, h=h, R_deg=R_deg)

In [None]:
# model.best_loss

In [None]:
# model.best_weights

In [None]:
# Train model adaptively
ast_search_adaptive(model,
                    learning_rate=1.0E-4, clipnorm=1.0,
                    max_epochs=20, batch_size=batch_size)

In [None]:
# Second (tune-up) adaptive training
ast_search_adaptive(model,
                    learning_rate=None, clipnorm=None,
                    max_epochs=10, batch_size=batch_size)

## Model Diagnostic

In [None]:
# Threshold
thresh_s = keras.backend.constant(deg2dist(thresh_deg))
thresh_s2 = keras.backend.constant(thresh_s**2)
thresh_z = keras.backend.constant(np.sqrt(1.0 - thresh_s2/2.0))

# Report thresholds
print(f'Thresholds:')
print(f's  : {thresh_s:6.2e}')
print(f's2 : {thresh_s2:6.2e}')
# print(f'z  : {thresh_z:10.8f}')
print(f'1-z: {1.0 - thresh_z:6.2e}')

In [None]:
from asteroid_model import AsteroidPosition, AsteroidDirection
from asteroid_search_model import OrbitalElements, TrajectoryScore

In [None]:
space_dims = 3

In [None]:
# Alias inputs
elts = elts_ast
ztf_elt = ztf_elt_ast

In [None]:
# Observed directions; extract from ztf_elt DataFrame
cols_u_obs = ['ux', 'uy', 'uz']
u_obs_np = ztf_elt[cols_u_obs].values.astype(dtype_np)    

# Set of trainable weights with candidate orbital elements; initialize according to elts
elements_layer = OrbitalElements(elts=elts, h=h, lam=lam, name='candidates')

# Extract the candidate elements and mixture parameters; pass dummy inputs to satisfy keras Layer API
a, e, inc, Omega, omega, f, epoch, h, lam = elements_layer(inputs=x)

In [None]:
# The orbital elements; stack to shape (elt_batch_size, 7)
elts_tf = tf.stack(values=[a, e, inc, Omega, omega, f, epoch], axis=1, name='elts')

In [None]:
# The predicted direction
direction_layer = AsteroidDirection(ts_np=ts_np, row_lengths_np=row_lengths_np, 
                                    site_name=site_name, name='direction_layer')

# Calibration arrays (flat)
cols_q_ast = ['qx', 'qy', 'qz']
cols_v_ast = ['vx', 'vy', 'vz']
q_ast = ztf_elt[cols_q_ast].values.astype(dtype_np)
v_ast = ztf_elt[cols_v_ast].values.astype(dtype_np)

In [None]:
# Run calibration
direction_layer.q_layer.calibrate(elts=elts, q_ast=q_ast, v_ast=v_ast)

# Tensor of predicted directions
u_pred, r_pred = direction_layer(a, e, inc, Omega, omega, f, epoch)

In [None]:
# Score layer for these observations
score_layer = TrajectoryScore(row_lengths_np=row_lengths_np, u_obs_np=u_obs_np,
                              thresh_deg=thresh_deg, name='score_layer')

# Compute the log likelihood by element from the predicted direction and mixture model parameters
log_like = score_layer(u_pred, h=h, lam=lam)

In [None]:
# Check selected row: row 11 has ztf_id = 341737, elt_id 733 (first hit)
u_pred[11]

In [None]:
# Data shapes
data_size = keras.backend.constant(value=tf.reduce_sum(row_lengths_np), dtype=tf.int32)
row_lengths = keras.backend.constant(value=row_lengths_np, shape=row_lengths_np.shape, dtype=tf.int32)
u_shape = (data_size, space_dims,)        

In [None]:
# Save u_obs
u_obs = keras.backend.constant(value=u_obs_np, shape=u_shape, dtype=dtype)

In [None]:
# Calculate distance
du = u_pred - u_obs
s2 = tf.reduce_sum(tf.square(du), axis=(-1), name='s2')

In [None]:
s2[11]

In [None]:
# Filter to only include terms where z2 is within the threshold distance^2
is_close = tf.math.less(s2, thresh_s2, name='is_close')

In [None]:
is_close[11]

In [None]:
# Relative distance v on data inside threshold
v = tf.divide(tf.boolean_mask(tensor=s2, mask=is_close), thresh_s2, name='v')

In [None]:
v[11]

In [None]:
# Row_lengths, for close observations only
# is_close_r = tf.RaggedTensor.from_row_lengths(values=is_close, row_lengths=self.row_lengths, name='is_close_r')
ragged_map_func = lambda x : tf.RaggedTensor.from_row_lengths(values=x, row_lengths=row_lengths)
is_close_r = tf.keras.layers.Lambda(function=ragged_map_func, name='is_close_r')(is_close)
row_lengths_close = tf.reduce_sum(tf.cast(is_close_r, tf.int32), axis=1, name='row_lengths_close')

In [None]:
row_lengths_close[0]

In [None]:
# Shape of parameters
close_size = tf.reduce_sum(row_lengths_close)
param_shape = (close_size,)

# Upsample h and lambda
h_rep = tf.repeat(input=h, repeats=row_lengths_close, name='h_rep')
h_vec = tf.reshape(tensor=h_rep, shape=param_shape, name='h_vec')
lam_rep = tf.repeat(input=lam, repeats=row_lengths_close, name='lam_rep')
lam_vec = tf.reshape(tensor=lam_rep, shape=param_shape, name='lam_vec')

In [None]:
h_vec[11]

In [None]:
lam_vec[11]

In [None]:
# Probability according to mixture model
emlx = tf.exp(-lam_vec * v, name='p_hit_cond') 
p_hit_cond_num = tf.multiply(emlx, lam_vec)
p_hit_cond_den = tf.subtract(1.0, tf.exp(-lam_vec))
p_hit_cond = tf.divide(p_hit_cond_num, p_hit_cond_den)
p_hit = tf.multiply(h_vec, p_hit_cond, name='p_hit')
p_miss = tf.subtract(1.0, h_vec, name='p_miss')
p = tf.add(p_hit, p_miss, name='p')
log_p_flat = keras.layers.Activation(tf.math.log, name='log_p_flat')(p)

In [None]:
log_p_flat[11]

In [None]:
p[11]

In [None]:
p_hit_cond[11]