In [1]:
# Core
import numpy as np
import pandas as pd

# Tensorflow / ML
import tensorflow as tf

# Utility
import os

# Plotting
import matplotlib.pyplot as plt
import matplotlib as mpl

# Utility
import time

In [2]:
# Set visibible GPU
gpu_num: int = 1
os.environ['CUDA_VISIBLE_DEVICES']=f'{gpu_num}'

In [3]:
# Configure TF GPU growth
import kepler_sieve
from tf_utils import gpu_grow_memory, get_gpu_device
gpu_grow_memory(verbose=True)

Found 1 GPUs.  Setting memory growth = True.


In [4]:
# MSE Imports
from asteroid_element import load_ast_elt
from candidate_element import asteroid_elts, perturb_elts, random_elts, elts_add_mixture_params, elts_add_H
from ztf_ast import load_ztf_nearest_ast, calc_hit_freq
from ztf_element import load_ztf_batch, make_ztf_batch, ztf_score_by_elt, ztf_elt_summary
from asteroid_model import AsteroidPosition, AsteroidDirection, make_model_ast_pos
from asteroid_search_layers import CandidateElements, MixtureParameters, TrajectoryScore
from asteroid_search_model import AsteroidSearchModel
from asteroid_search_report import traj_diff
from nearest_asteroid import nearest_ast_elt_cart, nearest_ast_elt_cov, elt_q_norm
from element_eda import score_by_elt
from asteroid_dataframe import calc_ast_data, spline_ast_vec_df
from astro_utils import deg2dist, dist2deg, dist2sec

In [5]:
# Aliases
keras = tf.keras

# Constants
dtype = tf.float32
dtype_np = np.float32
space_dims = 3

In [6]:
# Set plot style variables
mpl.rcParams['figure.figsize'] = [16.0, 10.0]
mpl.rcParams['font.size'] = 16

## Load ZTF Data and Batch of Orbital Elements

In [7]:
# Load orbital elements for known asteroids
ast_elt = load_ast_elt()

# Number of asteroids
N_ast = ast_elt.shape[0]

In [8]:
# Load ztf nearest asteroid data
ztf_ast = load_ztf_nearest_ast()

In [9]:
# Asteroid numbers and hit counts
ast_num, hit_count = calc_hit_freq(ztf=ztf_ast, thresh_sec=2.0)

# Sort the hit counts in descending order and find the top batch_size
idx = np.argsort(hit_count)[::-1]

# Extract the asteroid number and hit count for this batch
ast_num_best = ast_num[idx]
hit_count_best = hit_count[idx]

In [10]:
# Parameters to build elements batch
batch_size = 64

# Batch of unperturbed elements
elts_ast = asteroid_elts(ast_nums=ast_num_best[0:batch_size])

In [11]:
# # Review unperturbed elements
# elts_ast

In [12]:
# Inputs to perturb elements: large
sigma_a = 0.05
sigma_e = 0.01
sigma_inc_deg = 0.25
sigma_f_deg = 1.0
sigma_Omega_deg = 1.0
sigma_omega_deg = 1.0
mask_pert = None
random_seed = 42

In [13]:
# Perturb orbital elements
elts_pert= perturb_elts(elts_ast, sigma_a=sigma_a, sigma_e=sigma_e, 
                    sigma_inc_deg=sigma_inc_deg, sigma_f_deg=sigma_f_deg, 
                    sigma_Omega_deg=sigma_Omega_deg, sigma_omega_deg=sigma_omega_deg,
                    mask_pert=mask_pert, random_seed=random_seed)

In [14]:
# Choose which elements to search on
elts = elts_pert

## Batches of ZTF Data Near Initial Candidate Elements

In [15]:
# Arguments to make_ztf_batch
thresh_deg = 2.0
near_ast = False
regenerate = False

