In [1]:
# Core
import numpy as np
import pandas as pd

# Tensorflow / ML
import tensorflow as tf
from tensorflow.python.keras import backend as K

# Plotting
import matplotlib.pyplot as plt
import matplotlib as mpl

# Utility
import time

# Local
from asteroid_integrate import load_ast_elt
from candidate_element import orbital_element_batch, perturb_elts, random_elts
from ztf_data import load_ztf_nearest_ast, calc_hit_freq, load_ztf_batch, make_ztf_batch
from asteroid_model import AsteroidPosition, AsteroidDirection
from asteroid_search_layers import CandidateElements, TrajectoryScore
from asteroid_search_model import AsteroidSearchModel, make_adam_opt
from astro_utils import deg2dist, dist2deg, dist2sec
from tf_utils import Identity

Found 4 GPUs.  Setting memory growth = True.


In [2]:
# Aliases
keras = tf.keras

# Constants
dtype = tf.float32
dtype_np = np.float32
space_dims = 3

In [None]:
# Set plot style variables
mpl.rcParams['figure.figsize'] = [16.0, 10.0]
mpl.rcParams['font.size'] = 16

## Load ZTF Data and Batch of Orbital Elements

In [3]:
# Load orbital elements for known asteroids
ast_elt = load_ast_elt()

# Number of asteroids
N_ast = ast_elt.shape[0]

In [4]:
# Load ztf nearest asteroid data
ztf_ast = load_ztf_nearest_ast()

In [5]:
# Asteroid numbers and hit counts
ast_num, hit_count = calc_hit_freq(ztf=ztf_ast, thresh_sec=2.0)

# Sort the hit counts in descending order and find the top batch_size
idx = np.argsort(hit_count)[::-1]

# Extract the asteroid number and hit count for this batch
ast_num_best = ast_num[idx]
hit_count_best = hit_count[idx]

In [6]:
# Set batch size
batch_size = 64
elt_batch_size = batch_size

# Batch of unperturbed elements
elts_ast = orbital_element_batch(ast_nums=ast_num_best[0:batch_size])

In [7]:
elts_ast

Unnamed: 0,element_id,a,e,inc,Omega,omega,f,epoch
0,51921,2.669306,0.217361,0.499554,4.699703,2.450796,-1.133491,58600.0
1,59244,2.634727,0.262503,0.465045,5.738297,1.766995,-1.601363,58600.0
2,15786,1.883227,0.047655,0.392360,6.134689,0.804823,-1.246069,58600.0
3,3904,2.556387,0.098279,0.261542,5.450163,2.202423,-1.357345,58600.0
4,142999,2.619944,0.191376,0.514017,0.238022,0.946463,-1.299301,58600.0
...,...,...,...,...,...,...,...,...
59,11952,2.219650,0.086091,0.117967,0.042442,2.904823,-3.016580,58600.0
60,134815,2.612770,0.140831,0.513923,0.272689,0.645552,-0.957836,58600.0
61,27860,2.619406,0.096185,0.200633,5.541399,3.266046,3.948770,58600.0
62,85937,2.342292,0.197267,0.439063,5.279693,3.210025,3.947687,58600.0


In [8]:
# Perturb orbital elements
# Easy perturbation
# sigma_a = 0.0
# sigma_e = 0.0 
# sigma_f_deg = 0.1
# sigma_Omega_deg = 0.0
# sigma_omega_deg = 0.0

# Large perturbation
sigma_a = 0.05
sigma_e = 0.01
sigma_f_deg = 1.0
sigma_Omega_deg = 1.0
sigma_omega_deg = 1.0
mask_pert = None
random_seed = 42

mask_pert = None
random_seed = 42

elts_pert = perturb_elts(elts_ast, sigma_a=sigma_a, sigma_e=sigma_e, sigma_f_deg=sigma_f_deg, 
                         sigma_Omega_deg=sigma_Omega_deg, sigma_omega_deg=sigma_omega_deg,
                         mask_pert=mask_pert, random_seed=random_seed)

