In [1]:
# Core
import numpy as np
import pandas as pd

# Tensorflow / ML
import tensorflow as tf

# Utility
import os

# Plotting
import matplotlib.pyplot as plt
import matplotlib as mpl

# Utility
import time

In [2]:
# Set visibible GPU
gpu_num: int = 1
os.environ['CUDA_VISIBLE_DEVICES']=f'{gpu_num}'

In [3]:
# Configure TF GPU growth
import kepler_sieve
from tf_utils import gpu_grow_memory, get_gpu_device
gpu_grow_memory(verbose=True)

Found 1 GPUs.  Setting memory growth = True.


In [4]:
# MSE Imports
from asteroid_element import load_ast_elt
from candidate_element import asteroid_elts, perturb_elts, random_elts, elts_add_mixture_params, elts_add_H
from ztf_ast import load_ztf_nearest_ast, calc_hit_freq
from ztf_element import load_ztf_batch, make_ztf_batch, ztf_score_by_elt, ztf_elt_summary
from asteroid_model import AsteroidPosition, AsteroidDirection, make_model_ast_pos
from asteroid_search_layers import CandidateElements, MixtureParameters, TrajectoryScore
from asteroid_search_model import AsteroidSearchModel
from asteroid_search_report import traj_diff
from nearest_asteroid import nearest_ast_elt_cart, nearest_ast_elt_cov, elt_q_norm
from element_eda import score_by_elt
from asteroid_dataframe import calc_ast_data, spline_ast_vec_df
from astro_utils import deg2dist, dist2deg, dist2sec

In [5]:
# Aliases
keras = tf.keras

# Constants
dtype = tf.float32
dtype_np = np.float32
space_dims = 3

In [6]:
# Set plot style variables
mpl.rcParams['figure.figsize'] = [16.0, 10.0]
mpl.rcParams['font.size'] = 16

## Load ZTF Data and Batch of Orbital Elements

In [7]:
# Load orbital elements for known asteroids
ast_elt = load_ast_elt()

# Number of asteroids
N_ast = ast_elt.shape[0]

In [8]:
# Load ztf nearest asteroid data
ztf_ast = load_ztf_nearest_ast()

In [9]:
# Asteroid numbers and hit counts
ast_num, hit_count = calc_hit_freq(ztf=ztf_ast, thresh_sec=2.0)

# Sort the hit counts in descending order and find the top batch_size
idx = np.argsort(hit_count)[::-1]

# Extract the asteroid number and hit count for this batch
ast_num_best = ast_num[idx]
hit_count_best = hit_count[idx]

In [10]:
# Parameters to build elements batch
batch_size = 64

# Batch of unperturbed elements
elts_ast = asteroid_elts(ast_nums=ast_num_best[0:batch_size])

In [11]:
# # Review unperturbed elements
# elts_ast

In [12]:
# Inputs to perturb elements: large
sigma_a = 0.05
sigma_e = 0.01
sigma_inc_deg = 0.25
sigma_f_deg = 1.0
sigma_Omega_deg = 1.0
sigma_omega_deg = 1.0
mask_pert = None
random_seed = 42

In [13]:
# Perturb orbital elements
elts_pert= perturb_elts(elts_ast, sigma_a=sigma_a, sigma_e=sigma_e, 
                    sigma_inc_deg=sigma_inc_deg, sigma_f_deg=sigma_f_deg, 
                    sigma_Omega_deg=sigma_Omega_deg, sigma_omega_deg=sigma_omega_deg,
                    mask_pert=mask_pert, random_seed=random_seed)

In [14]:
# Choose which elements to search on
elts = elts_pert

## Batches of ZTF Data Near Initial Candidate Elements

In [15]:
# Arguments to make_ztf_batch
thresh_deg = 2.0
near_ast = False
regenerate = False

In [16]:
# Load perturbed element batch
ztf_elt = load_ztf_batch(elts=elts, thresh_deg=thresh_deg, near_ast=near_ast, regenerate=regenerate)

In [17]:
# Review ZTF elements
ztf_elt

