In [1]:
# Library imports
import tensorflow as tf
import rebound
import numpy as np
import datetime
import matplotlib.pyplot as plt

# Aliases
keras = tf.keras

In [2]:
# Local imports
from utils import load_vartbl, save_vartbl, plot_style
from tf_utils import gpu_grow_memory, TimeHistory
from tf_utils import plot_loss_hist, EpochLoss, TimeHistory
from tf_utils import Identity

from orbital_element import OrbitalElementToConfig, ConfigToOrbitalElement, MeanToTrueAnomaly, G_
from orbital_element import make_model_elt_to_cfg, make_model_cfg_to_elt

from jacobi import CartesianToJacobi, JacobiToCartesian

from g3b_data import make_traj_g3b, make_data_g3b, make_datasets_g3b
from g3b_data import make_filename_g3b, load_data_g3b
from g3b_data import make_datasets_solar, make_datasets_hard
from g3b_plot import plot_orbit_q, plot_orbit_v, plot_orbit_a, plot_orbit_energy, plot_orbit_element
from g3b import KineticEnergy_G3B, PotentialEnergy_G3B, Momentum_G3B, AngularMomentum_G3B
from g3b import VectorError, EnergyError
from g3b import Motion_G3B, make_physics_model_g3b
from g3b import compile_and_fit
from g3b_model_math import make_position_model_g3b_math, make_model_g3b_math
# from g3b_model_nn import make_position_model_g3b_nn, make_model_g3b_nn

In [3]:
# Grow GPU memory (must be first operation in TF)
gpu_grow_memory()

### Load Data and Build Position Model

In [4]:
# Inputs for make_datasets_g3b
vt_split = 0.20
n_years = 100
sample_freq = 10
traj_size = n_years * sample_freq + 1

# The batch size
batch_size = 64

In [5]:
m = np.array([1.0, 1.0E-6, 1.0E-3])
a = np.array([1.00000, 5.2029])
e = np.array([0.0167, 0.0484])
inc = np.radians([0.00, 1.3044])
Omega = np.radians([0.00, 100.47])
omega = np.radians([102.94, 14.73]) - Omega
f = np.radians([100.46, 34.40]) - Omega
n_years = 100
sample_freq = 10

inputs_traj, outputs_traj = make_traj_g3b(m=m, a=a, e=e, inc=inc, Omega=Omega, omega=omega, f=f, 
                                          n_years=n_years, sample_freq=sample_freq)

In [6]:
inputs_traj['q0']

array([[-3.1754794e-03,  3.9717490e-03,  5.4689841e-05],
       [-9.2346436e-01, -3.9427280e-01,  5.4689841e-05],
       [ 3.1764028e+00, -3.9713547e+00, -5.4689895e-02]], dtype=float32)

In [7]:
outputs_traj['q'][0]

array([[-3.1754794e-03,  3.9717490e-03,  5.4689841e-05],
       [-9.2346436e-01, -3.9427280e-01,  5.4689841e-05],
       [ 3.1764028e+00, -3.9713547e+00, -5.4689895e-02]], dtype=float32)

In [8]:
outputs_traj['orb_a'][0:3]

array([[1.       , 5.2029   ],
       [0.999998 , 5.2029   ],
       [1.0000111, 5.2029   ]], dtype=float32)

In [11]:
# outputs_traj['orb_e'][0:3]

In [12]:
outputs_traj['orb_f'][0:3]

array([[-4.5298276, -1.153139 ],
       [ 2.37244  , -1.0977441],
       [-3.3007643, -1.0420896]], dtype=float32)

In [13]:
# Create a tiny data set with 10,000 solar type orbits
n_traj = 100

ds_small_trn, ds_small_val, ds_small_tst = make_datasets_solar(n_traj=n_traj, vt_split=vt_split, 
                                                               n_years=n_years, sample_freq=sample_freq)

Loaded data from ../data/g3b/2509250945.pickle.


In [14]:
# # Create a medium data set with 10,000 solar type orbits
# n_traj = 10000

# ds_small_trn, ds_small_val, ds_small_tst = make_datasets_solar(n_traj=n_traj, vt_split=vt_split, 
#                                                                n_years=n_years, sample_freq=sample_freq)

