In [1]:
# Core
import numpy as np
import pandas as pd

# Tensorflow / ML
import tensorflow as tf
from tensorflow.python.keras import backend as K

# Utility
import time
from datetime import timedelta

# Local imports
from asteroid_search_model import make_model_asteroid_search
from ztf_data import load_ztf_easy_batch
from asteroid_data import make_ztf_dataset, orbital_element_batch
from asteroid_integrate import calc_ast_pos
from asteroid_search_report import report_model, report_training_progress
from utils import print_header

# Typing
from typing import Optional

Found 4 GPUs.  Setting memory growth = True.


In [2]:
from asteroid_search import perturb_elts, test_easy_batch
from asteroid_search_report import report_model_attribute
from search_score_functions import score_mean_var
from asteroid_dataframe import calc_ast_data, spline_ast_vec_df, calc_ast_dir
from ztf_data import make_ztf_batch, make_ztf_near_elt
from astro_utils import deg2dist, dist2deg

In [3]:
# Aliases
keras = tf.keras

## Generate DataSet Matching an Easy Batch

In [4]:
# Resolution and threshold in degrees
R_deg: float = 1.0
thresh_deg: float = 1.0
# Whether resolution R is trainable
R_is_trainable = False

# Batch setup
time_batch_size = None
elt_batch_size = 64
epochs = 10
cycles_per_epoch = 10

# Set calibration flag
use_calibration: bool = True

In [5]:
# Load all ZTF data with nearest asteroid calculations
ztf, elts = load_ztf_easy_batch(batch_size=elt_batch_size, thresh_deg=thresh_deg)

# Build a TensorFlow DataSet from ZTF DataFrame
ds, ts, row_len = make_ztf_dataset(ztf=ztf, batch_size=time_batch_size)

# Trajectory size and steps per batch
traj_size = ts.shape[0]
# If time_batch_size was None, use all the time points in each batch
if time_batch_size is None:
    time_batch_size = traj_size
steps = int(np.ceil(traj_size / time_batch_size))

# The element_id is the same as the asteroid number on the easy batch
element_id = elts.element_id.values
# The epoch (scalar)
epoch = elts['epoch'][0]

# The correct orbital elements; also a DataFrame
elts_true = elts.copy()

# Get example batch
batch_in, batch_out = list(ds.take(1))[0]
# Contents of this batch
t = batch_in['t']
idx = batch_in['idx']
row_len = batch_in['row_len']
u_obs = batch_in['u_obs']

# Get max_obs and number of observations
max_obs: int = u_obs.shape[1]
# The number of observations is the TOTAL FOR THE ZTF DATA SET!
# It's not just the size of this easy batch, b/c the easy batch has been harvested to be close!
# num_obs: float = np.sum(row_len, dtype=np.float32)
num_obs: float = 5.69E6

In [6]:
ztf

Unnamed: 0,ztf_id,element_id,ObjectID,CandidateID,TimeStampID,mjd,ra,dec,ux,uy,...,nearest_ast_dist,ast_ra,ast_dec,ast_ux,ast_uy,ast_uz,elt_ux,elt_uy,elt_uz,elt_r
0,37606,733,b'ZTF18abtdsvv',621336711115010010,9577,58375.336713,350.405941,0.879484,0.985897,-0.146790,...,0.001973,350.518390,0.891030,0.986219,-0.144934,0.079786,0.983072,-0.161029,0.087400,2.595640
1,37607,733,b'ZTF18abtdsvv',621309991115010019,9572,58375.310000,350.405990,0.879483,0.985897,-0.146789,...,0.002098,350.525361,0.893904,0.986239,-0.144804,0.079784,0.983090,-0.160937,0.087369,2.595627
2,37610,733,b'ZTF18abtdsvv',614358404915010002,8240,58368.358403,350.405972,0.879528,0.985897,-0.146789,...,0.001976,350.425686,0.991042,0.985923,-0.145699,0.082020,0.987293,-0.137772,0.079197,2.599708
3,37612,733,b'ZTF18abtdsvv',615324371115015021,8425,58369.324375,350.405987,0.879540,0.985897,-0.146789,...,0.001601,350.497212,0.889307,0.986159,-0.145280,0.079903,0.986748,-0.140962,0.080359,2.598268
4,141833,733,b'ZTF18abwpfho',623339405515010000,9980,58377.339410,348.764546,0.189321,0.980829,-0.177448,...,0.002356,348.831076,0.306747,0.981046,-0.175586,0.081961,0.981751,-0.167709,0.089660,2.597194
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
90205,5584064,324582,b'ZTF18aaaknvz',1149162620015010022,96198,58903.162627,67.404504,45.304078,0.270241,0.878532,...,0.006270,67.794274,45.536954,0.264723,0.878880,0.396851,0.271240,0.872415,0.406597,1.545974
90206,5584155,324582,b'ZTF20aaqjjhn',1149162620115015001,96198,58903.162627,66.164947,45.635778,0.282557,0.871180,...,0.003937,66.390553,45.474743,0.280840,0.873100,0.398528,0.271240,0.872415,0.406597,1.545974
90207,5584262,324582,b'ZTF20aaqjltu',1149162621915015008,96198,58903.162627,67.072721,46.398450,0.268658,0.870797,...,0.000034,67.069974,46.398068,0.268690,0.870788,0.411746,0.271240,0.872415,0.406597,1.545974
90208,5584879,324582,b'ZTF20aaqjltx',1149162621915015001,96198,58903.162627,66.993445,46.563928,0.268718,0.869478,...,0.000004,66.993111,46.563943,0.268721,0.869477,0.414487,0.271240,0.872415,0.406597,1.545974