Unnamed: 0,ztf_id,element_id,ObjectID,CandidateID,TimeStampID,mjd,ra,dec,ux,uy,...,vz,elt_ux,elt_uy,elt_uz,elt_r,s,s_sec,z,v,is_hit
0,306,733,b'ZTF18abiyevm',567274570115015018,3341,58321.274572,275.834958,-12.178240,0.099376,-0.976101,...,0.004095,0.080910,-0.982320,0.168822,2.050675,0.031258,6447.753750,0.999511,0.801975,False
1,6391,733,b'ZTF18ablpbwh',617122522515015016,8730,58371.122523,272.156750,-10.136454,0.037046,-0.972528,...,0.003837,0.016116,-0.975703,0.218503,2.491969,0.024005,4951.468379,0.999712,0.472962,False
2,6392,733,b'ZTF18ablpbwh',618126362515015025,8913,58372.126366,272.156760,-10.136446,0.037046,-0.972528,...,0.003830,0.017464,-0.975564,0.219021,2.503645,0.022568,4655.153797,0.999745,0.418050,False
3,6393,733,b'ZTF18ablpbwh',611146562515015015,7585,58365.146562,272.156733,-10.136444,0.037046,-0.972528,...,0.003876,0.010155,-0.976524,0.215169,2.423439,0.030883,6370.298847,0.999523,0.782824,False
4,12249,733,b'ZTF18ablwzcc',584190354815015015,4647,58338.190359,273.272132,-13.497675,0.055502,-0.983530,...,0.004023,0.030085,-0.980846,0.192450,2.155818,0.032724,6750.065879,0.999465,0.878934,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
290766,5650772,324582,b'ZTF20aaqvhnd',1150176700415015000,96618,58904.176701,48.664349,31.318054,0.564235,0.795279,...,-0.001786,0.584015,0.786561,0.200621,2.748814,0.030221,6233.797319,0.999543,0.749638,False
290767,5650773,324582,b'ZTF20aaqvhns',1150176245615015007,96617,58904.176250,45.820577,29.378228,0.607285,0.768505,...,-0.001786,0.584016,0.786559,0.200622,2.748808,0.029465,6077.915470,0.999566,0.712618,False
290768,5650789,324582,b'ZTF20aaqvhnm',1150176245015015006,96617,58904.176250,48.881586,28.300138,0.579016,0.797156,...,-0.001786,0.584016,0.786559,0.200622,2.748808,0.031743,6547.783696,0.999496,0.827049,False
290769,5650791,324582,b'ZTF20aaqvhog',1150176244815015007,96617,58904.176250,49.429756,29.370649,0.566783,0.802441,...,-0.001786,0.584016,0.786559,0.200622,2.748808,0.027275,5625.990260,0.999628,0.610591,False


In [18]:
# Score by element - perturbed
score_by_elt = ztf_score_by_elt(ztf_elt)

In [19]:
# Summarize the ztf element batch: perturbed asteroids
ztf_elt_summary(ztf_elt, score_by_elt, 'Perturbed Asteroids')

ZTF Element Dataframe Perturbed Asteroids:
                  Total     (Per Batch)
Observations   :   290771   (     4543)

Summarize score = sum(-1.0 - log(v)) by batch.  (Mean=0, Variance=num_obs)
Mean score     :      42.77
Sqrt(batch_obs):      67.40
Mean t_score   :       0.58


In [20]:
# Mixture parameters
num_hits: int = 10
R_deg: float = 0.5

In [21]:
# Add mixture parameters to candidate elements
elts_add_mixture_params(elts=elts, num_hits=num_hits, R_deg=R_deg, thresh_deg=thresh_deg)

In [22]:
# Add brightness parameter H
elts_add_H(elts=elts)

In [23]:
# Review perturbed elements; includes nearest asteroid number and distance
elts

Unnamed: 0,element_id,a,e,inc,Omega,omega,f,epoch,num_hits,R,thresh_s,H,sigma_mag
0,51921,2.736430,0.219134,0.499988,4.721815,2.452489,-1.129754,58600.0,10,0.008727,0.034905,16.5,4.0
1,59244,2.616575,0.266087,0.462848,5.725946,1.777382,-1.623105,58600.0,10,0.008727,0.034905,16.5,4.0
2,15786,1.945213,0.047621,0.385594,6.142435,0.790543,-1.243047,58600.0,10,0.008727,0.034905,16.5,4.0
3,3904,2.758664,0.099270,0.261841,5.463683,2.238942,-1.350620,58600.0,10,0.008727,0.034905,16.5,4.0
4,142999,2.589450,0.192070,0.509382,0.221844,0.928905,-1.314727,58600.0,10,0.008727,0.034905,16.5,4.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
59,11952,2.330603,0.084892,0.117649,0.042808,2.890716,-3.000560,58600.0,10,0.008727,0.034905,16.5,4.0
60,134815,2.550916,0.141660,0.510228,0.284591,0.630896,-0.920797,58600.0,10,0.008727,0.034905,16.5,4.0
61,27860,2.595202,0.098315,0.194023,5.535984,3.255585,3.966790,58600.0,10,0.008727,0.034905,16.5,4.0
62,85937,2.216242,0.195323,0.437115,5.285351,3.172956,3.921169,58600.0,10,0.008727,0.034905,16.5,4.0