In [15]:
# Example batch
batch_in, batch_out = list(ds_small_trn.take(1))[0]
print('Input field names: ', list(batch_in.keys()))
print('Output field names:', list(batch_out.keys()))

t = batch_in['t']
q0 = batch_in['q0']
v0 = batch_in['v0']
m = batch_in['m']

q = batch_out['q']
v = batch_out['v']
a = batch_out['a']
q0_rec = batch_out['q0_rec']
v0_rec = batch_out['v0_rec']
H = batch_out['H']
P = batch_out['P']
L = batch_out['L']

print(f'\nExample batch sizes:')
print(f't  = {t.shape}')
print(f'q0 = {q0.shape}')
print(f'v0 = {v0.shape}')
print(f'm  = {m.shape}')

print(f'q  = {q.shape}')
print(f'v  = {v.shape}')
print(f'a  = {a.shape}')
# print(f'q0_rec = {q0_rec.shape}')
# print(f'v0_rec = {v0_rec.shape}')
print(f'H  = {H.shape}')
print(f'P  = {P.shape}')
print(f'L  = {L.shape}')

Input field names:  ['t', 'q0', 'v0', 'm']
Output field names: ['q', 'v', 'a', 'orb_a', 'orb_e', 'orb_inc', 'orb_Omega', 'orb_omega', 'orb_f', 'q0_rec', 'v0_rec', 'T', 'U', 'H', 'P', 'L']

Example batch sizes:
t  = (64, 1001)
q0 = (64, 3, 3)
v0 = (64, 3, 3)
m  = (64, 3)
q  = (64, 1001, 3, 3)
v  = (64, 1001, 3, 3)
a  = (64, 1001, 3, 3)
H  = (64, 1001)
P  = (64, 1001, 3)
L  = (64, 1001, 3)


In [16]:
model = make_position_model_g3b_math(traj_size=traj_size, batch_size=batch_size)

In [17]:
optimizer = keras.optimizers.Adam(learning_rate=1.0E-3)

loss = {'q': VectorError(name='q_loss'),
        'v': VectorError(name='v_loss'),
       }

metrics = None

loss_weights = {'q': 1.0,
                'v': 1.0}

In [18]:
# Compile the  model
model.compile(optimizer=optimizer, loss=loss, metrics=metrics, loss_weights=loss_weights)

### Find batch and row with NAN Output

In [None]:
loss_hist = []
for i, ds_i in enumerate(ds_small_trn):
    batch_in, batch_out = ds_i
    loss_i = model.test_on_batch(batch_in, batch_out)
    loss_hist.append(loss_i)
    print(f'Loss on batch {i} = {loss_i[0]:5.2e}')
    if np.isnan(loss_i[0]):
        print(f'Loss at i={i} is NAN.  Saving inputs for debugging.')
        # Inputs (from batch)
        t, q0, v0, m = batch_in
        t = batch_in['t']
        q0 = batch_in['q0']
        v0 = batch_in['v0']
        m = batch_in['m']

        # Outputs (from batch)
        q = batch_out['q']
        v = batch_out['v']

        # Outputs (from model)
        model_out = model.predict_on_batch([t, q0, v0, m])
        q_out, v_out = model_out
       
        break

# Convert to loss_hist to numpy array
loss_hist = np.array(loss_hist)

In [None]:
# Verify error
model.test_on_batch([t, q0, v0, m], [q, v])

In [None]:
# Run model in numpy mode
q_np, v_np = model([t, q0, v0, m])

In [None]:
idx = np.where(np.isnan(np.mean(q_out, axis=(1,2,3))))[0][0]
idx

In [None]:
q_out[idx][0]

In [None]:
q_out[idx+1][0]

In [None]:
q_np[idx][0]

In [None]:
# Difference between numpy and regular mode usually very small
diff = q_out[idx+1] - q_np[idx+1]
np.mean(np.abs(diff_1))

In [None]:
q0[idx]

In [None]:
v0[idx]

In [None]:
qj0, vj0, mu0 = CartesianToJacobi()([m, q0, v0])

In [None]:
qj0[idx]

In [None]:
vj0[idx]

In [None]:
num_particles = 3
space_dims = 3

# Extract Jacobi coordinates of p1 and p2
qj0_1 = qj0[:, 1, :]
vj0_1 = vj0[:, 1, :]
qj0_2 = qj0[:, 2, :]
vj0_2 = vj0[:, 2, :]

