Copyright 2022 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");

In [None]:
#@title License
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Action-Angle Networks



In [None]:
example_directory = 'action_angle_networks'
editor_relpaths = ('configs/action_angle_flow.py', 'shm_simulation.py', 'models.py', 'train.py')

repo, branch = 'https://github.com/google-research/google-research', 'master'
# (If you run this code in Jupyter[lab], then you're already in the
#  example directory and nothing needs to be done.)

#@title Fetch source code
if 'google.colab' in str(get_ipython()):
  import os
  os.chdir('/content')
  # Download source repo from Github.
  if not os.path.isdir('s'):
    !git clone --depth=1 -b $branch $repo sourcerepo
  # Copy example files & change directory.
  example_root_path = f'/content/{example_directory}'
  if not os.path.isdir(example_root_path):
    os.makedirs(example_root_path)
    !cp -r sourcerepo/$example_directory/* "$example_root_path"
  os.chdir(example_root_path)
  from google.colab import files
  for relpath in editor_relpaths:
    files.view(f'{example_root_path}/{relpath}')

In [None]:
!pwd

In [None]:
!pip install -r requirements.txt

In [None]:
#@title Base Imports
from typing import *
import functools
import sys
import tempfile

import collections
import chex
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.core import frozen_dict
from flax.training import train_state
import optax
import distrax
import tensorflow as tf
import ml_collections
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib
import sklearn.preprocessing
matplotlib.rc('animation', html='jshtml')

In [None]:
#@title Source Imports
%load_ext autoreload
%autoreload 2
import chm_simulation
import models
import shm_simulation
import train
from configs import action_angle_flow

In [None]:
%load_ext tensorboard

## The Hamiltonian

The Hamiltonian describes the 'energy' of a system in terms of
the positions $q$ and momenta $p$ of the particles.

For example, in the case of a simple harmonic oscillator, we have:

$$
H(q, p) = \frac{p^2}{2m} + \frac{m\omega^2q^2}{2}
$$

The time evolution of any system with Hamiltonian $H(q, p)$ is given by
Hamilton's equations:

$$
\begin{align}
\frac{dq}{dt} &= \frac{\partial H}{\partial p} \\
\frac{dp}{dt} &= -\frac{\partial H}{\partial q}
\end{align}
$$

and hence, the Hamiltonian is conserved across time:

$$
\frac{dH}{dt} = \frac{\partial H}{\partial q} \frac{dq}{dt} + \frac{\partial H}{\partial p} \frac{dp}{dt} = 0 
$$

## Action-Angle Coordinates

For integrable systems, there exists a mapping from canonical coordinates ($q$, $p$) to action-angle coordinates ($\theta$, $I$).

In these coordinates, the Hamiltonian is only a function of the actions $I$,
not the angles $\theta$:

$$
H(q, p) \to H(\theta, I) = K(I)
$$

and hence, by Hamilton's equations:

$$
\frac{d\theta}{dt} = \frac{dH}{dI}.
$$

which is a constant across time as $I$ does not change with time:

$$
\frac{dI}{dt} = -\frac{dH}{d\theta} = 0.
$$

Thus, the angles $\theta$ evolve linearly, while the actions $I$ remain constant through time. This makes them very simple to model, requiring no integration to find the state of the system at any instant of time.

To enable our exploration, we focus on the simple harmonic oscillator system, introduced above.

## The Simple Harmonic Oscillator

The Hamiltonian for the simple harmonic oscillator is:
$$
H(q, p) = \frac{p^2}{2m} + \frac{m\omega^2q^2}{2}
$$

Applying Hamilton's equations,
$$
\begin{align}
\frac{dq}{dt} &= \frac{\partial H}{\partial p} = \frac{p}{m} \\
\frac{dp}{dt} &= -\frac{\partial H}{\partial q} = -m\omega^2q
\end{align}
$$

which can be solved as:
$$
\begin{align}
q &= A\cos(\omega t + \phi) \\
p &= −m\omega A\sin(\omega t + \phi)
\end{align}
$$
where $A$ and $\phi$ are integration constants.

This allows us to generate $(q, p)$ coordinates at any instant, given the values of $A, \phi, m$ and $\omega$.

In [None]:
# Parameters of system.
simulation_parameters = {
    'phi': jnp.asarray([0, jnp.pi/2]),
    'A': jnp.asarray([1, 2]),
    'm': jnp.asarray([1, 1]),
    'w': jnp.asarray([0.05, 0.2]),
}

In [None]:
# Generate coordinates at multiple instants of time!
times = jnp.arange(100) * 1
generate_canonical_coordinates_fn = jax.vmap(
    jax.vmap(shm_simulation.generate_canonical_coordinates, in_axes=(0, None)),
    in_axes=(None, 0), out_axes=1)
positions, momentums = generate_canonical_coordinates_fn(times, simulation_parameters)

In [None]:
# Compute Hamiltonians at the different instants.
hamiltonians = jax.vmap(shm_simulation.compute_hamiltonian, in_axes=(0, 0, None))(positions, momentums, simulation_parameters)
hamiltonians[:10]

In [None]:
shm_simulation.plot_coordinates(positions, momentums,
                                simulation_parameters,
                                title='TRUE TRAJECTORIES')

## Coupled Harmonic Oscillators

Consider a system of $n$ particles with masses $m_1, m_2, ..., m_n$.
Each particle is connected to a wall with a spring of spring constant $k$,
and every pair of particles is connected to each other with a spring of spring constant $\kappa$.
Then, the equations of motion take the following form:
$$
m_1\frac{\partial^2 q_i}{\partial t^2} = -(k + \kappa) q_i + \sum_{j \neq i}\kappa q_j
$$
for each $i \in \{1, 2 ..., n\}$.

To find the normal modes where all particles oscillate with the same angular frequency $\omega$, we make the exponential ansatz:
$$
q_i = c_ie^{i\omega t}
$$
for each $i \in \{1, 2 ..., n\}$.

This gives us:
$$
-\omega^2c_i = -\frac{k + \kappa}{m_i}c_i + \sum_{j \neq i}\frac{\kappa}{m_i}c_j
$$
for each $i \in \{1, 2 ..., n\}$.
In a matrix form, the coefficients $c$ satisfy
$$
(M + \omega^2I_n)c = 0
$$
where $M$ is the matrix such that
$M_{ii} = -\frac{k + \kappa}{m_i}$ and $M_{ij} = \frac{\kappa}{m_i}$ for $i \neq j$.

Thus, the angular frequencies $\omega$ are the square roots of the negative
eigenvalues of $M$.

In [None]:
simulation_parameters = {
    'm': jnp.asarray([2., 10.]),
    'k_wall': jnp.asarray([10., 4.]),
    'k_pair': jnp.asarray(1.),
    'A': jnp.asarray([3., 1.]),
    'phi': jnp.asarray([jnp.pi/2, 0.])
}

In [None]:
# Generate coordinates at multiple instants of time!
times = jnp.arange(100) * 1
positions, momentums = jax.vmap(chm_simulation.generate_canonical_coordinates, in_axes=(0, None), out_axes=0)(
    times, simulation_parameters)

In [None]:
# Compute Hamiltonians at the different instants.
hamiltonians = jax.vmap(chm_simulation.compute_hamiltonian, in_axes=(0, 0, None))(
    positions, momentums, simulation_parameters)
hamiltonians

## Our Approach

We propose learning the transformation to and from action-angle coordinates
as neural networks.

$$
(q_0, p_0) \xrightarrow{Enc} (\theta_0, I) \xrightarrow{\theta_t \ = \ \theta_0 + t \cdot \dot{\theta}} (\theta_t, I) \xrightarrow{Dec} (q_t, p_t)
$$

To be precise, our ActionAngleNetwork performs the following operations:
* Given canonical coordinates $(q_0, p_0)$, encode them with $Enc$ to action-angle coordinates $(\theta_0, I)$.
* Compute the Hamiltonian in these coordinates $H = Ham(I)$ with a HamiltonianNet $Ham$. Take the gradient of $H$ with respect to $I$ to get the angular velocity $\dot{\theta}$:

$$
\dot{\theta} = \frac{d\theta}{dt} = \frac{dH}{dI}.
$$

* In the current implementation, we directly parametrize $\dot{\theta} = f_\phi(I)$.
* To get the angles at time $t$, we simply need to multiply with the angular velocity:

$$
\theta_t = \theta_0 + t \cdot \dot{\theta}
$$

* Convert action-angle coordinates $(\theta_t, I)$ back to canonical coordinates $(q_t, p_t)$, via the decoder $Dec$.

We explore different parametrizations for the encoder $Enc$, HamiltonianNet $Ham$ and decoder 
$Dec$ neural networks.

We also explore representing the action-angle space in Cartesian coordinates.
In this formulation, the Encoder is followed by a fixed transformation to Polar coordinates.
$$
(q_0, p_0)
\xrightarrow{Enc} (\hat{q}_0, \hat{p}_0)
\xrightarrow{\text{to Polar}} (\theta_0, I)
\xrightarrow{\theta_t \ = \ \theta_0 + t \cdot \dot{\theta}} (\theta_t, I) 
\xrightarrow{\text{to Cartesian}} (\hat{q}_t, \hat{p}_t)
\xrightarrow{Dec} (q_t, p_t)
$$

In [None]:
#@title Training Configuration
config = action_angle_flow.get_config()
workdir = tempfile.mkdtemp()

In [None]:
%tensorboard --logdir={workdir} --port=0

In [None]:
scaler, state, aux = train.train_and_evaluate(config, workdir)

In [None]:
train_positions = aux['train']['positions']
train_momentums = aux['train']['momentums']
train_simulation_parameters = aux['train']['simulation_parameters']
all_train_metrics = aux['train']['metrics']

test_positions = aux['test']['positions']
test_momentums = aux['test']['momentums']
test_simulation_parameters = aux['test']['simulation_parameters']
all_test_metrics = aux['test']['metrics']

### Plotting Loss



In [None]:
%config InlineBackend.figure_format = 'retina'

metrics = all_test_metrics
total_losses = {jump: np.asarray(list(metrics[step][jump]['prediction_loss'] for step in metrics)) for jump in config.test_time_jumps}
steps = list(metrics.keys())
colors = plt.cm.viridis(np.linspace(0, 1, len(total_losses)))

fig, ax = plt.subplots()
for jump_color, jump in zip(colors, config.test_time_jumps):
  total_losses_for_jump = total_losses[jump]
  ax.plot(steps, total_losses_for_jump, label=jump, color=jump_color)

ax.set_title('Time Jump Sizes: Test Loss')
ax.set_xlabel('Steps')
ax.set_ylabel('Loss')
ax.set_yscale('log')
# ax.set_xscale('log')
ax.legend(title='Jump Size')
plt.show()

In [None]:
fig.savefig('test_losses.pdf', dpi=1000)
files.download('test_losses.pdf')

### Plotting Change in Hamiltonians

In [None]:
%config InlineBackend.figure_format = 'retina'

metrics = all_test_metrics
total_changes = {jump: np.asarray(list(metrics[step][jump]['mean_change_in_hamiltonians'] for step in metrics)) for jump in config.test_time_jumps}
steps = list(metrics.keys())
colors = plt.cm.viridis(np.linspace(0, 1, len(total_changes)))

fig, ax = plt.subplots()
for jump_color, jump in zip(colors, config.test_time_jumps):
  total_changes_for_jump = total_changes[jump]
  ax.plot(steps, total_changes_for_jump, label=jump, color=jump_color)

true_position, true_momentum = train.inverse_transform_with_scaler(test_positions[0, :1], test_momentums[0, :1], scaler)
actual_hamiltonian = shm_simulation.compute_hamiltonian(true_position, true_momentum, simulation_parameters)
actual_hamiltonian = np.asarray(actual_hamiltonian).squeeze()
ax.axhline(y=actual_hamiltonian, c='gray', linestyle='--')
ax.set_title('Time Jump Sizes: Mean Change in Hamiltonian')
ax.set_xlabel('Steps')
ax.set_ylabel('Change')
ax.set_yscale('log')
# ax.set_xscale('log')
ax.legend(title='Jump Size')
plt.show()

In [None]:
fig.savefig('test_change_in_hamiltonian.pdf', dpi=1000)
files.download('test_change_in_hamiltonian.pdf')

### Plotting Action and Angle Space

In [None]:
# Sample position-momentum space.
max_position = 1.2 * np.abs(train_positions).max()
max_momentum = 1.2 * np.abs(train_momentums).max()
plot_positions = jnp.linspace(-max_position, max_position, num=100)
plot_momentums = jnp.linspace(-max_momentum, max_momentum, num=100)
grid = jnp.meshgrid(plot_positions, plot_momentums)
plot_positions = grid[0][:, :, jnp.newaxis]
plot_momentums = grid[1][:, :, jnp.newaxis]

# Pad the remaining coordinates with zeros.
def pad_coords(positions: chex.Array, momentums: chex.Array, index: int) -> Tuple[chex.Array, chex.Array]:
  positions = jnp.pad(positions, ((0, 0), (0, 0), (index, config.num_trajectories - index - 1)))
  momentums = jnp.pad(momentums, ((0, 0), (0, 0), (index, config.num_trajectories - index - 1)))
  return positions, momentums

trajectory_index = 1
plot_positions, plot_momentums = pad_coords(plot_positions, plot_momentums, trajectory_index)

# Compute actions and angles.
_, _, auxiliary_predictions = jax.vmap(state.apply_fn, in_axes=(None, 0, 0, None))(state.params, plot_positions, plot_momentums, 0)
plot_actions = auxiliary_predictions['actions']
plot_angles = auxiliary_predictions['current_angles']

train_positions_rescaled, train_momentums_rescaled = train.inverse_transform_with_scaler(train_positions, train_momentums, scaler)
plot_positions, plot_momentums, plot_actions, plot_angles = jax.tree.map(lambda arr: arr[:, :, trajectory_index], (plot_positions, plot_momentums, plot_actions, plot_angles))
plot_positions, plot_momentums = train.inverse_transform_with_scaler(plot_positions, plot_momentums, scaler)

In [None]:
fig, ax = plt.subplots()
contours = ax.contour(plot_positions, plot_momentums, plot_actions, 50, cmap='viridis')
fig.colorbar(contours)
ax.plot(train_positions_rescaled[:, trajectory_index], train_momentums_rescaled[:, trajectory_index],
        c='gray', linestyle='--')
ax.set_xlabel('q')
ax.set_ylabel('p')
ax.set_title('Actions Contour')
plt.show()

In [None]:
fig.savefig('actions_contour.pdf', dpi=1000)
files.download('actions_contour.pdf')

In [None]:
fig, ax = plt.subplots()
contours = ax.contour(plot_positions, plot_momentums, plot_angles, 50, cmap='viridis')
fig.colorbar(contours)
ax.plot(train_positions_rescaled[:, trajectory_index], train_momentums_rescaled[:, trajectory_index],
        c='gray', linestyle='--')
ax.set_xlabel('q')
ax.set_ylabel('p')
ax.set_title('Angles Contour')
plt.show()

In [None]:
fig.savefig('angles_contour.pdf', dpi=1000)
files.download('angles_contour.pdf')

### Plotting True Trajectories

In [None]:
train_positions_rescaled, train_momentums_rescaled = train.inverse_transform_with_scaler(train_positions, train_momentums, scaler)
max_position = np.abs(train_positions_rescaled).max()
max_momentum = np.abs(train_momentums_rescaled).max()
shm_simulation.static_plot_coordinates_in_phase_space(train_positions_rescaled, train_momentums_rescaled, title='TRAIN TRAJECTORIES', max_position=max_position, max_momentum=max_momentum)

In [None]:
test_positions_rescaled, test_momentums_rescaled = train.inverse_transform_with_scaler(test_positions, test_momentums, scaler)
max_position = np.abs(train_positions_rescaled).max()
max_momentum = np.abs(train_momentums_rescaled).max()
shm_simulation.static_plot_coordinates_in_phase_space(test_positions_rescaled, test_momentums_rescaled, title='TEST TRAJECTORIES', max_position=max_position, max_momentum=max_momentum)

### One-step Predictions

In [None]:
def predict_for_trajectory(positions_for_trajectory: chex.Array, momentums_for_trajectory: chex.Array, jump: int):
  curr_positions, curr_momentums, target_positions, target_momentums = train.get_coordinates_for_time_jump(positions_for_trajectory, momentums_for_trajectory, jump)
  predicted_positions, predicted_momentums, auxiliary_predictions = train.compute_predictions(state, curr_positions, curr_momentums, jump * config.time_delta)
  predicted_positions, predicted_momentums = train.inverse_transform_with_scaler(predicted_positions, predicted_momentums, scaler)
  return predicted_positions, predicted_momentums

In [None]:
jump = 1
train_positions_rescaled, train_momentums_rescaled = train.inverse_transform_with_scaler(train_positions, train_momentums, scaler)
max_position = np.abs(train_positions_rescaled).max()
max_momentum = np.abs(train_momentums_rescaled).max()

predicted_positions, predicted_momentums = predict_for_trajectory(train_positions, train_momentums, jump)
shm_simulation.static_plot_coordinates_in_phase_space(predicted_positions, predicted_momentums, title=f'PREDICTED TRAIN TRAJECTORIES: JUMP {jump}', max_position=max_position, max_momentum=max_momentum)

In [None]:
train_one_step_predicted_trajectories_anim = shm_simulation.plot_coordinates(
    predicted_positions[:200, 0], predicted_momentums[:200, 0], jax.tree.map(lambda arr: arr[0], train_simulation_parameters),
    title=f'ONE-STEP PREDICTED TRAIN TRAJECTORIES: JUMP {jump}')
train_one_step_predicted_trajectories_anim

In [None]:
jump = 200
train_positions_rescaled, train_momentums_rescaled = train.inverse_transform_with_scaler(train_positions, train_momentums, scaler)
max_position = np.abs(train_positions_rescaled).max()
max_momentum = np.abs(train_momentums_rescaled).max()

predicted_positions, predicted_momentums = predict_for_trajectory(test_positions, test_momentums, jump)
shm_simulation.static_plot_coordinates_in_phase_space(predicted_positions, predicted_momentums,
                                                      title=f'PREDICTED TEST TRAJECTORIES: JUMP {jump}',
                                                      max_position=max_position, max_momentum=max_momentum)

In [None]:
test_one_step_predicted_trajectories_anim = shm_simulation.plot_coordinates(
    predicted_positions[:200, 0], predicted_momentums[:200, 0], jax.tree.map(lambda arr: arr[0], test_simulation_parameters),
    title=f'ONE-STEP PREDICTED TEST TRAJECTORIES: JUMP {jump}')
test_one_step_predicted_trajectories_anim

In [None]:
test_one_step_predicted_trajectories_anim.save('test_one_step_predicted_trajectories.gif')
files.download('test_one_step_predicted_trajectories.gif')

In [None]:
test_one_step_predicted_trajectories_phase_space_anim = shm_simulation.plot_coordinates_in_phase_space(
    predicted_positions[:100], predicted_momentums[:100],
    test_simulation_parameters,
    title=f'ONE-STEP PREDICTED TEST TRAJECTORIES: JUMP {jump}')
test_one_step_predicted_trajectories_phase_space_anim

In [None]:
test_one_step_predicted_trajectories_phase_space_anim.save('test_one_step_predicted_trajectories_phase_space.gif')
files.download('test_one_step_predicted_trajectories_phase_space.gif')

### Distribution of Actions

In [None]:
# Compute actions.
jump = 1
curr_positions, curr_momentums, *_ = train.get_coordinates_for_time_jump(train_positions[0], train_momentums[0], jump)

_, _, auxiliary_predictions = state.apply_fn(state.params, curr_positions, curr_momentums, 0)
plot_actions = auxiliary_predictions['actions']
plot_angles = auxiliary_predictions['curr_angles']

actions = auxiliary_predictions['actions']
actions = np.asarray(actions).flatten()
plt.hist(actions, bins=1000)
plt.gca().xaxis.set_major_formatter(matplotlib.ticker.StrMethodFormatter('{x:,.3f}'))
plt.show()

In [None]:
angular_velocities = auxiliary_predictions['angular_velocities']
angular_velocities = np.asarray(angular_velocities).flatten()
plt.hist(angular_velocities, bins=1000)
plt.gca().xaxis.set_major_formatter(matplotlib.ticker.StrMethodFormatter('{x:,.4f}'))
plt.show()

### Non-Recursive Multi-Step Predictions

In [None]:
jump = 50
non_recursive_predictions = state.apply_fn(
    state.params, test_positions[:1], test_momentums[:1],
    jnp.arange(1, 500) * jump * config.time_delta, method=models.ActionAngleNetwork.predict_multi_step)
non_recursive_multi_step_predicted_positions, non_recursive_multi_step_predicted_momentums, non_recursive_multi_step_auxiliary_predictions = non_recursive_predictions
non_recursive_multi_step_auxiliary_predictions['angular_velocities'], non_recursive_multi_step_auxiliary_predictions['current_angles'], non_recursive_multi_step_auxiliary_predictions['future_angles'][:10]

In [None]:
non_recursive_multi_step_predicted_positions, non_recursive_multi_step_predicted_momentums = train.inverse_transform_with_scaler(non_recursive_multi_step_predicted_positions, non_recursive_multi_step_predicted_momentums, scaler)

In [None]:
fig = shm_simulation.static_plot_coordinates_in_phase_space(
    non_recursive_multi_step_predicted_positions[:100], non_recursive_multi_step_predicted_momentums[:100],
    title=f'PREDICTED TEST TRAJECTORIES: JUMP {jump}')
fig

In [None]:
fig.savefig(f'test_multi_step_predicted_phase_space_jump_{jump}.pdf', dpi=1000, bbox_inches='tight', pad_inches=0)
files.download('test_multi_step_predicted_phase_space_jump_{jump}.pdf')

In [None]:
test_multi_step_predicted_trajectories_anim = plot_coordinates(
    non_recursive_multi_step_predicted_positions, non_recursive_multi_step_predicted_momentums, simulation_parameters,
    title='NON-RECURSIVE MULTI-STEP PREDICTED TEST TRAJECTORIES')
test_multi_step_predicted_trajectories_anim

In [None]:
test_multi_step_predicted_trajectories_anim.save('test_non_recursive_multi_step_predicted_trajectories.gif')
files.download('test_non_recursive_multi_step_predicted_trajectories.gif')

In [None]:
test_multi_step_predicted_trajectories_phase_space_anim = shm_simulation.plot_coordinates_in_phase_space(
    non_recursive_multi_step_predicted_positions[:100], non_recursive_multi_step_predicted_momentums[:100], test_simulation_parameters,
    title=f'PREDICTED TEST TRAJECTORIES: JUMP {jump}')
test_multi_step_predicted_trajectories_phase_space_anim

In [None]:
test_multi_step_predicted_trajectories_phase_space_anim.save(f'test_non_recursive_multi_step_predicted_trajectories_phase_space_jump_{jump}.gif')
files.download(f'test_non_recursive_multi_step_predicted_trajectories_phase_space_jump_{jump}.gif')

### Recursive Multi-Step Predictions

In [None]:
def predict_next_step(carry, _):
  current_position, current_momentum = carry
  predicted_position, predicted_momentum, auxiliary_predictions = train.compute_predictions(
      state, current_position, current_momentum, config.time_delta)
  return (predicted_position, predicted_momentum), (predicted_position, predicted_momentum, auxiliary_predictions)

_, recursive_predictions = jax.lax.scan(predict_next_step, (test_positions[0, :1], test_momentums[0, :1]), None, length=1000)
recursive_multi_step_predicted_positions, recursive_multi_step_predicted_momentums, recursive_multi_step_auxiliary_predictions = recursive_predictions
recursive_multi_step_auxiliary_predictions['angular_velocities'][0], recursive_multi_step_auxiliary_predictions['current_angles'][:10], recursive_multi_step_auxiliary_predictions['future_angles'][:10]

In [None]:
recursive_multi_step_predicted_positions[0], recursive_multi_step_predicted_momentums[0]

In [None]:
recursive_multi_step_predicted_positions, recursive_multi_step_predicted_momentums = train.inverse_transform_with_scaler(
    recursive_multi_step_predicted_positions, recursive_multi_step_predicted_momentums, scaler)

In [None]:
recursive_multi_step_predicted_positions[0], recursive_multi_step_predicted_momentums[0]

In [None]:
test_multi_step_predicted_trajectories_anim = shm_simulation.plot_coordinates(
    recursive_multi_step_predicted_positions[:200], recursive_multi_step_predicted_momentums[:200], simulation_parameters,
    title='RECURSIVE MULTI-STEP PREDICTED TEST TRAJECTORIES')
test_multi_step_predicted_trajectories_anim

In [None]:
test_multi_step_predicted_trajectories_anim.save('test_recursive_multi_step_predicted_trajectories.gif')
files.download('test_recursive_multi_step_predicted_trajectories.gif')

In [None]:
shm_simulation.static_plot_coordinates_in_phase_space(
    recursive_multi_step_predicted_positions, recursive_multi_step_predicted_momentums,
    title='RECURSIVE MULTI-STEP PREDICTED TEST TRAJECTORIES')

In [None]:
test_multi_step_predicted_trajectories_phase_space_anim = shm_simulation.plot_coordinates_in_phase_space(
    recursive_multi_step_predicted_positions[:200], recursive_multi_step_predicted_momentums[:200], simulation_parameters,
    title='RECURSIVE MULTI-STEP PREDICTED TEST TRAJECTORIES')
test_multi_step_predicted_trajectories_phase_space_anim

In [None]:
test_multi_step_predicted_trajectories_phase_space_anim.save('test_recursive_multi_step_predicted_trajectories_phase_space.gif')
files.download('test_recursive_multi_step_predicted_trajectories_phase_space.gif')

### True Trajectories

In [None]:
test_true_trajectories_anim = shm_simulation.plot_coordinates(
    test_target_positions[:200], test_target_momentums[:200], simulation_parameters,
    title='TRUE TEST TRAJECTORIES')
test_true_trajectories_anim

In [None]:
test_true_trajectories_anim.save('test_true_trajectories.gif')
files.download('test_true_trajectories.gif')