## Build Asteroid Search Model

In [24]:
# Observatory for ZTF data is Palomar Mountain
site_name = 'palomar'

In [25]:
# Training parameters
learning_rate = 2.0**-12
clipnorm = 1.0
save_at_end: bool = True

In [26]:
# Build asteroid search model
model = AsteroidSearchModel(
                elts=elts, ztf_elt=ztf_elt, 
                site_name=site_name, thresh_deg=thresh_deg, 
                learning_rate=learning_rate, clipnorm=clipnorm,
                name='model',
                file_name=f'candidate_elt_pert_large_{random_seed:04d}.h5'
)

In [27]:
# Load trained model
model.load()

Loaded candidate elements and training history from ../data/candidate_elt/candidate_elt_pert_large_0042.h5.


In [28]:
# Report before training starts
model.report()


Good elements (hits >= 10):   7.00

         \  log_like :  hits  :    R_sec : thresh_sec
Mean Good:  1146.91  :  68.29 :    25.41 :  2400.00
Mean Bad :   200.48  :   0.14 :  1055.07 :  2400.00
Mean     :   304.00  :   7.59 :   942.45 :  2400.00
Median   :   181.50  :   0.00 :  1200.00 :  2400.00
GeoMean  :   203.19  :   0.58 :   579.91 :  2401.00
Min      :    43.86  :   0.00 :     6.67 :  2400.00
Max      :  1407.04  : 163.00 :  1200.00 :  2400.00
Trained for 10496 batches over 164 epochs and 41 episodes (elapsed time 417 seconds).


In [29]:
ztf_hits = model.calc_ztf_hits()

HBox(children=(FloatProgress(value=0.0, max=64.0), HTML(value='')))




In [30]:
cols = \
    ['ztf_id', 'element_id', 'mjd',
     'ra', 'dec', 'mag_app', 'ux', 'uy', 'uz',
     'elt_ux', 'elt_uy', 'elt_uz', 
     's_sec', 'mag_pred', 'mag_diff']

In [31]:
ztf_hits[cols]

Unnamed: 0,ztf_id,element_id,mjd,ra,dec,mag_app,ux,uy,uz,elt_ux,elt_uy,elt_uz,s_sec,mag_pred,mag_diff
0,1733585,733,58744.420972,52.978730,52.846579,16.808901,0.363646,0.759448,0.539444,0.363627,0.759439,0.539470,7.028630,17.612526,0.803625
1,1813612,733,58747.446991,53.578605,53.159376,16.784700,0.355989,0.761005,0.542350,0.355976,0.760997,0.542369,5.009152,17.570366,0.785666
2,1819129,733,58747.490312,53.585943,53.163630,17.194000,0.355892,0.761021,0.542391,0.355879,0.761013,0.542410,4.999936,17.569696,0.375696
3,1911039,733,58750.402049,54.074838,53.433569,17.140800,0.349546,0.762104,0.544991,0.349539,0.762099,0.545003,3.153805,17.529083,0.388283
4,1916432,733,58750.469144,54.084122,53.439536,16.727800,0.349419,0.762119,0.545052,0.349412,0.762113,0.545064,3.027908,17.528160,0.800360
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
480,3274393,104194,58797.233981,42.760303,48.014329,17.912399,0.491140,0.712364,0.501318,0.491143,0.712354,0.501328,3.120997,18.488108,0.575708
481,3276453,104194,58797.270278,42.735914,48.015615,17.766100,0.491321,0.712168,0.501419,0.491325,0.712157,0.501431,3.336879,18.488625,0.722525
482,3282149,104194,58797.325417,42.698832,48.017455,17.144800,0.491597,0.711870,0.501572,0.491601,0.711858,0.501584,3.626259,18.489401,1.344601
483,3312487,104194,58798.183044,42.137575,48.041911,18.004299,0.495781,0.707350,0.503843,0.495796,0.707324,0.503864,7.403895,18.501766,0.497467


In [32]:
self = model
elts_fit = self.candidates_df()

In [33]:
elts_fit

