In [1]:
# Core
import numpy as np
import pandas as pd

# Tensorflow / ML
import tensorflow as tf
from tensorflow.python.keras import backend as K

# Local
from ztf_data import load_ztf_nearest_ast, calc_hit_freq, load_ztf_batch,make_ztf_batch
from asteroid_integrate import load_ast_elt, load_ast_pos
from candidate_element import orbital_element_batch, perturb_elts, random_elts

from asteroid_model import make_model_ast_pos, make_model_ast_dir, AsteroidDirection

Found 4 GPUs.  Setting memory growth = True.


## Load ZTF Data and Batch of Orbital Elements

In [2]:
# Load orbital elements for known asteroids
ast_elt = load_ast_elt()

# Number of asteroids
N_ast = ast_elt.shape[0]

In [3]:
# Load ztf nearest asteroid data
ztf_ast = load_ztf_nearest_ast()

In [4]:
# Asteroid numbers and hit counts
ast_num, hit_count = calc_hit_freq(ztf=ztf_ast, thresh_sec=2.0)

# Sort the hit counts in descending order and find the top batch_size
idx = np.argsort(hit_count)[::-1]

# Extract the asteroid number and hit count for this batch
ast_num_best = ast_num[idx]
hit_count_best = hit_count[idx]

In [5]:
# Set batch size
batch_size = 64

# Batch of unperturbed elements
elts_ast = orbital_element_batch(ast_nums=ast_num_best[0:batch_size])

In [6]:
elts_ast

Unnamed: 0,element_id,a,e,inc,Omega,omega,f,epoch
0,51921,2.669306,0.217361,0.499554,4.699703,2.450796,-1.133491,58600.0
1,59244,2.634727,0.262503,0.465045,5.738297,1.766995,-1.601363,58600.0
2,15786,1.883227,0.047655,0.392360,6.134689,0.804823,-1.246069,58600.0
3,3904,2.556387,0.098279,0.261542,5.450163,2.202423,-1.357345,58600.0
4,142999,2.619944,0.191376,0.514017,0.238022,0.946463,-1.299301,58600.0
...,...,...,...,...,...,...,...,...
59,11952,2.219650,0.086091,0.117967,0.042442,2.904823,-3.016580,58600.0
60,134815,2.612770,0.140831,0.513923,0.272689,0.645552,-0.957836,58600.0
61,27860,2.619406,0.096185,0.200633,5.541399,3.266046,3.948770,58600.0
62,85937,2.342292,0.197267,0.439063,5.279693,3.210025,3.947687,58600.0


In [7]:
# Perturb orbital elements
sigma_a = 0.0 
sigma_e = 0.0 
sigma_f_deg = 0.1
sigma_Omega_deg = 0.0
sigma_omega_deg = 0.0
mask_pert = None
random_seed = 42

elts_pert = perturb_elts(elts_ast, sigma_a=sigma_a, sigma_e=sigma_e, sigma_f_deg=sigma_f_deg, 
                         sigma_Omega_deg=sigma_Omega_deg, sigma_omega_deg=sigma_omega_deg,
                         mask_pert=mask_pert, random_seed=random_seed)

In [8]:
elts_pert

Unnamed: 0,element_id,a,e,inc,Omega,omega,f,epoch
0,51921,2.669306,0.217361,0.499554,4.699703,2.450796,-1.133117,58600.0
1,59244,2.634727,0.262503,0.465045,5.738297,1.766995,-1.603537,58600.0
2,15786,1.883227,0.047655,0.392360,6.134689,0.804823,-1.245767,58600.0
3,3904,2.556387,0.098279,0.261542,5.450163,2.202423,-1.356673,58600.0
4,142999,2.619945,0.191376,0.514017,0.238022,0.946463,-1.300844,58600.0
...,...,...,...,...,...,...,...,...
59,11952,2.219650,0.086091,0.117967,0.042442,2.904823,-3.014978,58600.0
60,134815,2.612770,0.140831,0.513923,0.272689,0.645552,-0.954132,58600.0
61,27860,2.619406,0.096185,0.200633,5.541399,3.266046,3.950572,58600.0
62,85937,2.342292,0.197267,0.439063,5.279693,3.210025,3.945035,58600.0