In [7]:
elts

Unnamed: 0,element_id,a,e,inc,Omega,omega,f,epoch
0,733,3.398871,0.058840,0.354075,5.951576,3.315651,3.703184,58600.0
1,1476,2.280890,0.189703,0.110415,5.768549,6.107239,0.309986,58600.0
2,1803,2.349173,0.247823,0.376183,5.886310,4.431771,2.468021,58600.0
3,2015,2.335363,0.103840,0.207837,6.012107,4.788967,1.897608,58600.0
4,2294,2.581424,0.116455,0.109954,5.075755,0.753231,0.373019,58600.0
...,...,...,...,...,...,...,...,...
59,203722,3.171154,0.291647,0.426091,5.309571,1.789398,-1.162711,58600.0
60,253246,2.677164,0.227275,0.446592,5.260714,1.455787,-0.777289,58600.0
61,306781,2.534977,0.209075,0.228417,5.551083,0.846400,-0.451545,58600.0
62,313521,2.637372,0.254435,0.524255,4.487565,2.113770,-1.015696,58600.0


## Orbital Element Batch

In [8]:
# Mask where data perturbed vs not
mask_good = np.arange(elt_batch_size) < (elt_batch_size//2)
mask_bad = ~mask_good
# Perturb second half of orbital elements
elts_pert = perturb_elts(elts, sigma_a=0.00, sigma_e=0.00, sigma_f_deg=0.1, mask=mask_bad)
# elts_pert = perturb_elts(elts, mask=mask_bad)

# Orbits for calibration on perturbed elements
if use_calibration:
    print(f'Numerically integrating calibration trajectories on perturbed elements, q_cal...')
    q_cal = calc_ast_pos(elts=elts_pert, epoch=epoch, ts=ts)
else:
    q_cal = None

Numerically integrating calibration trajectories on perturbed elements, q_cal...


In [9]:
# ZTF batch of perturbed elements
file_path = '../data/ztf/ztf-easy-batch-pert.h5'
try:
    ztfp = pd.read_hdf(file_path)
except:
    ztfp = make_ztf_batch(elts=elts_pert, thresh_deg=thresh_deg, near_ast=True)
    ztf.to_hdf(file_path, key='ztf', mode='w')

In [10]:
ztfp

Unnamed: 0,ztf_id,element_id,ObjectID,CandidateID,TimeStampID,mjd,ra,dec,ux,uy,...,nearest_ast_dist,ast_ra,ast_dec,ast_ux,ast_uy,ast_uz,elt_ux,elt_uy,elt_uz,elt_r
0,37606,733,b'ZTF18abtdsvv',621336711115010010,9577,58375.336713,350.405941,0.879484,0.985897,-0.146790,...,0.001973,350.518390,0.891030,0.986219,-0.144934,0.079786,0.983072,-0.161029,0.087400,2.595640
1,37607,733,b'ZTF18abtdsvv',621309991115010019,9572,58375.310000,350.405990,0.879483,0.985897,-0.146789,...,0.002098,350.525361,0.893904,0.986239,-0.144804,0.079784,0.983090,-0.160937,0.087369,2.595627
2,37610,733,b'ZTF18abtdsvv',614358404915010002,8240,58368.358403,350.405972,0.879528,0.985897,-0.146789,...,0.001976,350.425686,0.991042,0.985923,-0.145699,0.082020,0.987293,-0.137772,0.079197,2.599708
3,37612,733,b'ZTF18abtdsvv',615324371115015021,8425,58369.324375,350.405987,0.879540,0.985897,-0.146789,...,0.001601,350.497212,0.889307,0.986159,-0.145280,0.079903,0.986748,-0.140962,0.080359,2.598268
4,141833,733,b'ZTF18abwpfho',623339405515010000,9980,58377.339410,348.764546,0.189321,0.980829,-0.177448,...,0.002356,348.831076,0.306747,0.981046,-0.175586,0.081961,0.981751,-0.167709,0.089660,2.597194
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
90205,5584064,324582,b'ZTF18aaaknvz',1149162620015010022,96198,58903.162627,67.404504,45.304078,0.270241,0.878532,...,0.006270,67.794274,45.536954,0.264723,0.878880,0.396851,0.271240,0.872415,0.406597,1.545974
90206,5584155,324582,b'ZTF20aaqjjhn',1149162620115015001,96198,58903.162627,66.164947,45.635778,0.282557,0.871180,...,0.003937,66.390553,45.474743,0.280840,0.873100,0.398528,0.271240,0.872415,0.406597,1.545974
90207,5584262,324582,b'ZTF20aaqjltu',1149162621915015008,96198,58903.162627,67.072721,46.398450,0.268658,0.870797,...,0.000034,67.069974,46.398068,0.268690,0.870788,0.411746,0.271240,0.872415,0.406597,1.545974
90208,5584879,324582,b'ZTF20aaqjltx',1149162621915015001,96198,58903.162627,66.993445,46.563928,0.268718,0.869478,...,0.000004,66.993111,46.563943,0.268721,0.869477,0.414487,0.271240,0.872415,0.406597,1.545974


## Build Model

In [11]:
# The observation site for ZTF data is Palomar Mountain.
site_name = 'palomar'

# Alpha and beta parameters for the objective function
alpha = 1.0
beta = 0.0

# Build functional model for asteroid score
model = make_model_asteroid_search(\
    ts=ts, elts=elts_pert, max_obs=max_obs, num_obs=num_obs, site_name=site_name,
    elt_batch_size=elt_batch_size, time_batch_size=time_batch_size,
    R_deg=R_deg, thresh_deg=thresh_deg, R_is_trainable=R_is_trainable, alpha=alpha, beta=beta, 
    q_cal=q_cal, use_calibration=use_calibration)

# Use Adam optimizer with gradient clipping
# learning_rate = 2.0e-5
learning_rate = 2.0E-5  # default 1.0E-3
beta_1 = 0.900          # default 0.900
beta_2 = 0.999          # default 0.999
epsilon = 1.0E-7        # default 1.0E-7
amsgrad = False         # default False
clipvalue = 5.0         # default not used
opt = keras.optimizers.Adam(learning_rate=learning_rate, 
                            beta_1=beta_1, 
                            beta_2=beta_2,
                            epsilon=epsilon, 
                            amsgrad=amsgrad,
                            # clipvalue=clipvalue, 
                            )
model.compile(optimizer=opt)
# Whether to display results
display = False

## Initial Losses

In [12]:
R = deg2dist(R_deg)
A = 1.0 / R**2
thresh = deg2dist(thresh_deg)
mu_per_obs, sigma2_per_obs = score_mean_var(A, thresh=thresh)
N_obs = 5.69e6
mu = N_obs * mu_per_obs
sigma2 = N_obs * sigma2_per_obs
sigma = np.sqrt(sigma2)

# report
print(f'Resolution: {R_deg} degrees / {R:8.6} cartesian')
print(f'A         : {A:8.3f}')
print('\nPer Observation:')
print(f'mu:     {mu_per_obs:5.3e}')
print(f'sigma2: {sigma2_per_obs:5.3e}')
print('\nFor Data:')
print(f'mu:     {mu:10.6f}')
print(f'sigma2: {sigma2:10.6f}')
print(f'simga:  {sigma:10.6f}')

Resolution: 1.0 degrees / 0.0174531 cartesian
A         : 3282.890

Per Observation:
mu:     5.993e-05
sigma2: 4.813e-05

For Data:
mu:     340.986259
sigma2: 273.882006
simga:   16.549381


In [13]:
# Report losses before training
print(f'Processed easy batch with {elt_batch_size} asteroids. Perturbed second half.')
if display:
    print_header('Model Before Training:')
pred0 = model.predict_on_batch(ds)
elts0, R0, u_pred0, z0, _ = pred0
scores0, traj_err0, elt_err0 = \
    report_model(model=model, ds=ds, R_deg=R_deg, mask_good=mask_good, 
                 batch_size=elt_batch_size, steps=steps, elts_true=elts_true, display=display)

Processed easy batch with 64 asteroids. Perturbed second half.


## Check that untrained model predictions match expected results

In [14]:
# The results predicted by the TF model, as a numpy array of shape (batch_size=64, N_t=2371, 3)
u_pred_tf = u_pred0.numpy()

In [15]:
# Alias tensorflow of times to numpy array mjd
mjd = ts.numpy()

# Show that mjd is sorted
# np.min(np.diff(mjd))

In [16]:
# Arrays of element errors in Cartesian distance degrees
err_obs_dist = np.zeros(elt_batch_size)
err_elt_dist = np.zeros(elt_batch_size)
err_ast_dist = np.zeros(elt_batch_size)

# Columns to extract integrated asteroid direction
cols_obs_dir = ['ux', 'uy', 'uz']
cols_ast_dir = ['ast_ux', 'ast_uy', 'ast_uz']
cols_elt_dir = ['elt_ux', 'elt_uy', 'elt_uz']

# Theshold for hits is 2.0 arc seconds
thresh_hit = deg2dist(2.0 / 3600.0)

for i, elt_id in enumerate(element_id):
    # Select entries for this element that are hits
    mask = (ztfp.element_id == elt_id) & (ztfp.nearest_ast_num == elt_id) #& (ztf.nearest_ast_dist < thresh_hit)
    ztf_i = ztfp[mask]
    
    # Time indices on left and right side side; want idx s/t mjd[idx] == ztf_i.mjd
    idx_l = np.searchsorted(a=mjd, v=ztf_i.mjd)-1
    idx_r = idx_l + 1

    # Find the side that is closer
    dist_l = ztf_i.mjd.values - mjd[idx_l]
    dist_r = mjd[idx_r] - ztf_i.mjd.values
    is_left = (dist_l < dist_r)
    idx = idx_r.copy()

    # The index that is closest
    idx[is_left] = idx_l[is_left]

    # TensorFlow predictions on the ZTF dates for this element_id
    u_pred_tf_i = u_pred_tf[i, idx, :]

    # Direction: observation and MSE prediction in Numpy
    u_obs_np_i = ztf_i[cols_obs_dir].values
    u_elt_np_i = ztf_i[cols_elt_dir].values
    u_ast_np_i = ztf_i[cols_ast_dir].values

    # Difference between TensorFlow and Numpy
    u_diff_obs_i = u_pred_tf_i - u_obs_np_i
    u_diff_elt_i = u_pred_tf_i - u_elt_np_i
    u_diff_ast_i = u_pred_tf_i - u_ast_np_i
    
    # Save this entry to err_obs_dist and err_ast_dist
    err_obs_dist[i] = np.mean(np.linalg.norm(u_diff_obs_i, axis=1))
    err_elt_dist[i] = np.mean(np.linalg.norm(u_diff_elt_i, axis=1))
    err_ast_dist[i] = np.mean(np.linalg.norm(u_diff_ast_i, axis=1))

# Distance in degrees
err_obs_deg = dist2deg(err_obs_dist)
err_elt_deg = dist2deg(err_elt_dist)
err_ast_deg = dist2deg(err_ast_dist)

In [17]:
# ztf_i

In [18]:
def report_error(err_deg, err_type: str):
    mean_err_deg_g = np.mean(err_deg[mask_good])
    mean_err_deg_b = np.mean(err_deg[mask_bad])
    mean_err_sec_g = mean_err_deg_g * 3600.0
    mean_err_sec_b = mean_err_deg_b * 3600.0

    print(f'Mean error: {err_type}')
    print(f'Good: {mean_err_deg_g:8.4f} deg / {mean_err_sec_g:10.3f} arc sec')
    print(f'Bad : {mean_err_deg_b:8.4f} deg / {mean_err_sec_b:10.3f} arc sec\n')

In [19]:
# Report errors
report_error(err_deg=err_elt_deg, err_type='Orbital Elements')
report_error(err_deg=err_ast_deg, err_type='Nearest Asteroid')
report_error(err_deg=err_obs_deg, err_type='Observed')

Mean error: Orbital Elements
Good:   0.0002 deg /      0.780 arc sec
Bad :   0.1404 deg /    505.606 arc sec

Mean error: Nearest Asteroid
Good:   0.0002 deg /      0.782 arc sec
Bad :   0.1404 deg /    505.593 arc sec

Mean error: Observed
Good:   0.0728 deg /    262.140 arc sec
Bad :   0.3048 deg /   1097.260 arc sec



## Initial Gradients

In [20]:
# Get intital gradients on entire data set
with tf.GradientTape(persistent=True) as gt:
    gt.watch([model.elements.e_, model.elements.inc_, model.elements.R_])
    pred = model.predict_on_batch(ds.take(traj_size))
    elts, R, u_pred, z, scores = pred
    # unpack elements
    a = elts[:,0]
    e = elts[:,1]
    inc = elts[:,2]
    Omega = elts[:,3]
    omega = elts[:,4]
    f = elts[:,5]
    epoch = elts[:,6]
    # unpack scores
    raw_score = scores[:, 0]
    mu = scores[:, 1]
    sigma2 = scores[:, 2]
    objective = scores[:, 3]
    # total loss function
    loss = tf.reduce_sum(-objective)

# Derivatives of loss w.r.t. control variables for elements and R
dL_da = gt.gradient(loss, model.elements.a_) / steps
dL_de = gt.gradient(loss, model.elements.e_) / steps
dL_dinc = gt.gradient(loss, model.elements.inc_) / steps
dL_dOmega = gt.gradient(loss, model.elements.Omega_) / steps
dL_domega = gt.gradient(loss, model.elements.omega_) / steps
dL_df = gt.gradient(loss, model.elements.f_) / steps
dL_dR = gt.gradient(loss, model.elements.R_) / steps
del gt

In [21]:
# Report gradients
report_model_attribute(att=np.abs(dL_da), mask_good=mask_good, att_name='abs(grad_a)')
report_model_attribute(att=np.abs(dL_de), mask_good=mask_good, att_name='abs(grad_e)')
report_model_attribute(att=np.abs(dL_dinc), mask_good=mask_good, att_name='abs(grad_inc)')


Mean & Std abs(grad_a) by Category:
Good: 25648.88 +/- 21397.89
Bad:  49811.03 +/- 43486.70
All:  37729.96 +/- 36337.77

Mean & Std abs(grad_e) by Category:
Good:  6024.63 +/-  6063.05
Bad:  12884.88 +/- 15466.12
All:   9454.75 +/- 12237.09

Mean & Std abs(grad_inc) by Category:
Good: 11522.25 +/- 13333.79
Bad:  10115.27 +/-  9552.84
All:  10818.76 +/- 11619.74


In [22]:
dL_da[mask_good].numpy()

array([ 12736.823  ,   9434.081  ,   4342.021  ,   9732.37   ,
        55865.918  ,   3188.746  , -62406.977  ,  58793.027  ,
        28065.217  ,   4005.748  ,  29271.057  ,   2095.1665 ,
        18635.027  ,  40190.324  ,  14777.196  ,  60594.273  ,
       -43584.016  ,   2572.2803 ,  14333.315  ,  48017.93   ,
        -4918.378  , -43182.9    ,  63662.938  , -22469.154  ,
         6636.6226 , -59049.816  ,  -1640.0939 , -41446.04   ,
        30715.836  ,  10133.0625 ,    994.10504,  13273.74   ],
      dtype=float32)

In [23]:
dL_da[mask_bad].numpy()

array([ -70809.67  ,   -4852.3994,  -32782.727 ,  -63238.63  ,
        -75701.484 ,  -25549.617 ,  -41343.71  , -114343.66  ,
        -33651.59  ,  -23140.73  ,   74283.74  ,  -79540.88  ,
         19228.627 ,   84720.1   ,  -22723.934 ,  161927.31  ,
        -21348.568 ,  -87418.07  ,  -56703.645 ,     806.107 ,
         29640.346 ,   11769.445 ,   -9194.842 ,  -18312.852 ,
         12403.965 ,    3902.699 ,  -63047.746 ,  -51807.117 ,
          3704.0317,   39822.844 ,   69676.68  ,  186555.12  ],
      dtype=float32)

## Train Model

In [24]:
# Train model
steps_per_epoch = steps*cycles_per_epoch
rows_per_epoch = steps_per_epoch * time_batch_size
print_header(f'Training for {epochs} Epochs of Size {rows_per_epoch} Observation Days (Cycles/Epoch = {cycles_per_epoch})...')
print(f'R (degrees):    {R_deg:5.1f}')
print(f'R_is_trainable: {R_is_trainable}')
print(f'alpha:          {alpha:5.1f}')
print(f'beta:           {beta:5.1f}')
print(f'learning_rate:  {learning_rate:7.2e}')
print(f'clipvalue:      {clipvalue:5.2f}')

train_time_0 = time.time()
hist = model.fit(ds, epochs=epochs, steps_per_epoch=steps_per_epoch)
train_time_1 = time.time()
train_time = train_time_1 - train_time_0
print(f'Elapsed Time: {str(timedelta(seconds=train_time))}')

# Report results
if display:
    print_header('Model After Training:')
pred1 = model.predict_on_batch(ds)
elts1, R1, u_pred1, z1, _ = pred1
scores1, traj_err1, elt_err1 = \
    report_model(model=model, ds=ds, R_deg=R_deg, mask_good=mask_good, 
                 batch_size=elt_batch_size, steps=steps, elts_true=elts_true, display=display)

# Unpack scores after training
raw_score = scores1[:,0]
mu = scores1[:,1]
sigma2 = scores1[:,2]
objective = scores1[:,3]
sigma = scores1[:, 4]
t_score = scores1[:, 5]
eff_obs = raw_score - mu

## Change in scores
d_scores = scores1 - scores0
d_traj_err = traj_err1 - traj_err0
d_elt_err = elt_err1 - elt_err0
d_R = R1 - R0

# Report training progress: scores, orbital element errors, and resolution
print_header('Progress During Training:')
scores_01 = (scores0, scores1)
traj_err_01 = (traj_err0, traj_err1)
elt_err_01 = (elt_err0, elt_err1)
R_01 = (R0, R1)
report_training_progress(scores_01, traj_err_01, elt_err_01, R_01, mask_good)


********************************************************************************
Training for 10 Epochs of Size 23710 Observation Days (Cycles/Epoch = 10)...
********************************************************************************
R (degrees):      1.0
R_is_trainable: False
alpha:            1.0
beta:             0.0
learning_rate:  2.00e-05
clipvalue:       5.00
Train for 10 steps
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Elapsed Time: 0:00:12.616490

********************************************************************************
Progress During Training:
********************************************************************************

Mean & Std Change in Trajectory Error (AU) by Category:
Good: +0.007125 +/- 0.004867 -- from 0.000000 to 0.007125     (+1653765.8%)
Bad:  +0.002501 +/- 0.005611 -- from 0.003529 to 0.006029     (+70.9%)
All:  +0.004813 +/- 0.005739 -- from 0.001764 to 0.006577     (+272.8%)



## Test Different Resolutions

In [25]:
thresh_deg = 1.0
R_is_trainable = False
elt_batch_size = 64
time_batch_size = None
epochs = 10
cycles_per_epoch = 10

In [26]:
R_deg_cand = np.power(2.0, -np.arange(0.0, 8.0))
R_deg_cand

array([1.       , 0.5      , 0.25     , 0.125    , 0.0625   , 0.03125  ,
       0.015625 , 0.0078125])

In [27]:
# Results table
scores_tbl = dict()
traj_err_tbl = dict()
elt_err_tbl = dict()
R_tbl = dict()

In [28]:
# # Iterate through different resolutions
# R_deg_cand = np.power(2.0, -np.arange(0.0, 8.0))
# for R_deg in R_deg_cand:
#     # Generate scores
#     scores_01, traj_err_01, elt_err_01, R_01, mask_good = \
#         test_easy_batch(R_deg=R_deg, 
#                         R_is_trainable=R_is_trainable,
#                         elt_batch_size=elt_batch_size,
#                         time_batch_size=time_batch_size,
#                         epochs=epochs,
#                         cycles_per_epoch=cycles_per_epoch)
#     # Save to table
#     scores_tbl[R_deg] = scores_01
#     traj_err_tbl[R_deg] = traj_err_01
#     elt_err_tbl[R_deg] = elt_err_01
#     R_tbl[R_deg] = R_01

In [29]:
scores_01[0].shape

(64, 6)