# Three Body Problem: Neural Network Training

A simple neural network is trained to learn approximate solutions to the three body problem.<br>
Data is sampled from three body systems with parameters similar to planets in the solar system.

In [1]:
# Library imports
import tensorflow as tf
import rebound
import numpy as np
import datetime
import matplotlib.pyplot as plt

# Aliases
keras = tf.keras

In [2]:
# Local imports
from utils import load_vartbl, save_vartbl, plot_style, range_inc
from tf_utils import gpu_grow_memory, TimeHistory
from tf_utils import plot_loss_hist, EpochLoss, TimeHistory
from tf_utils import Identity

from orbital_element import OrbitalElementToConfig, ConfigToOrbitalElement, MeanToTrueAnomaly, G_
from orbital_element import make_model_elt_to_cfg, make_model_cfg_to_elt

from jacobi import CartesianToJacobi, JacobiToCartesian

from g3b_data import make_traj_g3b, make_data_g3b, make_datasets_g3b, traj_to_batch
from g3b_data import make_datasets_solar, make_datasets_hard
from g3b_data import combine_datasets_g3b, combine_datasets_solar
from sej_data import load_data_sej, make_datasets_sej, combine_datasets_sej

from g3b_plot import plot_orbit_q, plot_orbit_v, plot_orbit_a, plot_orbit_energy, plot_orbit_element
from g3b import KineticEnergy_G3B, PotentialEnergy_G3B, Momentum_G3B, AngularMomentum_G3B
from g3b import VectorError, EnergyError
from g3b import Motion_G3B, make_physics_model_g3b
from g3b import fit_model
from g3b_model_math import make_position_model_g3b_math, make_model_g3b_math
from g3b_model_nn import make_position_model_g3b_nn, make_model_g3b_nn

In [3]:
# Set active GPUs
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[1:2], 'GPU')

In [4]:
# Grow GPU memory (must be first operation in TF)
# gpu_grow_memory()

In [5]:
# Lightweight serialization
fname = '../data/g3b/g3b_train.pickle'
vartbl = load_vartbl(fname)

In [6]:
# Set plot style
plot_style()

### Load Data for General Three Body Problem

In [7]:
# Description of datasets to be loaded
n_years = 100
sample_freq = 10
traj_size = n_years * sample_freq + 1

In [8]:
# Configuration for loading data sets
# num_data_sets = 50
num_data_sets = 5
batch_size = 256
# num_gpus = 1
# full_batch_size = num_gpus * batch_size

# Set size of tiny data sets
n_traj_tiny = batch_size

# Set starting random seed
seed0 = 42

In [9]:
# Create a tiny data set with one batch of solar type orbits
ds_tiny_solar, _ , _ = \
    make_datasets_solar(n_traj=n_traj_tiny, vt_split=0.0, 
                        n_years=n_years, sample_freq=sample_freq,
                        batch_size=batch_size, seed=seed0)

Loaded data from ../data/g3b/1789961721.pickle.


W0821 11:16:47.555900 139880889796416 deprecation.py:323] From /home/michael/anaconda3/envs/nbody/lib/python3.7/site-packages/tensorflow/python/data/util/random_seed.py:58: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


In [10]:
# Build combined solar data sets
# ds_solar_trn, ds_solar_val, ds_solar_tst = \
#     combine_datasets_solar(num_data_sets=num_data_sets, batch_size=batch_size, seed0=seed0)

### Load Data for Perturbed Sun-Earth-Jupiter System

In [11]:
# Orbital perturbation scales
sd_log_a = 0.01
sd_log_e = 0.10
sd_log_inc = 0.10
sd_Omega = np.pi * 0.02
sd_omega = np.pi * 0.02
sd_f = np.pi * 0.02

# Wrap into dictionary
sej_sigma = {
    'sd_log_a': sd_log_a,
    'sd_log_e': sd_log_e,
    'sd_log_inc': sd_log_inc,
    'sd_Omega': sd_Omega,
    'sd_omega': sd_omega,
    'sd_f': sd_f
}