# Extract gravitational field strength for orbital element conversion of p1 and p2
mu0_1 = mu0[:, 1:2]
mu0_2 = mu0[:, 2:3]

# Manually set the shapes to work around documented bug on slices losing shape info
jacobi_shape = (batch_size, space_dims)
qj0_1.set_shape(jacobi_shape)
qj0_2.set_shape(jacobi_shape)
vj0_1.set_shape(jacobi_shape)
vj0_1.set_shape(jacobi_shape)
mu_shape = (batch_size, 1)
mu0_1.set_shape(mu_shape)
mu0_2.set_shape(mu_shape)

In [None]:
# Tuple of inputs for the model converting from configuration to orbital elements
cfg_1 = (qj0_1, vj0_1, mu0_1)
cfg_2 = (qj0_2, vj0_2, mu0_2)

# Model mapping cartesian coordinates to orbital elements
model_c2e = make_model_cfg_to_elt()

# Extract the orbital elements of the initial conditions
a1_0, e1_0, inc1_0, Omega1_0, omega1_0, f1_0, M1_0, N1_0 = model_c2e(cfg_1)
a2_0, e2_0, inc2_0, Omega2_0, omega2_0, f2_0, M2_0, N2_0 = model_c2e(cfg_2)

In [None]:
print(f'Initial Orbital Elements - Body 1:')
print(f'a    ={a1_0[idx][0]:10f}, e    ={e1_0[idx][0]:10f}, inc  ={inc1_0[idx][0]:10f}')
print(f'Omega={Omega1_0[idx][0]:10f}, omega={omega1_0[idx][0]:10f}, f    ={f1_0[idx][0]:10f}')

In [None]:
print(f'Initial Orbital Elements - Body 2:')
print(f'a    ={a2_0[idx][0]:10f}, e    ={e2_0[idx][0]:10f}, inc  ={inc2_0[idx][0]:10f}')
print(f'Omega={Omega2_0[idx][0]:10f}, omega={omega2_0[idx][0]:10f}, f    ={f2_0[idx][0]:10f}')

### Review Original Input Data as Orbital Elements

In [19]:
vt_split = 0.20
m_min = 1.0E-7 
m_max = 2.0E-3 
a_min = 0.50
a_max = 32.0
e_max = 0.08
inc_max = 0.04 
seed=42

data = load_data_g3b(n_traj=n_traj, vt_split=vt_split, n_years=n_years, sample_freq=sample_freq, 
                     m_min=m_min, m_max=m_max, a_min=a_min, a_max=a_max, e_max=e_max, inc_max=inc_max,
                     seed=seed)

Loaded data from ../data/g3b/2509250945.pickle.


In [20]:
inputs_trn, outputs_trn, _3, _4, _5, _6 = data

In [21]:
inputs_trn.keys()

dict_keys(['t', 'q0', 'v0', 'm'])

In [22]:
outputs_trn.keys()

dict_keys(['q', 'v', 'a', 'orb_a', 'orb_e', 'orb_inc', 'orb_Omega', 'orb_omega', 'orb_f', 'q0_rec', 'v0_rec', 'T', 'U', 'H', 'P', 'L'])

In [24]:
i=23

In [25]:
inputs_trn['q0'][i]

array([[-4.3338328e-05,  2.0454632e-05, -1.0817035e-06],
       [ 1.3953951e+01, -4.5750203e+00,  2.6270902e-01],
       [-3.0916464e+01, -1.0864068e+01,  3.1162494e-01]], dtype=float32)

In [26]:
outputs_trn['orb_a'].shape

(100, 1001, 2)

In [29]:
outputs_trn['orb_a'][0:10, 0, :]

array([[ 3.150409  , 20.723997  ],
       [ 5.5913043 , 28.804457  ],
       [ 0.7897071 , 19.602516  ],
       [ 3.6963537 , 21.400305  ],
       [ 0.65943986,  5.5654535 ],
       [17.785114  , 22.294699  ],
       [ 7.564483  , 21.03678   ],
       [ 7.973346  , 22.933645  ],
       [10.750091  , 24.014479  ],
       [20.963436  , 27.250538  ]], dtype=float32)

In [None]:
inputs_trn['q0'][0:10][0]