Unnamed: 0,element_id,a,e,inc,Omega,omega,f,epoch,num_hits,R,...,thresh_deg,thresh_sec,log_like,hits,num_rows_close,H,sigma_mag,weight_joint,weight_element,weight_mixture
0,51921,2.672028,0.217864,0.499715,4.699493,2.446779,-1.129590,58600.0,194.909027,0.000095,...,0.666667,2399.999756,1291.511963,16.996742,625.0,15.430998,3.308102,0.25,1.0,1.0000
1,59244,2.636551,0.264598,0.464793,5.737494,1.780742,-1.611774,58600.0,187.710724,0.000235,...,0.666667,2399.999756,972.112427,14.985753,537.0,16.615879,3.315832,0.25,1.0,1.0000
2,15786,1.917499,0.044975,0.379297,6.131769,0.783329,-1.249776,58600.0,377.000000,0.005818,...,0.666667,2399.999756,132.732040,0.000000,377.0,17.612007,3.329152,0.25,1.0,0.5000
3,3904,2.756904,0.126885,0.256991,5.504299,2.174521,-1.363815,58600.0,361.999939,0.005818,...,0.666667,2399.999756,150.461990,0.000000,362.0,16.443434,3.320377,0.25,1.0,1.0000
4,142999,2.543661,0.188454,0.506491,0.213543,0.954120,-1.348161,58600.0,417.000061,0.005818,...,0.666667,2399.999756,125.745827,0.000000,417.0,16.951111,3.341388,0.25,1.0,1.0000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
59,11952,2.328239,0.108479,0.103699,0.099810,2.886406,-3.013675,58600.0,544.210449,0.005818,...,0.666667,2399.999756,293.935638,0.000000,1274.0,16.217176,3.313385,0.25,1.0,1.0000
60,134815,2.604447,0.150049,0.519907,0.291115,0.639903,-0.919676,58600.0,245.000000,0.005818,...,0.666667,2399.999756,109.600258,0.000000,245.0,16.548264,3.337271,0.25,1.0,1.0000
61,27860,2.600416,0.096817,0.194389,5.530719,3.259895,3.969121,58600.0,83.200272,0.001401,...,0.666667,2399.999756,54.210567,0.000000,981.0,16.162300,3.343579,0.25,1.0,0.0625
62,85937,2.190138,0.207091,0.434241,5.242270,3.154911,3.953590,58600.0,390.000092,0.005818,...,0.666667,2399.999756,169.643311,0.000000,390.0,17.236227,3.322382,0.25,1.0,1.0000


In [34]:
ztf_hits['mag_diff'] = np.abs(ztf_hits.mag_app - ztf_hits.mag_pred)

In [35]:
ztf_hits[cols]

Unnamed: 0,ztf_id,element_id,mjd,ra,dec,mag_app,ux,uy,uz,elt_ux,elt_uy,elt_uz,s_sec,mag_pred,mag_diff
0,1733585,733,58744.420972,52.978730,52.846579,16.808901,0.363646,0.759448,0.539444,0.363627,0.759439,0.539470,7.028630,17.612526,0.803625
1,1813612,733,58747.446991,53.578605,53.159376,16.784700,0.355989,0.761005,0.542350,0.355976,0.760997,0.542369,5.009152,17.570366,0.785666
2,1819129,733,58747.490312,53.585943,53.163630,17.194000,0.355892,0.761021,0.542391,0.355879,0.761013,0.542410,4.999936,17.569696,0.375696
3,1911039,733,58750.402049,54.074838,53.433569,17.140800,0.349546,0.762104,0.544991,0.349539,0.762099,0.545003,3.153805,17.529083,0.388283
4,1916432,733,58750.469144,54.084122,53.439536,16.727800,0.349419,0.762119,0.545052,0.349412,0.762113,0.545064,3.027908,17.528160,0.800360
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
480,3274393,104194,58797.233981,42.760303,48.014329,17.912399,0.491140,0.712364,0.501318,0.491143,0.712354,0.501328,3.120997,18.488108,0.575708
481,3276453,104194,58797.270278,42.735914,48.015615,17.766100,0.491321,0.712168,0.501419,0.491325,0.712157,0.501431,3.336879,18.488625,0.722525
482,3282149,104194,58797.325417,42.698832,48.017455,17.144800,0.491597,0.711870,0.501572,0.491601,0.711858,0.501584,3.626259,18.489401,1.344601
483,3312487,104194,58798.183044,42.137575,48.041911,18.004299,0.495781,0.707350,0.503843,0.495796,0.707324,0.503864,7.403895,18.501766,0.497467


In [36]:
np.mean(ztf_hits.mag_diff)