In [12]:
# Create a tiny data set with one batch of perturbed SEJ orbits
ds_tiny_sej, _ , _ = \
    make_datasets_sej(n_traj=n_traj_tiny, vt_split=0.0, n_years=n_years, sample_freq=sample_freq,
                      **sej_sigma,
                      batch_size=batch_size, seed=seed0)

Loaded data from ../data/sej/1026452775.pickle.


In [13]:
# Create dictionary for sigmas of unperturbed orbits: all sd are zero (always same elements)
sej_sigma0 = {k: v*0.0 for k, v in sej_sigma.items()}

# Create a tiny data set with the unperturbed SEJ system
ds_sej0, _, _ = \
    make_datasets_sej(n_traj=n_traj_tiny, vt_split=0.0, n_years=n_years, sample_freq=sample_freq,
                      **sej_sigma0,
                      batch_size=batch_size, seed=0)

Loaded data from ../data/sej/3203691191.pickle.


In [14]:
# Build combined SEJ data sets
ds_sej_trn, ds_sej_val, ds_sej_tst = \
    combine_datasets_sej(num_data_sets=num_data_sets, batch_size=batch_size, seed0=seed0)

Loaded data from ../data/sej/4087833051.pickle.


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

Loaded data from ../data/sej/3169253148.pickle.
Loaded data from ../data/sej/155748689.pickle.
Loaded data from ../data/sej/3645397039.pickle.
Loaded data from ../data/sej/2755636330.pickle.



***Choose Data Set for Analysis: Solar vs. SEJ***

In [15]:
# Alias ds_trn, ds_val, ds_tst to selected source

# The selected data type for this analysis
data_type = 'SEJ'

# Tables mapping data type to tuple of data sets
data_by_type = {
    # 'solar': (ds_tiny_solar, ds_solar_trn, ds_solar_val, ds_solar_tst),
    'SEJ': (ds_tiny_sej, ds_sej_trn, ds_sej_val, ds_sej_tst)
}

# Perform the aliasing
ds_tiny, ds_trn, ds_val, ds_tst = data_by_type[data_type]

### Create the Kepler-Jacobi Model as a Benchmark

In [16]:
model_kj = make_model_g3b_math(traj_size=traj_size, batch_size=batch_size)

In [17]:
optimizer = keras.optimizers.Adam(learning_rate=0.0)

loss = {'q': VectorError(name='q_loss'),
        'v': VectorError(name='v_loss'),
        'a': VectorError(regularizer=1.0, name='a_loss'),
        'q0_rec': VectorError(name='q0_loss'),
        'v0_rec': VectorError(name='v0_loss'),
        'H': EnergyError(name='H_loss'),
        'P': VectorError(name='P_loss', regularizer=1.0E-6),
        'L': VectorError(name='L_loss'),
       }

metrics = None

loss_weights = {'q': 1.0,
                'v': 1.0,
                'a': 1.0,
                'q0_rec': 1.0E4,
                'v0_rec': 1.0E4,
                'H': 1.0,
                'P': 1.0,
                'L': 1.0}

In [18]:
# Compile the full mathematical model
model_kj.compile(optimizer=optimizer, loss=loss, metrics=metrics, loss_weights=loss_weights)

In [19]:
# Evaluate KJ model on unperturbed SEJ data set
model_kj.evaluate(ds_sej0)



[6.651984585914761e-05,
 1.9148267e-05,
 1.9135909e-05,
 2.8235167e-05,
 4.116526e-14,
 9.423976e-15,
 2.2974615e-13,
 4.1866443e-14,
 7.2123155e-14]

In [20]:
# Evaluate KJ model on tiny data set
model_kj.evaluate(ds_tiny)



[6.427881453419104e-05,
 1.8427074e-05,
 1.8414992e-05,
 2.7436467e-05,
 1.4878642e-14,
 1.3017593e-14,
 9.33544e-14,
 3.877194e-14,
 2.0240327e-14]

In [21]:
# Evaluate KJ model on full validation data
model_kj.evaluate(ds_val)