In [16]:
# Load perturbed element batch
ztf_elt = load_ztf_batch(elts=elts, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [17]:
# Review ZTF elements
ztf_elt

Unnamed: 0,ztf_id,element_id,ObjectID,CandidateID,TimeStampID,mjd,ra,dec,ux,uy,...,vz,elt_ux,elt_uy,elt_uz,elt_r,s,s_sec,z,v,is_hit
0,306,733,b'ZTF18abiyevm',567274570115015018,3341,58321.274572,275.834958,-12.178240,0.099376,-0.976101,...,0.004095,0.080910,-0.982320,0.168822,2.050675,0.031258,6447.753750,0.999511,0.801975,False
1,6391,733,b'ZTF18ablpbwh',617122522515015016,8730,58371.122523,272.156750,-10.136454,0.037046,-0.972528,...,0.003837,0.016116,-0.975703,0.218503,2.491969,0.024005,4951.468379,0.999712,0.472962,False
2,6392,733,b'ZTF18ablpbwh',618126362515015025,8913,58372.126366,272.156760,-10.136446,0.037046,-0.972528,...,0.003830,0.017464,-0.975564,0.219021,2.503645,0.022568,4655.153797,0.999745,0.418050,False
3,6393,733,b'ZTF18ablpbwh',611146562515015015,7585,58365.146562,272.156733,-10.136444,0.037046,-0.972528,...,0.003876,0.010155,-0.976524,0.215169,2.423439,0.030883,6370.298847,0.999523,0.782824,False
4,12249,733,b'ZTF18ablwzcc',584190354815015015,4647,58338.190359,273.272132,-13.497675,0.055502,-0.983530,...,0.004023,0.030085,-0.980846,0.192450,2.155818,0.032724,6750.065879,0.999465,0.878934,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
290766,5650772,324582,b'ZTF20aaqvhnd',1150176700415015000,96618,58904.176701,48.664349,31.318054,0.564235,0.795279,...,-0.001786,0.584015,0.786561,0.200621,2.748814,0.030221,6233.797319,0.999543,0.749638,False
290767,5650773,324582,b'ZTF20aaqvhns',1150176245615015007,96617,58904.176250,45.820577,29.378228,0.607285,0.768505,...,-0.001786,0.584016,0.786559,0.200622,2.748808,0.029465,6077.915470,0.999566,0.712618,False
290768,5650789,324582,b'ZTF20aaqvhnm',1150176245015015006,96617,58904.176250,48.881586,28.300138,0.579016,0.797156,...,-0.001786,0.584016,0.786559,0.200622,2.748808,0.031743,6547.783696,0.999496,0.827049,False
290769,5650791,324582,b'ZTF20aaqvhog',1150176244815015007,96617,58904.176250,49.429756,29.370649,0.566783,0.802441,...,-0.001786,0.584016,0.786559,0.200622,2.748808,0.027275,5625.990260,0.999628,0.610591,False


In [18]:
# Score by element - perturbed
score_by_elt = ztf_score_by_elt(ztf_elt)

In [19]:
# Summarize the ztf element batch: perturbed asteroids
ztf_elt_summary(ztf_elt, score_by_elt, 'Perturbed Asteroids')

ZTF Element Dataframe Perturbed Asteroids:
                  Total     (Per Batch)
Observations   :   290771   (     4543)

Summarize score = sum(-1.0 - log(v)) by batch.  (Mean=0, Variance=num_obs)
Mean score     :      42.77
Sqrt(batch_obs):      67.40
Mean t_score   :       0.58


In [20]:
# Mixture parameters
num_hits: int = 10
R_deg: float = 0.5

In [21]:
# Add mixture parameters to candidate elements
elts_add_mixture_params(elts=elts, num_hits=num_hits, R_deg=R_deg, thresh_deg=thresh_deg)

In [22]:
# Add brightness parameter H
elts_add_H(elts=elts)

In [23]:
# Review perturbed elements; includes nearest asteroid number and distance
elts

Unnamed: 0,element_id,a,e,inc,Omega,omega,f,epoch,num_hits,R,thresh_s,H,sigma_mag
0,51921,2.736430,0.219134,0.499988,4.721815,2.452489,-1.129754,58600.0,10,0.008727,0.034905,16.5,4.0
1,59244,2.616575,0.266087,0.462848,5.725946,1.777382,-1.623105,58600.0,10,0.008727,0.034905,16.5,4.0
2,15786,1.945213,0.047621,0.385594,6.142435,0.790543,-1.243047,58600.0,10,0.008727,0.034905,16.5,4.0
3,3904,2.758664,0.099270,0.261841,5.463683,2.238942,-1.350620,58600.0,10,0.008727,0.034905,16.5,4.0
4,142999,2.589450,0.192070,0.509382,0.221844,0.928905,-1.314727,58600.0,10,0.008727,0.034905,16.5,4.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
59,11952,2.330603,0.084892,0.117649,0.042808,2.890716,-3.000560,58600.0,10,0.008727,0.034905,16.5,4.0
60,134815,2.550916,0.141660,0.510228,0.284591,0.630896,-0.920797,58600.0,10,0.008727,0.034905,16.5,4.0
61,27860,2.595202,0.098315,0.194023,5.535984,3.255585,3.966790,58600.0,10,0.008727,0.034905,16.5,4.0
62,85937,2.216242,0.195323,0.437115,5.285351,3.172956,3.921169,58600.0,10,0.008727,0.034905,16.5,4.0


## Build Asteroid Search Model

In [24]:
# Observatory for ZTF data is Palomar Mountain
site_name = 'palomar'

In [25]:
# Training parameters
learning_rate = 2.0**-12
clipnorm = 1.0
save_at_end: bool = True

In [26]:
# Build asteroid search model
model = AsteroidSearchModel(
                elts=elts, ztf_elt=ztf_elt, 
                site_name=site_name, thresh_deg=thresh_deg, 
                learning_rate=learning_rate, clipnorm=clipnorm,
                name='model',
                file_name=f'candidate_elt_pert_large_{random_seed:04d}.h5'
)

In [27]:
# Load trained model
model.load()

Loaded candidate elements and training history from ../data/candidate_elt/candidate_elt_pert_large_0042.h5.


In [28]:
# Report before training starts
model.report()


Good elements (hits >= 10):   9.00

         \  log_like :  hits  :    R_sec : thresh_sec
Mean Good:  1172.76  : 123.78 :     9.28 :   560.91
Mean Bad :    53.60  :   0.07 :   407.25 :  2578.63
Mean     :   210.98  :  17.47 :   351.28 :  2294.89
Median   :    45.48  :   0.00 :   345.88 :  2533.41
GeoMean  :    69.05  :   1.01 :   195.52 :  2059.26
Min      :    19.65  :   0.00 :     4.43 :   374.50
Max      :  1501.34  : 173.00 :  1200.00 :  2790.43
Trained for 13376 batches over 209 epochs and 65 episodes (elapsed time 1297 seconds).


In [29]:
ztf_hits = model.calc_ztf_hits()

HBox(children=(FloatProgress(value=0.0, max=64.0), HTML(value='')))




In [36]:
cols = \
    ['ztf_id', 'element_id', 'mjd',
     'ra', 'dec', 'mag_app', 'ux', 'uy', 'uz',
     'elt_ux', 'elt_uy', 'elt_uz', 
     's_sec', 'mag_pred', 'mag_diff']

In [37]:
ztf_hits[cols]

Unnamed: 0,ztf_id,element_id,mjd,ra,dec,mag_app,ux,uy,uz,elt_ux,elt_uy,elt_uz,s_sec,mag_pred,mag_diff
0,1994410,733,58756.486481,54.789836,53.890263,16.679100,0.339796,0.763137,0.549690,0.339809,0.763107,0.549725,9.767155,17.782135,1.103035
1,2071539,733,58759.368750,54.979604,54.049413,16.512300,0.336911,0.763130,0.551475,0.336926,0.763102,0.551504,8.938659,17.742491,1.230190
2,2079718,733,58759.382384,54.980077,54.050121,16.512800,0.336901,0.763128,0.551483,0.336916,0.763100,0.551513,8.934473,17.742270,1.229469
3,2085480,733,58759.413148,54.981065,54.051651,16.926500,0.336880,0.763123,0.551503,0.336895,0.763095,0.551532,8.918493,17.741838,0.815338
4,2089292,733,58759.425891,54.981501,54.052296,16.979200,0.336871,0.763121,0.551511,0.336887,0.763093,0.551540,8.916037,17.741682,0.762482
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1141,5381516,90777,58896.182523,39.709609,36.789589,15.601800,0.616080,0.707654,0.345936,0.616079,0.707656,0.345934,0.700642,20.481018,4.879218
1142,5446528,90777,58899.110625,40.456356,36.751514,15.202700,0.609662,0.715006,0.342168,0.609660,0.715009,0.342165,1.110163,20.503975,5.301275
1143,5461111,90777,58899.163958,40.470055,36.750858,15.900400,0.609543,0.715140,0.342100,0.609541,0.715143,0.342097,1.121960,20.504402,4.604002
1144,1066101,104194,58717.445116,65.036459,25.006511,19.418301,0.382479,0.921946,0.061028,0.382441,0.921963,0.061018,8.752734,17.490957,1.927343


In [33]:
self = model
elts_fit = self.candidates_df()

In [34]:
elts_fit

Unnamed: 0,element_id,a,e,inc,Omega,omega,f,epoch,num_hits,R,...,thresh_deg,thresh_sec,log_like,hits,num_rows_close,H,sigma_mag,weight_joint,weight_element,weight_mixture
0,51921,2.671419,0.217827,0.499639,4.699614,2.447760,-1.130743,58600.0,194.943146,0.000079,...,0.157495,566.981689,1121.526733,26.994644,222.0,15.769428,4.0,0.007812,1.0,0.03125
1,59244,2.634395,0.262890,0.464968,5.738090,1.771472,-1.604850,58600.0,189.965530,0.000075,...,0.254521,916.275452,1280.354858,98.989349,234.0,16.205391,4.0,0.007812,1.0,0.12500
2,15786,1.942106,0.050723,0.385029,6.140004,0.788396,-1.249232,58600.0,20.630554,0.000462,...,0.645947,2325.407715,31.780399,0.000000,352.0,17.861952,4.0,0.007812,1.0,0.06250
3,3904,2.760293,0.128028,0.257008,5.504178,2.180960,-1.369055,58600.0,70.118988,0.002502,...,0.773805,2785.696289,42.812237,0.000000,460.0,16.437471,4.0,0.500000,1.0,0.06250
4,142999,2.505064,0.179645,0.501702,0.207104,0.945653,-1.344816,58600.0,35.063999,0.001113,...,0.769156,2768.960693,24.255808,0.000000,550.0,16.981972,4.0,0.500000,1.0,0.03125
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
59,11952,2.224995,0.078338,0.121356,0.069336,2.828827,-3.042223,58600.0,89.200439,0.001694,...,0.703615,2533.014404,51.075352,0.000000,1309.0,16.782936,4.0,0.062500,1.0,0.06250
60,134815,2.621922,0.155097,0.520578,0.292239,0.660736,-0.939725,58600.0,47.555294,0.002643,...,0.666667,2399.999756,25.066959,0.000000,246.0,16.530491,4.0,0.007812,1.0,0.03125
61,27860,2.600507,0.096844,0.194392,5.530711,3.259202,3.969775,58600.0,64.929626,0.001221,...,0.666667,2399.999756,41.732182,0.000000,979.0,16.195171,4.0,0.007812,1.0,0.06250
62,85937,2.227119,0.198896,0.437545,5.270397,3.090596,3.999848,58600.0,66.213310,0.002724,...,0.703643,2533.114746,30.344482,0.000000,423.0,17.293219,4.0,0.062500,1.0,0.06250


In [None]:
ztf_hits['mag_diff'] = np.abs(ztf_hits.mag_app - ztf_hits.mag_pred)

In [35]:
ztf_hits[cols]

Unnamed: 0,ztf_id,element_id,mjd,ra,dec,mag_app,ux,uy,uz,elt_ux,elt_uy,elt_uz,s_sec,mag_pred
0,1994410,733,58756.486481,54.789836,53.890263,16.679100,0.339796,0.763137,0.549690,0.339809,0.763107,0.549725,9.767155,17.782135
1,2071539,733,58759.368750,54.979604,54.049413,16.512300,0.336911,0.763130,0.551475,0.336926,0.763102,0.551504,8.938659,17.742491
2,2079718,733,58759.382384,54.980077,54.050121,16.512800,0.336901,0.763128,0.551483,0.336916,0.763100,0.551513,8.934473,17.742270
3,2085480,733,58759.413148,54.981065,54.051651,16.926500,0.336880,0.763123,0.551503,0.336895,0.763095,0.551532,8.918493,17.741838
4,2089292,733,58759.425891,54.981501,54.052296,16.979200,0.336871,0.763121,0.551511,0.336887,0.763093,0.551540,8.916037,17.741682
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1141,5381516,90777,58896.182523,39.709609,36.789589,15.601800,0.616080,0.707654,0.345936,0.616079,0.707656,0.345934,0.700642,20.481018
1142,5446528,90777,58899.110625,40.456356,36.751514,15.202700,0.609662,0.715006,0.342168,0.609660,0.715009,0.342165,1.110163,20.503975
1143,5461111,90777,58899.163958,40.470055,36.750858,15.900400,0.609543,0.715140,0.342100,0.609541,0.715143,0.342097,1.121960,20.504402
1144,1066101,104194,58717.445116,65.036459,25.006511,19.418301,0.382479,0.921946,0.061028,0.382441,0.921963,0.061018,8.752734,17.490957


In [38]:
np.mean(ztf_hits.mag_diff)

1.3001267448145668

In [39]:
np.std(ztf_hits.mag_app)

1.2608708324561555

In [40]:
mask = ztf_hits.element_id == 90777

In [41]:
ztf_hits[cols][mask]

Unnamed: 0,ztf_id,element_id,mjd,ra,dec,mag_app,ux,uy,uz,elt_ux,elt_uy,elt_uz,s_sec,mag_pred,mag_diff
993,395815,90777,58693.445729,44.285479,33.579335,15.9857,0.596406,0.753718,0.276059,0.596413,0.753709,0.276070,3.412059,20.678802,4.693103
994,420554,90777,58693.485671,44.293555,33.585205,15.5081,0.596283,0.753793,0.276120,0.596290,0.753783,0.276131,3.368749,20.678478,5.170379
995,430852,90777,58688.476401,43.232256,32.855154,16.0814,0.612042,0.743715,0.268871,0.612047,0.743707,0.268884,3.336704,20.719473,4.638073
996,433118,90777,58688.476863,43.232370,32.855227,16.0991,0.612041,0.743716,0.268872,0.612045,0.743708,0.268885,3.357001,20.719475,4.620375
997,549984,90777,58696.451701,44.889994,34.015204,15.7284,0.587237,0.759236,0.280557,0.587246,0.759226,0.280567,3.492202,20.653498,4.925097
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1139,5242925,90777,58893.159641,38.972055,36.841868,15.5513,0.622190,0.700314,0.349914,0.622190,0.700315,0.349913,0.406294,20.456871,4.905571
1140,5364815,90777,58896.123785,39.695137,36.790572,15.3835,0.616202,0.707510,0.346013,0.616200,0.707513,0.346011,0.721670,20.480553,5.097053
1141,5381516,90777,58896.182523,39.709609,36.789589,15.6018,0.616080,0.707654,0.345936,0.616079,0.707656,0.345934,0.700642,20.481018,4.879218
1142,5446528,90777,58899.110625,40.456356,36.751514,15.2027,0.609662,0.715006,0.342168,0.609660,0.715009,0.342165,1.110163,20.503975,5.301275


In [44]:
np.mean(ztf_hits.mag_pred[mask])

20.127634048461914

In [42]:
elts[elts.element_id == 733]

Unnamed: 0,element_id,a,e,inc,Omega,omega,f,epoch,num_hits,R,thresh_s,H,sigma_mag
47,733,3.583354,0.058824,0.357684,5.962229,3.315871,3.718062,58600.0,10,0.008727,0.034905,16.5,4.0


In [None]:
elts

In [None]:
model.thaw_all()

In [None]:
model.magnitude.trainable

In [None]:
# model.sieve_round(round=1, 
#                  num_batches=1024, 
#                  batches_per_epoch=64,
#                  epochs_per_episode=4,
#                  training_mode='joint',
#                  learning_rate=2**-15, 
#                  min_learning_rate=2**-23,
#                  reset_active_weight=True)

In [None]:
variables_names = [v.name for v in tf.trainable_variables()]

In [None]:
model.get_H()

In [None]:
model.candidates_df()