0.9113583004351746

In [37]:
np.std(ztf_hits.mag_app)

0.8981321434181515

In [38]:
mask = ztf_hits.element_id == 733

In [39]:
ztf_hits[cols][mask]

Unnamed: 0,ztf_id,element_id,mjd,ra,dec,mag_app,ux,uy,uz,elt_ux,elt_uy,elt_uz,s_sec,mag_pred,mag_diff
0,1733585,733,58744.420972,52.97873,52.846579,16.808901,0.363646,0.759448,0.539444,0.363627,0.759439,0.53947,7.02863,17.612526,0.803625
1,1813612,733,58747.446991,53.578605,53.159376,16.7847,0.355989,0.761005,0.54235,0.355976,0.760997,0.542369,5.009152,17.570366,0.785666
2,1819129,733,58747.490312,53.585943,53.16363,17.194,0.355892,0.761021,0.542391,0.355879,0.761013,0.54241,4.999936,17.569696,0.375696
3,1911039,733,58750.402049,54.074838,53.433569,17.1408,0.349546,0.762104,0.544991,0.349539,0.762099,0.545003,3.153805,17.529083,0.388283
4,1916432,733,58750.469144,54.084122,53.439536,16.7278,0.349419,0.762119,0.545052,0.349412,0.762113,0.545064,3.027908,17.52816,0.80036
5,1916633,733,58750.470961,54.084376,53.439703,16.7143,0.349415,0.762119,0.545054,0.349409,0.762113,0.545065,2.997128,17.528097,0.813797
6,1994410,733,58756.486481,54.789836,53.890263,16.6791,0.339796,0.763137,0.54969,0.339805,0.763136,0.549687,1.943897,17.444452,0.765352
7,2071539,733,58759.36875,54.979604,54.049413,16.5123,0.336911,0.76313,0.551475,0.336926,0.763131,0.551463,3.98793,17.404829,0.892529
8,2079718,733,58759.382384,54.980077,54.050121,16.5128,0.336901,0.763128,0.551483,0.336917,0.763129,0.551472,4.026111,17.404606,0.891806
9,2085480,733,58759.413148,54.981065,54.051651,16.9265,0.33688,0.763123,0.551503,0.336896,0.763125,0.551491,3.959413,17.404177,0.477676


In [40]:
np.mean(ztf_hits.mag_pred[mask])

17.441818237304688

In [41]:
np.mean(ztf_hits.mag_app[mask])

16.80821287631989

In [43]:
mag = ztf_ast.mag_app.values

In [44]:
np.std(mag)

1.1037232485682564

In [45]:
np.mean(mag)

18.95098666821053

In [46]:
mag_pred = ztf_hits.mag_pred

In [47]:
np.mean(mag_pred)

18.765668869018555

In [48]:
np.std(mag_pred)

0.6102663278579712

In [49]:
np.mean(ztf_hits.mag_diff)

0.9113583004351746

In [50]:
q_ast, v_ast = model.predict_position()

In [51]:
u_pred, r_pred, mag_pred = model.predict_direction()

In [52]:
num_hits, R, R_max = model.get_mixture_params()

In [53]:
mag_pred, sigma_mag = model.magnitude(q_ast)

In [54]:
mag_pred

<tf.Tensor: shape=(290771,), dtype=float32, numpy=
array([18.86412 , 19.24609 , 19.254805, ..., 19.800415, 19.800415,
       19.800413], dtype=float32)>

In [55]:
num_hits

array([194.90903, 187.71072, 377.     , 361.99994, 417.00006, 163.95807,
       245.99994, 615.871  , 295.0001 , 525.81836, 220.00003, 172.98471,
       411.     , 409.00003, 582.6705 ,  52.66849, 461.     , 481.43085,
       559.76764, 500.7688 , 248.00003, 167.39249, 162.98485, 534.4035 ,
       493.12103, 161.97577, 423.00006, 479.9984 , 250.99997, 528.73016,
       283.1051 , 531.3354 , 162.71024, 310.99997, 529.9829 ,  50.21997,
       440.00006, 241.99997, 559.6961 , 417.00006, 153.96823, 160.96904,
       319.99994, 314.     , 289.99997, 515.7891 ,  77.99371, 224.49268,
       541.46094, 529.6459 , 153.40881, 178.00002, 354.00003, 551.1699 ,
       558.30707, 302.99988, 331.99994, 376.00003, 663.3324 , 544.21045,
       245.     ,  83.20027, 390.0001 , 538.51044], dtype=float32)

