This notebook reproduces the nonlinear Pendulum results from [Champion et. al](https://www.pnas.org/doi/full/10.1073/pnas.1906995116). The data generation is specified in the  [appendix](https://www.pnas.org/action/downloadSupplement?doi=10.1073%2Fpnas.1906995116&file=pnas.1906995116.sapp.pdf) of [Champion et. al](https://www.pnas.org/doi/full/10.1073/pnas.1906995116), and is restated here. 






The nonlinear pendulum equation is given by

\begin{equation}
\ddot{z} = -sin(z)
\end{equation}

Here $z$ denotes the angle between the vertical and the pendulum

As per the appendix "We generate synthetic video of the pendulum in two spatial dimensions by creating high-dimensional snapshots given by"

$$
x(y_1, y_2, t) = exp \left( -20 \left(  (y_1 - cos(z(t)) - \pi/2)\right)^2 + (y_2  - sin(z(t)) - \pi/2)^2 \right)
$$


"at a discretization of $y_1, y_2 \in [−1.5, 1.5]$ . We use 51 grid points in each dimension resulting in snapshots $x(t) \in \mathbb{R}^{2601}$. To generate
a training set, we simulate Eq. (1) from 100 randomly chosen initial conditions with $z(0) \in [-\pi, \pi]$ and $\dot{z(0)} \in [−2.1, 2.1]$. The
156 initial conditions are selected from a uniform distribution in the specified range but are restricted to conditions for which the
pendulum does not have enough energy to do a full loop. This condition is determined by checking that $|\dot{z}(0)^2/2−cos z(0)| ≤ 0.99$."

An animation of this data for one inital condition can be seen in the pendulum_animation.gif

In [None]:
#Check if jax is using GPU
import jax
from jax.lib import xla_bridge

print(f"JAX is using: {xla_bridge.get_backend().platform}")
devices = jax.devices()
print(f"Number of devices: {len(devices)}")
for device in devices:
    print(device)


In [None]:
from pendulumData import get_pendulum_data
import sys
sys.path.append('../')
from data_utils import create_jax_batches_factory

create_jax_batches = create_jax_batches_factory(second_order=True)


# Set up training and validation data sets as arrays
n_ics_training = 100
n_ics_validation = 20

noise_strength = 1e-6
batch_size_training = 1024
batch_size_validation = 1024

training_data = get_pendulum_data(n_ics_training, noise_strength)
train_loader = create_jax_batches(training_data, batch_size_training)


validation_data = get_pendulum_data(n_ics_validation, noise_strength)
validation_loader = create_jax_batches(validation_data, batch_size_validation)


In [None]:
# Define hyperparameters
input_dim = 128
latent_dim = 3
poly_order = 3
widths = [128, 64, 32]

initial_epochs = 10001
final_epochs = 1001

# Get example input from training_data loader
x, dx = train_loader[0]

# Define hyperparameters dictionary
hparams = {
    'input_dim': input_dim,
    'latent_dim': latent_dim,
    'poly_order': poly_order,
    'widths': widths,
    'activation': 'sigmoid',
    'weight_initializer': 'xavier_uniform',
    'bias_initializer': 'zeros',
    'optimizer_hparams': {'optimizer': "adam"},
    'include_sine': True, #important  
    'loss_weights': (1, 1e-4, 1e-5, 1e-5),  # Note different weights than Lorenz
    'seed': 42, #importiante 
    'update_mask_every_n_epoch': 500,
    'coefficient_threshold': 0.1,
    'regularization': True,  
    'second_order': True,   #second order True WIIIIIIII
    'include_constant': True  
}

# Define other parameters dictionary
trainer_params = {
    'exmp_input': x,
    'logger_params': {},
    'enable_progress_bar': True,
    'debug': False,
    'check_val_every_n_epoch': 400
}

In [None]:
from trainer import SINDy_trainer

# Merge dictionaries
params = {**hparams, **trainer_params}

# Initialize trainer
trainer = SINDy_trainer(**params)

trainer.train_model(train_loader, validation_loader, num_epochs=10001, final_epochs=1001)