## Batches of ZTF Data vs. Elements

In [16]:
# Arguments to make_ztf_batch
thresh_deg = 2.0
near_ast = False
regenerate = True

In [None]:
# Load unperturbed element batch
ztf_elt_ast = load_ztf_batch(elts=elts_ast, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

HBox(children=(FloatProgress(value=0.0, max=64.0), HTML(value='')))

In [None]:
# Load perturbed element batch
ztf_elt_pert = load_ztf_batch(elts=elts_pert, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [None]:
ztf_elt_ast

In [None]:
ztf_elt_ast.columns

In [None]:
# Review results
ztf_elt = ztf_elt_ast
element_id_best = ast_num_best[0]
mask = (ztf_elt.element_id == element_id_best)
hits_best = np.sum(ztf_elt[mask].is_hit)
s_sec_min = np.min(ztf_elt[mask].s_sec)
idx = np.argmin(ztf_elt.s)
ztf_id = ztf_elt.ztf_id[idx]
# ztf_elt[mask].iloc[idx:idx+1]
print(f'Best asteroid has element_id = {element_id_best}')
print(f'Hit count: {hits_best}')
print(f'Closest hit: {s_sec_min:0.3f} arc seconds')
# ztf_elt[mask]

## Load Position and Direction Models

In [None]:
# Unique times
ts_np = np.unique(ztf_elt_ast.mjd)
TimeStampID = np.unique(ztf_elt_ast.TimeStampID)

# The epoch
epoch0 = elts_ast.epoch[0]

# Observation site
site_name = 'palomar'

In [None]:
# Load positions for calibration
q_cal = load_ast_pos(elts=elts_ast, epoch=epoch0, ts=ts_np)

In [None]:
# Build position model
model_pos = make_model_ast_pos(ts=ts_np, batch_size=batch_size)

In [None]:
# Build direction model
model_dir = make_model_ast_dir(ts=ts_np, batch_size=batch_size, site_name=site_name)

In [None]:
# Stack elements as a dict of numpy arrays for prediction
cols_elt = ['a', 'e', 'inc', 'Omega', 'omega', 'f', 'epoch']
elts_ast_dict = {col : elts_ast[col].values for col in cols_elt}
# elts_ast_dict

In [None]:
# Predict position model
q_pred, v_pred = model_pos.predict(elts_ast_dict)

In [None]:
# Predict direction model
u_pred, r_pred = model_dir.predict(elts_ast_dict)

# Review shape of predictions
u_pred.shape

## Assemble Tensors for Prototype Model

In [None]:
dtype = tf.float32

# Observation times
ts = tf.constant(value=ts_np, dtype=dtype)

# Orbital elements
a = tf.constant(value=elts_ast.a, dtype=dtype)
e = tf.constant(value=elts_ast.e, dtype=dtype)
inc = tf.constant(value=elts_ast.inc, dtype=dtype)
Omega = tf.constant(value=elts_ast.Omega, dtype=dtype)
omega = tf.constant(value=elts_ast.omega, dtype=dtype)
f = tf.constant(value=elts_ast.f, dtype=dtype)
epoch = tf.constant(value=elts_ast.epoch, dtype=dtype)

In [None]:
# Build direction layer
elt_batch_size = batch_size

direction_layer = AsteroidDirection(ts=ts, site_name=site_name, batch_size=elt_batch_size, name='u_pred')

In [None]:
ts

In [None]:
ztf_elt_ast

In [None]:
# Alias ztf_elt
ztf_elt = ztf_elt_ast.copy()

# Get observation count per element
row_splits = ztf_elt.element_id.groupby(ztf_elt.element_id).count().values
row_splits

In [None]:
ast_num_best[0:3]

In [None]:
mask = ztf_elt.element_id == 51921
ztf_elt[mask]