In [57]:
model.magnitude.get_sigma_mag()

<tf.Tensor: shape=(64,), dtype=float32, numpy=
array([3.3081017, 3.315832 , 3.329152 , 3.3203773, 3.341388 , 3.316377 ,
       3.3158445, 3.320999 , 3.318493 , 3.3379204, 3.3291774, 3.3098276,
       3.324485 , 3.3207762, 3.3081017, 3.3378108, 3.3023841, 3.3333054,
       3.3166535, 3.3258023, 3.314421 , 3.311572 , 3.3263142, 3.3140342,
       3.3317494, 3.3320727, 3.3207207, 3.3243542, 3.3147388, 3.3081033,
       3.345127 , 3.3126004, 3.3130991, 3.3182645, 3.3178587, 3.3384562,
       3.3214433, 3.3102634, 3.3136227, 3.3171084, 3.3039503, 3.3103623,
       3.325276 , 3.3116884, 3.334243 , 3.3065903, 3.3402922, 3.2868593,
       3.3133473, 3.3215094, 3.3138676, 3.3151205, 3.319129 , 3.3060615,
       3.3083794, 3.3142262, 3.3248687, 3.3140807, 3.3109474, 3.3133848,
       3.3372715, 3.3435788, 3.3223822, 3.3026392], dtype=float32)>

In [58]:
model.magnitude.get_H()

<tf.Tensor: shape=(64,), dtype=float32, numpy=
array([15.430998, 16.61588 , 17.612007, 16.443434, 16.95111 , 16.497623,
       15.492089, 16.71932 , 17.637424, 16.463398, 17.185095, 15.337681,
       16.38165 , 17.689379, 15.699258, 17.31567 , 17.256306, 16.331377,
       15.703431, 17.094046, 17.563665, 15.395374, 15.45465 , 15.942504,
       16.441114, 16.45381 , 17.629517, 17.422169, 17.705387, 17.377289,
       16.763489, 16.171293, 15.967063, 17.59524 , 16.402407, 17.157513,
       16.974741, 17.708378, 17.563158, 15.948264, 15.356865, 15.437008,
       15.780988, 17.700438, 16.076597, 17.687643, 16.89887 , 15.363563,
       16.208687, 17.274439, 16.398052, 17.641155, 17.255655, 17.623085,
       15.538877, 15.393722, 15.853599, 16.624027, 15.388755, 16.217176,
       16.548264, 16.1623  , 17.236227, 16.054396], dtype=float32)>

In [62]:
self = model.score
from typing import Tuple

In [63]:
# Transform thresh_s2_ to thresh_s2_elt
self.thresh_s2_elt: tf.Tensor = self.get_thresh_s2()

# Convert R to exponential decay parameter lam
half_thresh_s2: tf.Tensor = tf.multiply(self.thresh_s2_elt, 0.5, name='half_thresh_s2')
R2: tf.Tensor = tf.square(R, name='R2')
lam: tf.Tensor = tf.divide(half_thresh_s2, R2, name='lam')

# Difference between actual and predicted directions
du: tf.Tensor = keras.layers.subtract(inputs=[u_pred, self.u_obs], name='du')
# Squared distance bewteen predicted and observed directions
s2: tf.Tensor = tf.reduce_sum(tf.square(du), axis=(-1), name='s2')

In [64]:
# Upsample thresh_s2 so it matches input shape
thresh_shape: Tuple[int] = (self.data_size,)
self.thresh_s2_rep: tf.Tensor = \
    tf.repeat(input=self.thresh_s2_elt, repeats=self.row_lengths, name='thresh_s2_rep')
self.thresh_s2: tf.Tensor = \
    tf.reshape(tensor=self.thresh_s2_rep, shape=thresh_shape, name='thresh_s2')

# Filter to only include terms where z2 is within the threshold distance^2
is_close: tf.Tensor = tf.math.less(s2, self.thresh_s2, name='is_close')

# Relative distance v on data inside threshold; shape is [num_close,]
v_num: tf.Tensor = tf.boolean_mask(tensor=s2, mask=is_close, name='v_num')
v_den: tf.Tensor = tf.boolean_mask(tensor=self.thresh_s2, mask=is_close, name='v_den')
v: tf.Tensor = tf.divide(v_num, v_den, name='v')

In [65]:
# Row_lengths, for close observations only
# is_close_r = tf.RaggedTensor.from_row_lengths(
#   values=is_close, row_lengths=self.row_lengths, name='is_close_r')
ragged_map_func = lambda x : \
    tf.RaggedTensor.from_row_lengths(values=x, row_lengths=self.row_lengths)