In [9]:
elts_pert

Unnamed: 0,element_id,a,e,inc,Omega,omega,f,epoch
0,51921,2.736430,0.219134,0.499554,4.721815,2.452489,-1.129754,58600.0
1,59244,2.616575,0.266087,0.465045,5.725946,1.777382,-1.623105,58600.0
2,15786,1.945213,0.047621,0.392360,6.142436,0.790543,-1.243047,58600.0
3,3904,2.758664,0.099270,0.261542,5.463683,2.238942,-1.350620,58600.0
4,142999,2.589450,0.192070,0.514017,0.221844,0.928905,-1.314727,58600.0
...,...,...,...,...,...,...,...,...
59,11952,2.330603,0.084892,0.117967,0.042808,2.890716,-3.000560,58600.0
60,134815,2.550916,0.141660,0.513923,0.284591,0.630896,-0.920797,58600.0
61,27860,2.595202,0.098315,0.200633,5.535985,3.255585,3.966790,58600.0
62,85937,2.216242,0.195323,0.439063,5.285351,3.172956,3.921169,58600.0


In [10]:
# Random elements
elts_rand = random_elts(element_id_start=0, size=elt_batch_size, random_seed=random_seed, dtype=dtype_np)

In [11]:
elts_rand

Unnamed: 0,element_id,a,e,f,inc,Omega,omega,epoch
0,0,2.346512,0.191774,3.947884,0.123224,5.077989,4.698056,58600.0
1,1,3.002211,0.239903,-2.104391,0.032646,3.623584,5.988327,58600.0
2,2,2.317087,0.055763,-4.869156,0.034794,3.476631,2.078165,58600.0
3,3,2.349419,0.200190,3.323487,0.222124,1.726133,3.473125,58600.0
4,4,2.712220,0.116107,4.183110,0.200886,0.668298,3.595820,58600.0
...,...,...,...,...,...,...,...,...
59,59,2.328723,0.639718,1.285795,0.064922,1.580547,4.395165,58600.0
60,60,3.972593,0.039944,-0.006957,0.176738,5.749878,0.869298,58600.0
61,61,3.019390,0.032757,-3.315554,0.115168,1.725482,0.834064,58600.0
62,62,2.481321,0.025182,2.025885,0.337006,2.972029,6.091780,58600.0


## Batches of ZTF Data vs. Elements

In [12]:
# Arguments to make_ztf_batch
# thresh_deg = 1.0
thresh_deg = 4.0
near_ast = False
regenerate = False