[6.587040261365473e-05,
 1.8911158e-05,
 1.8898823e-05,
 2.806015e-05,
 1.4462772e-14,
 1.2528137e-14,
 9.376004e-14,
 3.876589e-14,
 1.9768472e-14]

### Train the Neural Network Model

In [22]:
# Configuration for neural network model architecture
# hidden_sizes = [64, 16]
hidden_sizes = []
skip_layers = True
traj_size = 1001

# Training configuration
reg = 1.0E2
kernel_reg = reg
activity_reg = reg
learning_rate = 1.0E-4

In [23]:
# Build neural network model
model_nn = make_model_g3b_nn(hidden_sizes=hidden_sizes, skip_layers=skip_layers, 
                             kernel_reg=kernel_reg, activity_reg=activity_reg,
                             traj_size=traj_size, batch_size=batch_size)

In [24]:
# model_nn.summary()

In [25]:
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
# optimizer = keras.optimizers.Adagrad(learning_rate=learning_rate)
# optimizer = keras.optimizers.Adadelta()

In [26]:
# Compile the NN model
model_nn.compile(optimizer=optimizer, loss=loss, metrics=metrics, loss_weights=loss_weights)

In [27]:
# Evaluate the NN model on the tiny data set
model_nn.evaluate(ds_tiny)



[6.427881453419104e-05,
 1.8427074e-05,
 1.8414992e-05,
 2.7436467e-05,
 1.4878642e-14,
 1.3017593e-14,
 9.33544e-14,
 3.877194e-14,
 2.0240327e-14]

In [28]:
# Evaluate the NN model on the full validation data
model_nn.evaluate(ds_val)



[6.587040261365473e-05,
 1.8911158e-05,
 1.8898823e-05,
 2.806015e-05,
 1.4462772e-14,
 1.2528137e-14,
 9.376004e-14,
 3.876589e-14,
 1.9768472e-14]

In [29]:
# Compare this to math model - should be the same before training
# model_math.evaluate(ds_val)

# Hard coded baseline losses
loss_baseline_list = \
[6.427881453419104e-05,
 1.8427074e-05,
 1.8414992e-05,
 2.7436467e-05,
 1.4878642e-14,
 1.3017593e-14,
 9.33544e-14,
 3.877194e-14,
 2.0240327e-14]

# Baseline position loss
q_loss_baseline = loss_baseline_list[1]

# Table of baseline losses
keys = ['loss', 'q_loss', 'v_loss', 'a_loss', 'q0_rec_loss', 'v0_rec_loss', 'H_loss', 'P_loss', 'L_loss']
loss_baseline = {key: loss_baseline_list[i] for i, key in enumerate(keys)}
# Set dummy batch_num and time
loss_baseline['batch_num'] = 0
loss_baseline['time'] = 0.0

# Initialize history before training
hist0 = {key: np.array([val], dtype=np.float32) for key, val in loss_baseline.items()}

# Review baseline loss table
# loss_baseline

In [30]:
# Set up training
suffix = '_'.join(str(sz) for sz in hidden_sizes)
if data_type == 'solar':
    model_name = f'model_g3b_nn_{suffix}'
    folder = 'g3b'
elif data_type == 'SEJ':
    model_name = f'model_sej_nn_{suffix}'
    folder = 'sej'
model_h5 = f'../models/g3b/{model_name}.h5'
hist_name = model_name.replace('model_', 'hist_')
epochs = 1
save_freq = 'epoch'

In [31]:
hist = fit_model(model=model_nn,
                 folder=folder,
                 ds=ds_tiny, 
                 epochs=epochs,
                 save_freq=save_freq,
                 prev_history = hist0, 
                 batch_num=1)


Epoch 0001; loss 6.43e-05; elapsed 0:00:06


In [32]:
# Attempt to load model or train a single epoch
try:
    model_nn.load_weights(model_h5)
    model_nn.compile(loss=loss, optimizer=optimizer, metrics=metrics, loss_weights=loss_weights)
    hist = vartbl[hist_name]
    print(f'Loaded {model_name} from {model_h5}.')