is_close_r: tf.RaggedTensor = \
    tf.keras.layers.Lambda(function=ragged_map_func, name='is_close_r')(is_close)
# Compute the row lengths (number of close observations per candidate element)
row_lengths_close: tf.Tensor = \
    tf.reduce_sum(tf.cast(is_close_r, tf.int32), axis=1, name='row_lengths_close')
row_lengths_close_float: tf.Tensor = \
    tf.cast(x=row_lengths_close, dtype=dtype)

In [67]:
# Compute the implied hit rate h from the number of hits and row_lengths_close; shape [batch_size,]
h: tf.Tensor = tf.divide(num_hits, row_lengths_close_float, name='h_raw')
# The hit rate must be in [0, 1]
h = tf.clip_by_value(h, clip_value_min=0.0, clip_value_max=1.0)

# Shape of parameters
close_size: tf.Tensor = tf.reduce_sum(row_lengths_close)
param_shape = (close_size,)

# Upsample h and lambda
h_rep: tf.Tensor = tf.repeat(input=h, repeats=row_lengths_close, name='h_rep')
h_vec: tf.Tensor = tf.reshape(tensor=h_rep, shape=param_shape, name='h_vec')
lam_rep: tf.Tensor = tf.repeat(input=lam, repeats=row_lengths_close, name='lam_rep')
lam_vec: tf.Tensor = tf.reshape(tensor=lam_rep, shape=param_shape, name='lam_vec')

In [68]:
# Conditional probability based on distance bewteen predicted and observed direction
emlx: tf.Tensor = tf.exp(-lam_vec * v, name='emlx')
p_cond_dist_num: tf.Tensor = tf.multiply(emlx, lam_vec, name='p_cond_dist_num')
p_cond_dist_den: tf.Tensor = tf.subtract(1.0, tf.exp(-lam_vec), name='p_cond_dist_den')
p_cond_dist: tf.Tensor = tf.divide(p_cond_dist_num, p_cond_dist_den, name='p_cond_dist')

In [74]:
# Difference between predicted and observed magnitude
mag_diff_all: tf.Tensor = tf.subtract(mag_pred, self.mag_obs, name='mag_diff_all')

# Conditional probability based on difference between predicted and observed magnitude
mag_z_all: tf.Tensor = tf.divide(mag_diff, sigma_mag, name='mag_z_all')
mag_z2_all: tf.Tensor = tf.square(mag_z, name='mag_z2_all')
mag_arg_all: tf.Tensor = tf.multiply(mag_z2, -0.5, name='mag_arg_all')

In [76]:
# Map from flat tensors of nearby rows to ragged tensors
ragged_map_func_close = lambda x : \
    tf.RaggedTensor.from_row_lengths(values=x, row_lengths=row_lengths_close)

# The normalized probability based on the magnitude is the unnormalized prob over the normalizer
# The numerator is the exp(mag_arg); shape [data_size,]
p_mag_num_all: tf.Tensor = tf.exp(mag_arg_all, name='p_mag_num_all')

In [77]:
# Filter to only the close rows; shape [num_close,]
p_mag_num: tf.Tensor = tf.boolean_mask(tensor=p_mag_num_all, mask=is_close, name='p_mag_num')
p_mag_den: tf.Tensor = tf.boolean_mask(tensor=sigma_mag, mask=is_close, name='p_mag_den')
p_mag_pdf: tf.Tensor = tf.divide(p_mag_num, p_mag_den, name='p_mag_pdf')
p_cond_mag: tf.Tensor = tf.multiply(self.sigma_mag_normalizer, p_mag_pdf, name='p_cond_mag')

In [128]:
# Combined conditional probability of a hit
p_hit_cond: tf.Tensor = p_cond_dist

# Probability according to mixture model
p_hit: tf.Tensor = tf.multiply(h_vec, p_hit_cond, name='p_hit')
p_miss: tf.Tensor = tf.subtract(1.0, h_vec, name='p_miss')
p: tf.Tensor = tf.add(p_hit, p_miss, name='p')
log_p_flat: tf.Tensor = keras.layers.Activation(tf.math.log, name='log_p_flat')(p)

# The posterior hit probability is p_hit / p
p_hit_post_flat: tf.Tensor = tf.divide(p_hit, p)

In [71]:
mag_diff

<tf.Tensor: shape=(290771,), dtype=float32, numpy=
array([-0.30197906,  2.2160892 ,  2.1946049 , ...,  0.27001572,
        0.13521576,  0.39011383], dtype=float32)>