In [13]:
# Load unperturbed element batch
ztf_elt_ast = load_ztf_batch(elts=elts_ast, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [14]:
# Load perturbed element batch
ztf_elt_pert = load_ztf_batch(elts=elts_pert, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [15]:
# Load random element batch
ztf_elt_rand = load_ztf_batch(elts=elts_rand, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [16]:
def ztf_elt_summary(ztf_elt: pd.DataFrame, elt_name: str):
    """Report summary attributes of a ztf_elt dataframe"""
    # Calculate summary statistics
    num_obs = ztf_elt.shape[0]
    batch_size = np.unique(ztf_elt.element_id).size
    obs_per_batch = num_obs / batch_size
    num_hits = np.sum(ztf_elt.is_hit)
    hits_per_batch = num_hits / batch_size
    hit_rate = np.mean(ztf_elt.is_hit)    

    # Score by element; use log_v as a proxy.  This has E[log(v)] = 0, Var[log(v)] = 1 b/c V ~ Unif[0, 1]
    score_func = lambda x: -1.0 - np.log(x)
    # ztf_elt['score'] = 1.0 - np.log(ztf_elt.v)
    # score_by_elt = ztf_elt['v'].apply(np.log).groupby(ztf_elt.element_id).agg(['sum', 'count'])
    score_by_elt = ztf_elt['v'].apply(score_func).groupby(ztf_elt.element_id).agg(['sum', 'count'])
    score_by_elt.rename(columns={'sum': 'score_sum', 'count': 'num_obs'}, inplace=True)
    score_by_elt['t_score'] = score_by_elt['score_sum'] / np.sqrt(score_by_elt['num_obs'])    
    # Summarize log_v for the elements
    mean_score_sum = np.mean(score_by_elt.score_sum)
    mean_t_score = np.mean(score_by_elt.t_score)
    
    # Report results
    print(f'ZTF Element Dataframe {elt_name}:')
    print(f'                  Total     (Per Batch)')
    print(f'Observations   : {num_obs:8d}   ({obs_per_batch:9.0f})')
    print(f'Hits           : {num_hits:8d}   ({hits_per_batch:9.2f})')
    # print(f'Hit Rate    : {hit_rate*100:8.4f}%')
    print(f'\nSummarize score = sum(-1.0 - log(v)) by batch.  (Mean=0, Variance=num_obs)')
    print(f'Mean score     :  {mean_score_sum:9.2f}')
    print(f'Sqrt(batch_obs):  {np.sqrt(obs_per_batch):9.2f}')
    print(f'Mean t_score   :  {mean_t_score:9.2f}')
    
    return score_by_elt

In [17]:
# Summarize the ztf element batch: unperturbed asteroids
score_by_elt_ast = ztf_elt_summary(ztf_elt_ast, 'Unperturbed Asteroids')

ZTF Element Dataframe Unperturbed Asteroids:
                  Total     (Per Batch)
Observations   :  1233691   (    19276)
Hits           :    10331   (   161.42)

Summarize score = sum(-1.0 - log(v)) by batch.  (Mean=0, Variance=num_obs)
Mean score     :    3347.30
Sqrt(batch_obs):     138.84
Mean t_score   :      26.13


In [18]:
# score_by_elt_ast

In [19]:
# Summarize the ztf element batch: perturbed asteroids
score_by_elt_pert = ztf_elt_summary(ztf_elt_pert, 'Perturbed Asteroids')

ZTF Element Dataframe Perturbed Asteroids:
                  Total     (Per Batch)
Observations   :  1149827   (    17966)
Hits           :        0   (     0.00)

Summarize score = sum(-1.0 - log(v)) by batch.  (Mean=0, Variance=num_obs)
Mean score     :     150.10
Sqrt(batch_obs):     134.04
Mean t_score   :       0.95


In [20]:
# Summarize the ztf element batch: random elements
score_by_elt_rand = ztf_elt_summary(ztf_elt_rand, 'Random Elements')

ZTF Element Dataframe Random Elements:
                  Total     (Per Batch)
Observations   :  1056706   (    16773)
Hits           :        0   (     0.00)

Summarize score = sum(-1.0 - log(v)) by batch.  (Mean=0, Variance=num_obs)
Mean score     :      82.17
Sqrt(batch_obs):     129.51
Mean t_score   :       0.28


## View Example DataFrames and Hits

In [21]:
ztf_elt_ast

Unnamed: 0,ztf_id,element_id,ObjectID,CandidateID,TimeStampID,mjd,ra,dec,ux,uy,...,vz,elt_ux,elt_uy,elt_uz,elt_r,s,s_sec,z,v,is_hit
0,25248,733,b'ZTF18absqzef',611130485415015022,7576,58365.130486,269.830331,-14.496884,-0.002867,-0.987841,...,0.003988,-0.053007,-0.979529,0.194197,2.419761,0.063916,13185.823199,0.997957,0.838529,False
1,37580,733,b'ZTF18abcsqhf',618126363115010066,8913,58372.126366,271.256633,-11.466062,0.021493,-0.978028,...,0.003945,-0.042686,-0.979082,0.198938,2.501611,0.064738,13355.520139,0.997904,0.860243,False
2,37581,733,b'ZTF18abcsqhf',617209183115010025,8768,58371.209190,271.256646,-11.466037,0.021493,-0.978028,...,0.003950,-0.044314,-0.979129,0.198349,2.490799,0.066431,13704.796673,0.997793,0.905809,False
3,37582,733,b'ZTF18abcsqhf',617122523115010032,8730,58371.122523,271.256648,-11.466060,0.021493,-0.978028,...,0.003951,-0.044456,-0.979134,0.198292,2.489770,0.066579,13735.439330,0.997784,0.909863,False
4,37587,733,b'ZTF18abcsqhf',567274573115010023,3341,58321.274572,271.256647,-11.466069,0.021493,-0.978028,...,0.004191,-0.000053,-0.989617,0.143727,2.018643,0.068175,14064.826109,0.997676,0.954008,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1233686,5651378,324582,b'ZTF20aaqvkyo',1150185754815015007,96635,58904.185752,42.122572,29.195466,0.647485,0.731228,...,-0.001541,0.627608,0.750726,0.206199,2.981918,0.029094,6001.223850,0.999577,0.173740,False
1233687,5651434,324582,b'ZTF20aaqvkus',1150185310315015002,96634,58904.185313,43.256211,29.954124,0.631015,0.743348,...,-0.001541,0.627610,0.750725,0.206200,2.981912,0.017707,3652.453794,0.999843,0.064359,False
1233688,5651513,324582,b'ZTF20aaqvlna',1150185755115015011,96635,58904.185752,42.388038,28.228923,0.650751,0.733102,...,-0.001541,0.627608,0.750726,0.206199,2.981918,0.030306,6251.308520,0.999541,0.188521,False
1233689,5651704,324582,b'ZTF20aaqvlmz',1150185755115015003,96635,58904.185752,41.916630,28.631276,0.653128,0.728579,...,-0.001541,0.627608,0.750726,0.206199,2.981918,0.033791,6970.132666,0.999429,0.234365,False


In [22]:
# Review hits
mask = ztf_elt_ast.is_hit
ztf_elt_ast[mask]

Unnamed: 0,ztf_id,element_id,ObjectID,CandidateID,TimeStampID,mjd,ra,dec,ux,uy,...,vz,elt_ux,elt_uy,elt_uz,elt_r,s,s_sec,z,v,is_hit
178,341737,733,b'ZTF19abizrac',937427766115015019,40797,58691.427766,33.130412,43.596186,0.606481,0.637452,...,-0.002127,0.606487,0.637448,0.475220,2.115259,0.000008,1.551631,1.0,1.161524e-08,True
191,345725,733,b'ZTF19abjajfg',937467364815015020,40840,58691.467361,33.148426,43.605278,0.606265,0.637618,...,-0.002129,0.606270,0.637613,0.475274,2.114851,0.000007,1.522407,1.0,1.118183e-08,True
196,346522,733,b'ZTF19abjajmr',937468726115015011,40842,58691.468727,33.149062,43.605587,0.606257,0.637624,...,-0.002129,0.606263,0.637619,0.475276,2.114837,0.000008,1.547274,1.0,1.155011e-08,True
205,347644,733,b'ZTF19abiyxiu',937402264815015008,40777,58691.402269,33.118785,43.590288,0.606621,0.637345,...,-0.002127,0.606626,0.637341,0.475186,2.115523,0.000007,1.535273,1.0,1.137163e-08,True
222,431445,733,b'ZTF19abkkfhr',934448315015015003,40221,58688.448310,31.751906,42.913068,0.622775,0.624451,...,-0.002040,0.622780,0.624447,0.471393,2.145603,0.000007,1.536414,1.0,1.138854e-08,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1233374,5447574,324582,b'ZTF20aapdfmj',1145121244815015014,95204,58899.121250,42.392273,29.202629,0.644677,0.734024,...,-0.001502,0.644680,0.734021,0.213544,2.915471,0.000005,0.932672,1.0,4.196714e-09,True
1233417,5450145,324582,b'ZTF20aapdfmj',1145121716115015016,95205,58899.121713,42.392439,29.202638,0.644675,0.734026,...,-0.001502,0.644678,0.734022,0.213543,2.915477,0.000005,1.003729,1.0,4.860540e-09,True
1233479,5461311,324582,b'ZTF20aapeobw',1145164884815015018,95284,58899.164884,42.405891,29.202386,0.644538,0.734164,...,-0.001503,0.644541,0.734161,0.213479,2.916052,0.000004,0.888835,1.0,3.811477e-09,True
1233480,5461312,324582,b'ZTF20aapeobw',1145165336115015009,95285,58899.165336,42.406082,29.202412,0.644536,0.734166,...,-0.001503,0.644540,0.734163,0.213478,2.916058,0.000005,1.054282,1.0,5.362472e-09,True


In [23]:
ztf_elt_ast.columns

Index(['ztf_id', 'element_id', 'ObjectID', 'CandidateID', 'TimeStampID', 'mjd',
       'ra', 'dec', 'ux', 'uy', 'uz', 'qx', 'qy', 'qz', 'vx', 'vy', 'vz',
       'elt_ux', 'elt_uy', 'elt_uz', 'elt_r', 's', 's_sec', 'z', 'v',
       'is_hit'],
      dtype='object')

In [24]:
# Alias ztf_elt_ast to ztf_elt
ztf_elt = ztf_elt_ast

In [25]:
# Build numpy array of times
ts_np = ztf_elt.mjd.values.astype(dtype_np)

# Get observation count per element
row_lengths_np = ztf_elt.element_id.groupby(ztf_elt.element_id).count().values.astype(np.int32)

In [26]:
# Review results
element_id_best = ast_num_best[0]
mask = (ztf_elt.element_id == element_id_best)
hits_best = np.sum(ztf_elt[mask].is_hit)
hit_rate_best = np.mean(ztf_elt[mask].is_hit)
rows_best = np.sum(mask)
s_sec_min = np.min(ztf_elt[mask].s_sec)
idx = np.argmin(ztf_elt.s)
ztf_id = ztf_elt.ztf_id[idx]
# ztf_elt[mask].iloc[idx:idx+1]
print(f'Best asteroid has element_id = {element_id_best}')
print(f'Hit count: {hits_best} / {rows_best} observations')
print(f'Hit rate : {hit_rate_best:8.6f}')
print(f'Closest hit: {s_sec_min:0.3f} arc seconds')
# ztf_elt[mask]

Best asteroid has element_id = 51921
Hit count: 158 / 10110 observations
Hit rate : 0.015628
Closest hit: 0.382 arc seconds


## Build Asteroid Search Model

In [27]:
# Additional arguments for asteroid search models
site_name = 'palomar'
h = 1.0/64.0  # (1.5625%)
R_deg = 2.0

# Training parameters
learning_rate = 2.0**-13 # (1.22E-4)
clipnorm = 1.0

In [28]:
# Build asteroid search model
model = AsteroidSearchModel(
        elts=elts_ast, ztf_elt=ztf_elt, site_name=site_name,
        thresh_deg=thresh_deg, h=h, R_deg=R_deg,
        learning_rate=learning_rate, clipnorm=clipnorm)

In [29]:
# Dummy inputs for search model; any array with shape [elt_batch_size,] is good
x = np.ones(elt_batch_size)

In [30]:
# Run model on unperturbed elements
log_like, orbital_elements, mixture_parameters = model(x)

In [31]:
# Summarize log likelihood on unperturbed elements
log_like_tot = np.sum(log_like)
log_like_mean = np.mean(log_like)
log_like_std = np.std(log_like)

# Report on unperturbed elements
print(f'Log likelihood:')
print(f'Total: {log_like_tot:8.2f}')
print(f'Mean: {log_like_mean:8.2f}')
print(f'Std : {log_like_std:8.2f}')
print(f'First 5:')
print(log_like[0:5].numpy())

Log likelihood:
Total:    19.03
Mean:     0.30
Std :     0.36
First 5:
[0.5026116  0.35954165 0.16162954 0.20722373 0.05370883]


In [32]:
hist = model.fit(x)

Train on 64 samples


In [33]:
model.evaluate(x)



-70.60837936401367

In [34]:
# model.calc_log_like()

In [35]:
# model.summary()

## Fit Model on Unperturbed Elements

In [36]:
# Build asteroid search model
model_ast = AsteroidSearchModel(
                elts=elts_ast, ztf_elt=ztf_elt, site_name=site_name,
                thresh_deg=thresh_deg, h=h, R_deg=R_deg,
                learning_rate=learning_rate, clipnorm=clipnorm)

In [37]:
# Adaptive search parameters
max_batches = 10000
batches_per_epoch = 100
epochs_per_episode = 5
# min_learning_rate = 1.0E-7
min_learning_rate = 2.0**-23
verbose = 1

# Tiny size for fast testing
# max_batches = 100
# batches_per_epoch = 10

In [38]:
# model = model_ast

In [39]:
# Train unperturbed model
model_ast.search_adaptive(
    max_batches=max_batches, 
    batches_per_epoch=batches_per_epoch,
    epochs_per_episode=epochs_per_episode,
    min_learning_rate=min_learning_rate,
    verbose=verbose)

Train on 6400 samples
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Log Likelihood:  5179.32

Training episode 2: Epoch    5
learning_rate=1.221e-04, total training time 0 sec.
Train on 6400 samples
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Log Likelihood: 31020.76

Training episode 3: Epoch   10
learning_rate=1.221e-04, total training time 0 sec.
Train on 6400 samples
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Log Likelihood: 31985.88

Training episode 4: Epoch   14
learning_rate=1.221e-04, total training time 0 sec.
Train on 6400 samples
Epoch 15/19
Epoch 16/19
Changing learning rate by factor 0.500000 from 1.221e-04 to 6.104e-05.
Log Likelihood: 31708.33

Training episode 5: Epoch   16
learning_rate=6.104e-05, total training time 0 sec.
Train on 6400 samples
Epoch 17/21
Epoch 18/21
Changing learning rate by factor 0.500000 from 6.104e-05 to 3.052e-05.
Log Likelihood: 31920.88

Training episode 6: Epoch   18
learning_rate=3.052e-05, total training time 0 sec.
T

In [None]:
# # Second (tune-up) adaptive training
# model.search_adaptive(
#     max_batches=max_batches, 
#     batches_per_epoch=batches_per_epoch,
#     epochs_per_episode=epochs_per_episode,
#     min_learning_rate=min_learning_rate,
#     verbose=verbose)

In [40]:
# Review likelihoods by element
log_like_ast, orbital_elements_ast, mixture_parameters_ast = model_ast.calc_outputs()

In [41]:
log_like_ast

<tf.Tensor: shape=(64,), dtype=float32, numpy=
array([ 388.58304 ,  240.03488 ,  653.15985 ,  520.4308  ,  274.88257 ,
        334.38904 ,  353.10324 , 1379.2067  ,   92.35781 ,  662.12396 ,
        413.40378 , 1034.4708  ,  145.2835  ,  337.44818 , 1206.2009  ,
        154.60413 ,  149.15475 ,  747.82776 ,  742.9674  ,  554.65216 ,
        230.15546 ,  165.5966  , 1052.2552  ,  636.9827  ,  269.21576 ,
        271.72366 ,  292.50385 ,  487.80005 ,  113.5315  , 1028.0378  ,
        217.48804 ,  749.14966 ,  585.52716 ,   80.493256,  970.43884 ,
        115.38408 ,  398.9076  ,   69.50052 ,  606.45044 ,  323.05075 ,
        131.63742 ,  571.65076 ,  410.7415  ,  382.54666 ,  364.81277 ,
        264.0494  , 1237.5557  ,  493.4021  ,  703.44824 , 1380.3977  ,
        200.50754 ,  415.82388 ,  471.89285 ,  335.72308 ,  887.5253  ,
         23.285824,  358.50272 ,  196.9346  ,  556.87085 , 1554.7644  ,
        196.08194 ,  548.1326  ,  126.0803  ,  720.82196 ], dtype=float32)>

In [42]:
# plot log_like

## Train on Perturbed Elements

In [43]:
# Build asteroid search model
model_pert = AsteroidSearchModel(
                 elts=elts_pert, ztf_elt=ztf_elt_pert, site_name=site_name,
                 thresh_deg=thresh_deg, h=h, R_deg=R_deg,
                 learning_rate=learning_rate, clipnorm=clipnorm)

In [44]:
# Train model on perturbed elements
model_pert.search_adaptive(
    max_batches=max_batches, 
    batches_per_epoch=batches_per_epoch,
    epochs_per_episode=epochs_per_episode,
    min_learning_rate=min_learning_rate,
    verbose=verbose)

Train on 6400 samples
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Log Likelihood:  4510.63

Training episode 2: Epoch    5
learning_rate=1.221e-04, total training time 0 sec.
Train on 6400 samples
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Log Likelihood: 29583.26

Training episode 3: Epoch   10
learning_rate=1.221e-04, total training time 0 sec.
Train on 6400 samples
Epoch 11/15
Epoch 12/15
Epoch 13/15
Log Likelihood: 30189.61

Training episode 4: Epoch   13
learning_rate=1.221e-04, total training time 0 sec.
Train on 6400 samples
Epoch 14/18
Epoch 15/18
Changing learning rate by factor 0.500000 from 1.221e-04 to 6.104e-05.
Log Likelihood: 29901.54

Training episode 5: Epoch   15
learning_rate=6.104e-05, total training time 0 sec.
Train on 6400 samples
Epoch 16/20
Epoch 17/20
Changing learning rate by factor 0.500000 from 6.104e-05 to 3.052e-05.
Log Likelihood: 29987.93

Training episode 6: Epoch   17
learning_rate=3.052e-05, total training time 0 sec.
Train on 6400

In [45]:
# Review likelihoods by element
log_like_per, orbital_elements_pert, mixture_parameters_pert = model_pert.calc_outputs()

In [46]:
log_like_pert

<tf.Tensor: shape=(64,), dtype=float32, numpy=
array([ 474.798   ,  213.48239 ,  141.31227 ,  891.5929  ,  424.9019  ,
        178.26047 ,   47.32715 , 1543.3555  ,  160.85414 ,  571.67694 ,
        266.31735 ,  984.4761  ,  360.67383 ,  322.9265  , 1146.867   ,
         60.905857,  395.2658  ,  855.116   ,  564.28674 ,  435.42606 ,
        196.7451  ,  231.01488 , 1061.7494  ,  567.2667  ,  330.53247 ,
        247.92955 ,  253.66682 ,  391.057   ,   66.71367 ,  959.40924 ,
        426.2126  ,  390.57266 ,  599.01965 ,  157.75606 ,  742.4881  ,
        467.25656 ,  554.0606  ,   97.73875 ,  828.1733  ,  301.03305 ,
        161.27496 ,  742.74066 ,  293.26907 ,  190.17546 ,  160.88339 ,
        244.04848 ,  451.6518  ,  604.96234 ,  917.8943  ,  864.8018  ,
        223.8088  ,  101.54476 ,   68.18154 ,  426.27496 , 1697.8339  ,
        167.51418 ,  349.40793 ,   88.89637 ,  590.7444  ,  282.53323 ,
        373.57996 ,  518.2189  ,  199.97359 ,  843.6132  ], dtype=float32)>

## Train on Random Elements

In [57]:
# Filter elts_rand down to only those that had matching ztf observations
idx = np.unique(ztf_elt_rand.element_id)
elts_rand = elts_rand.loc[idx]

In [60]:
elts_rand['a'].dtype

dtype('float64')

In [67]:
elts_rand['a'].dtype

dtype('float32')

In [70]:
for col in ['a', 'e', 'f', 'inc', 'Omega', 'omega', 'epoch']:
    elts_rand[col] = elts_rand[col].astype(np.float32)

In [72]:
# Build asteroid search model
model_rand = AsteroidSearchModel(
                 elts=elts_rand, ztf_elt=ztf_elt_rand, site_name=site_name,
                 thresh_deg=thresh_deg, h=h, R_deg=R_deg,
                 learning_rate=learning_rate, clipnorm=clipnorm)

In [74]:
# Train model on perturbed elements
model_rand.search_adaptive(
    max_batches=max_batches, 
    batches_per_epoch=batches_per_epoch,
    epochs_per_episode=epochs_per_episode,
    min_learning_rate=min_learning_rate,
    verbose=verbose)

Train on 6300 samples
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Log Likelihood:  3696.17

Training episode 2: Epoch    5
learning_rate=1.221e-04, total training time 0 sec.
Train on 6300 samples
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Log Likelihood: 22294.73

Training episode 3: Epoch   10
learning_rate=1.221e-04, total training time 0 sec.
Train on 6300 samples
Epoch 11/15
Epoch 12/15
Log Likelihood: 22691.61

Training episode 4: Epoch   12
learning_rate=1.221e-04, total training time 0 sec.
Train on 6300 samples
Epoch 13/17
Epoch 14/17
Changing learning rate by factor 0.500000 from 1.221e-04 to 6.104e-05.
Log Likelihood: 22666.66

Training episode 5: Epoch   14
learning_rate=6.104e-05, total training time 0 sec.
Train on 6300 samples
Epoch 15/19
Epoch 16/19
Changing learning rate by factor 0.500000 from 6.104e-05 to 3.052e-05.
Log Likelihood: 22565.74

Training episode 6: Epoch   16
learning_rate=3.052e-05, total training time 0 sec.
Train on 6300 samples
Epo

In [75]:
# Review likelihoods by element
log_like_rand, orbital_elements_rand, mixture_parameters_rand = model_rand.calc_outputs()

In [76]:
log_like_rand

<tf.Tensor: shape=(63,), dtype=float32, numpy=
array([4.77463226e+02, 1.17010254e+03, 4.15739410e+02, 1.77882538e+01,
       1.95976425e+02, 1.28115244e+01, 4.19807037e+02, 3.46641312e+01,
       5.77843872e+02, 6.63247299e+00, 1.33319759e+00, 2.24694004e+01,
       3.79695679e+02, 1.13020776e+03, 6.90465637e+02, 5.43634338e+02,
       1.34409680e+03, 9.65692932e+02, 1.64268665e+01, 6.57113403e+02,
       4.86166039e+01, 1.11709668e+03, 1.50391602e+02, 7.67535305e+00,
       7.43313416e+02, 7.26073685e+01, 6.08193970e+02, 4.75551910e+01,
       2.53791656e+02, 2.14364044e+02, 1.96174355e+01, 4.33977013e+01,
       3.55882294e+02, 3.06253662e+02, 5.39834518e+01, 3.12860527e+01,
       1.50992584e+02, 6.74602509e+01, 3.11920227e+02, 4.93390686e+02,
       1.43001572e+02, 8.96031799e+01, 1.81505432e+02, 1.22521286e+02,
       5.18647270e+01, 2.77763367e+02, 2.85232147e+02, 4.46494019e+02,
       1.24276953e+03, 1.60976685e+02, 8.73707642e+02, 1.90191040e+02,
       1.87260380e+01, 2.76888

In [78]:
# orbital_elements_rand