except:
    print(f'Unable to load {model_name} from {model_h5}. Fitting...')
    hist = fit_model(model=model_nn,
                     folder=folder,
                     # ds=ds_trn, 
                     ds=ds_tiny, 
                     epochs=epochs,
                     save_freq=save_freq,
                     prev_history = hist0, 
                     batch_num=1)
    vartbl[hist_name] = hist
    save_vartbl(vartbl, fname)

Unable to load model_sej_nn_ from ../models/g3b/model_sej_nn_.h5. Fitting...

Epoch 0001; loss 6.52e+03; elapsed 0:00:00


In [None]:
# num_epochs = 50
num_epochs = 1
for i in range_inc(1, num_epochs):
    ts = datetime.datetime.now()
    st = ts.strftime('%Y-%m-%d %H:%M:%S')
    print(f'*** Training loop {i:3} *** - {st}')
    hist = fit_model(model=model_nn,
                     folder=folder,
                     ds=ds_trn, 
                     epochs=epochs,
                     loss=loss, 
                     optimizer=optimizer,
                     metrics=metrics,
                     save_freq=save_freq,
                     prev_history = hist, 
                     batch_num=i)
    vartbl[hist_name] = hist
    save_vartbl(vartbl, fname)

In [None]:
# Plot the loss
fig, ax = plot_loss_hist(hist=hist, model_name=model_nn.name, key='q_loss', baseline=q_loss_baseline)

In [None]:
# Plot the total loss
fig, ax = plot_loss_hist(hist=hist, model_name=model_nn.name, key='loss')

In [None]:
hist

In [None]:
# Evaluate the trained model on the training data
# model_nn.evaluate(ds_trn)

In [None]:
# Evaluate the trained model on the test data
# model_nn.evaluate(ds_tst)

## Perturbed Sun-Earth-Jupiter System

In [None]:
from sej_data import load_data_sej, make_datasets_sej, combine_datasets_sej
import numpy as np
from g3b_plot import plot_orbit_q, plot_orbit_v, plot_orbit_a, plot_orbit_energy, plot_orbit_element
import matplotlib.pyplot as plt

In [None]:
# Trajectory length
n_years = 100
sample_freq = 10

# Number of trajectories
num_batches = 1
n_traj = 100
vt_split = 0.20
batch_size = 64

In [None]:
# Orbital perturbation scales
sd_log_a = 0.01
sd_log_e = 0.10
sd_log_inc = 0.10
sd_Omega = np.pi * 0.02
sd_omega = np.pi * 0.02
sd_f = np.pi * 0.02

In [None]:
# List of seeds to use for datasets
seed0 = 42
seed1 = seed0 + num_batches * 3
seeds = list(range(seed0, seed1, 3))

In [None]:
data = load_data_sej(n_traj=n_traj, vt_split=vt_split, n_years=n_years, sample_freq=sample_freq,
                     sd_log_a=sd_log_a, sd_log_e=sd_log_e, sd_log_inc=sd_log_inc,
                     sd_Omega=sd_Omega, sd_omega=sd_omega, sd_f=sd_f, seed=seed0)

In [None]:
inputs_trn, outputs_trn = data[0:2]

data_trn = {**inputs_trn, **outputs_trn}

In [None]:
inputs_trn['t'].shape

In [None]:
outputs_trn['q'].shape

In [None]:
plt.plot(inputs_trn['t'][0], outputs_trn['q'][0][:, 1, 0])

In [None]:
plt.plot(inputs_trn['t'][0], outputs_trn['q'][0][:, 1, 1])

In [None]:
np.mean(outputs_trn['orb_a'], axis=(0))

In [None]:
outputs_trn['orb_a'].shape

In [None]:
np.mean(outputs_trn['orb_a'][:, 0, :], axis=0)

In [None]:
np.std(outputs_trn['orb_a'][:, 0, :], axis=0)

In [None]:
np.mean(outputs_trn['orb_e'][:, 0, :], axis=0)

In [None]:
np.std(outputs_trn['orb_e'][:, 0, :], axis=0)