In [93]:
mag_diff_all: tf.Tensor = tf.subtract(mag_pred, self.mag_obs, name='mag_diff_all')
mag_diff: tf.Tensor = tf.boolean_mask(tensor=mag_diff_all, mask=is_close, name='mag_diff')

In [94]:
np.mean(mag_diff_all)

-0.06570314

In [95]:
np.mean(mag_diff)

-0.010758142

In [96]:
sigma_mag: tf.Tensor = tf.boolean_mask(tensor=sigma_mag, mask=is_close, name='sigma_mag')

In [97]:
mag_z: tf.Tensor = tf.divide(mag_diff, sigma_mag, name='mag_z')
mag_z2: tf.Tensor = tf.square(mag_z, name='mag_z2')
mag_arg: tf.Tensor = tf.multiply(mag_z2, -0.5, name='mag_arg')

In [98]:
np.mean(mag_arg)

-0.06334289

In [103]:
np.mean((mag_diff/sigma_mag)**2)

0.12668578

In [104]:
log_sigma_mag: tf.Tensor = tf.math.log(sigma_mag, name='log_sigma_mag')

In [108]:
mag_log_pdf: tf.Tensor = tf.subtract(mag_arg, log_sigma_mag, name='mag_log_pdf')

In [111]:
# Equal weight for each observation associated with an element
w_equal_elt = tf.divide(1.0, row_lengths_close_float, name='w_equal_elt')

In [113]:
# Upsample the equal weights
w_equal = tf.repeat(input=w_equal_elt, repeats=row_lengths_close)

In [115]:
w_equal_elt.shape

TensorShape([64])

In [116]:
w_equal.shape

TensorShape([39649])

In [119]:
xxx = ragged_map_func_close(w_equal)

In [130]:
# Sum hit probability by element for magnitude calculation below
p_hit_r: tf.Tensor = \
    tf.keras.layers.Lambda(function=ragged_map_func_close, name='p_hit_r')(p_hit)

In [134]:
# Denominator for normalizing p_hit
p_hit_den_elt: tf.Tensor = tf.reduce_sum(p_hit_r, axis=-1, name='p_hit_den_elt')
# Upsample the denominator
p_hit_den: tf.Tensor = tf.repeat(input=p_hit_den_elt, repeats=row_lengths_close, name='p_hit_den')

In [136]:
# Normalized hit probability; sums to 1 over each element
p_hit_norm: tf.Tensor = tf.divide(p_hit, p_hit_den, name='p_hit_norm')

In [137]:
# Weights for log likelihood of magnitude
w_mag: tf.Tensor = tf.subtract(p_hit_norm, w_equal, name='w_mag')

In [138]:
xxx = ragged_map_func_close(w_mag)

In [139]:
tf.reduce_sum(xxx, axis=-1)

<tf.Tensor: shape=(64,), dtype=float32, numpy=
array([ 7.61356205e-07,  3.30619514e-07,  6.94766641e-07, -1.46101229e-07,
       -1.00466423e-07, -6.52740709e-07,  1.80909410e-07, -8.10803613e-07,
       -1.06636435e-07, -9.71893314e-07,  2.24448740e-07,  2.41911039e-07,
        2.17696652e-07, -1.32829882e-07,  4.90108505e-08, -4.67058271e-07,
       -5.72763383e-08,  3.74391675e-07, -2.86090653e-07,  7.66594894e-08,
        6.02332875e-07,  2.65426934e-08,  9.45292413e-08, -4.90748789e-07,
        2.45985575e-07,  1.80443749e-07, -1.91736035e-07, -2.13389285e-07,
       -3.64147127e-07,  7.03148544e-08,  2.90921889e-07, -5.97676262e-07,
        3.38419341e-07,  3.16067599e-07, -6.08968548e-07,  4.28408384e-08,
       -3.02447006e-07,  1.04540959e-07, -2.18278728e-07,  8.68458301e-08,
       -3.25962901e-07,  2.32574530e-06,  6.99190423e-07, -1.66823156e-07,
        5.29456884e-07,  1.90222636e-07,  1.91503204e-08,  3.88128683e-07,
        1.38999894e-07,  5.20376489e-08,  6.20959327e

In [141]:
# The log likelihood of the magnitude; flat, by observation
mag_log_like_flat = tf.multiply(w_mag, mag_log_pdf, name='mag_log_like_flat')

In [145]:
np.sum(mag_log_like_flat)

0.18304774