# RO47019: Intelligent Control Systems Practical Assignment
* Period: 2022-2023, Q3
* Course homepage: https://brightspace.tudelft.nl/d2l/home/500969
* Instructor: Cosimo Della Santina (C.DellaSantina@tudelft.nl)
* Teaching assistant: Ruben Martin Rodriguez (R.MartinRodriguez@student.tudelft.nl)
* (c) TU Delft, 2023

Make sure you fill in any place that says `YOUR CODE HERE` or `YOUR ANSWER HERE`. Remove `raise NotImplementedError()` afterwards. Moreover, if you see an empty cell, please DO NOT delete it, instead run that cell as you would run all other cells. Please fill in your name(s) and other required details below:

In [None]:
# Please fill in your names, student numbers, netID, and emails below.
STUDENT_1_NAME = ""
STUDENT_1_STUDENT_NUMBER = ""
STUDENT_1_NETID = ""
STUDENT_1_EMAIL = ""

In [None]:
# Note: this block is a check that you have filled in the above information.
# It will throw an AssertionError until all fields are filled
assert STUDENT_1_NAME != ""
assert STUDENT_1_STUDENT_NUMBER != ""
assert STUDENT_1_NETID != ""
assert STUDENT_1_EMAIL != ""

### General announcements

* Do *not* share your solutions, and do *not* copy solutions from others. By submitting your solutions, you claim that you alone are responsible for this code.

* Do *not* email questions directly, since we want to provide everybody with the same information and avoid repeating the same answers. Instead, please post your questions regarding this assignment in the correct support forum on Brightspace, this way everybody can benefit from the response. If you do have a particular question that you want to ask directly, please use the scheduled Q&A hours to ask the TA.

* There is a strict deadline for each assignment. Students are responsible to ensure that they have uploaded their work in time. So, please double check that your upload succeeded to the Brightspace and avoid any late penalties.

* This [Jupyter notebook](https://jupyter.org/) uses `nbgrader` to help us with automated tests. `nbgrader` will make various cells in this notebook "uneditable" or "unremovable" and gives them a special id in the cell metadata. This way, when we run our checks, the system will check the existence of the cell ids and verify the number of points and which checks must be run. While there are ways that you can edit the metadata and work around the restrictions to delete or modify these special cells, you should not do that since then our nbgrader backend will not be able to parse your notebook and give you points for the assignment. You are free to add additional cells, but if you find a cell that you cannot modify or remove, please know that this is on purpose.

* This notebook will have in various places a line that throws a `NotImplementedError` exception. These are locations where the assignment requires you to adapt the code! These lines are just there as a reminder for youthat you have not yet adapted that particular piece of code, especially when you execute all the cells. Once your solution code replaced these lines, it should accordingly *not* throw any exceptions anymore.

Before you turn this problem in, make sure everything runs as expected. First, **restart the kernel** (in the menubar, select Kernel$\rightarrow$Restart) and then **run all cells** (in the menubar, select Cell$\rightarrow$Run All).

# Lagrangian Neural Network (LNN) training

**Author:** Maximilian Stölzle (M.W.Stolzle@tudelft.nl)

This notebook will contain functions to assist with the training of the lagrangian neural network. The task assignment can be found in the notebook `task_2c-3_train_lnn.ipynb`.

In [None]:
from distutils.util import strtobool
from flax.core import FrozenDict
from functools import partial
from jax.config import config as jax_config

jax_config.update("jax_platform_name", "cpu")  # set default device to 'cpu'
jax_config.update("jax_enable_x64", True)  # double precision
import dill
from flax.training.train_state import TrainState
import jax
from jax import debug, jit
import jax.numpy as jnp
from jax import random
from matplotlib import rcParams
import matplotlib.pyplot as plt
import numpy as np
import optax
import os
from pathlib import Path
from tqdm.notebook import tqdm  # progress bar
from typing import Callable, Dict, List, Tuple

from jax_double_pendulum.utils import normalize_link_angles

# import the learned discrete forward dynamics from lagrangian_nn.ipynb
from ipynb.fs.full.lnn import (
    discrete_forward_dynamics,
    MassMatrixNN,
    PotentialEnergyNN,
)

# define boolean to check if the notebook is run for the purposes of autograding
AUTOGRADING = strtobool(os.environ.get("AUTOGRADING", "false"))

In [None]:
def load_datasets(
    filepath: Path, rng: random.KeyArray, val_ratio: float = 0.2
) -> Tuple[Dict[str, jnp.ndarray], Dict[str, jnp.ndarray]]:
    """
    Loads the datasets from a .npz file.
    The dataset needs to have the following entries:
        - dt_ss: array of shape (N, ) containing the time step between the current and the next state [s]
        - th_curr_ss: Array of shape (N, 2) containing the current link angles of the double pendulum. [rad]
        - th_d_curr_ss: Array of shape (N, 2) containing the current link angular velocities of the double pendulum. [rad/s]
        - th_next_ss: Array of shape (N, 2) containing the nextt link angles of the double pendulum. [rad]
        - th_d_next_ss: Array of shape (N, 2) containing the next link angular velocities of the double pendulum. [rad/s]
    Args:
        filepath: Path to the .npz file.
        rng: PRNG key for pseudo-random number generation.
        val_ratio: Ratio of validation dataset with respect to the entire dataset size. Needs to be in interval [0, 1].
    Returns:
        train_ds: Dictionary containing the training dataset with the same keys as the original dataset.
        val_ds: Dictionary containing the validation dataset with the same keys as the original dataset.
    """
    assert 0.0 <= val_ratio <= 1.0, "Validation ratio needs to be in interval [0, 1]."

    dataset = jnp.load(filepath)
    num_samples = dataset["th_curr_ss"].shape[0]

    # Randomly split the dataset into a training set and a validation set
    # Important: make use of the provided PRNG key
    train_ds, val_ds = {}, {}
    # YOUR CODE HERE
    raise NotImplementedError()

    return train_ds, val_ds

In [None]:
def create_learning_rate_fn(
    num_epochs: int, steps_per_epoch: int, base_lr: float, warmup_epochs: int = 0
) -> Callable:
    """
    Creates a learning rate schedule function. THe learning rate scheduler implements the following procedure:
        1. A linear increase of the learning rate for a specified number of warmup epochs up to the base lr
        2. A cosine decay of the learning rate throughout the remaining epochs
    Args:
        num_epochs: Number of epochs to train for.
        steps_per_epoch: Number of steps per epoch.
        base_lr: Base learning rate.
        warmup_epochs: Number of epochs for warmup.
    Returns:
        learning_rate_fn: A function that takes the current step and returns the current learning rate.
            It has the signature learning_rate_fn(step: int) -> lr.
    """
    # Create the learning rate function implementing the procedure documented in the docstring
    # Hint: use the following optax functions:
    # optax.linear_schedule, optax.cosine_decay_schedule, optax.join_schedules
    # YOUR CODE HERE
    raise NotImplementedError()

    return learning_rate_fn

In [None]:
def initialize_train_states(
    rng: random.KeyArray, learning_rate_fn: Callable, weight_decay: float = 0.0
) -> Dict[str, TrainState]:
    """
    Initialize the train states of the two neural networks.
    Args:
        rng: PRNG key for pseudo-random number generation.
        learning_rate_fn: A function that takes the current step and returns the current learning rate.
            It has the signature learning_rate_fn(step: int) -> lr.
        weight_decay: Weight decay of the Adam optimizer for training the neural networks.
    Returns:
        states: Dictionary containing the current states of the training of the two neural networks.
            Entries of dictionary:
                - MassMatrixNN: TrainState of the mass matrix neural network
                - PotentialEnergyNN: TrainState of the potential energy neural network
    """
    # initialize the neural network objects
    mass_matrix_nn = MassMatrixNN()
    potential_energy_nn = PotentialEnergyNN()

    # initialize parameters of the neural networks by passing a dummy input through the network
    # Hint: pass the `rng` and a dummy input to the `init` method of the neural network object
    mass_matrix_nn_params = FrozenDict()
    potential_energy_nn_params = FrozenDict()
    # YOUR CODE HERE
    raise NotImplementedError()

    # initialize the Adam with weight decay optimizer for both neural networks
    # Hint: use optax.adamw
    mass_matrix_nn_tx, potential_energy_nn_tx = None, None
    # YOUR CODE HERE
    raise NotImplementedError()

    # create the TrainState object for both neural networks
    mass_matrix_nn_train_state = TrainState.create(
        apply_fn=mass_matrix_nn.apply,
        params=mass_matrix_nn_params,
        tx=mass_matrix_nn_tx,
    )
    potential_energy_nn_train_state = TrainState.create(
        apply_fn=potential_energy_nn.apply,
        params=potential_energy_nn_params,
        tx=potential_energy_nn_tx,
    )

    # save the TrainState objects into a dictionary, which we can pass to the various training functions
    states = {
        "MassMatrixNN": mass_matrix_nn_train_state,
        "PotentialEnergyNN": potential_energy_nn_train_state,
    }

    return states

In [None]:
@jit
def mse_loss_fn(pred: jnp.ndarray, target: jnp.ndarray) -> jnp.ndarray:
    """
    Mean Squared Error (MSE) loss function to train the neural network.
    Args:
        pred: Predictions of the neural network.
        target: Targets (i.e. labels) for the supervised training.
    Returns:
        loss: Mean squared error (MSE) as array of shape () between the predicted and target values.
    """
    # YOUR CODE HERE
    raise NotImplementedError()
    return loss

In [None]:
@jit
def compute_metrics(
    batch: Dict[str, jnp.ndarray], preds: Dict[str, jnp.ndarray]
) -> Dict[str, jnp.ndarray]:
    """
    Computes the metrics of the current batch measuring the performance of the dynamics prediction.
    Args:
        batch: dictionary of batch data (i.e. inputs and targets)
        preds: dictionary of batch predictions (i.e. outputs of the neural network) with keys:
            - th_next_pred: predicted next link angles of the double pendulum. [rad/s]
            - th_d_next_pred: predicted next link angular velocities of the double pendulum. [rad/s^2]
    Returns:
        metrics: Dictionary of metrics
    """
    error_th = normalize_link_angles(preds["th_next_ss"] - batch["th_next_ss"])
    metrics = {
        "rmse_th_next": jnp.sqrt(jnp.mean(jnp.square(error_th))),
        "rmse_th_d_next": jnp.sqrt(
            mse_loss_fn(preds["th_d_next_ss"], batch["th_d_next_ss"])
        ),
    }

    return metrics

In [None]:
# Vectorize the discrete forward dynamics of the LNN
# the discrete forward dynamics for a single state are available through the function `discrete_forward_dynamics`
# It should have the following signature
#   discrete_forward_dynamics_vmapped(
#       mass_matrix_nn_params,
#       potential_energy_nn_params,
#       dt_ss,
#       th_curr_ss,
#       th_d_curr_ss,
#       tau_ss,
#   ) -> th_next_pred_ss, th_d_next_pred_ss, th_dd_ss
# where `dt_ss` has the shape (N, ), `th_curr_ss` has the shape (N, 2), `th_d_curr_ss` has the shape (N, 2),
# `tau_ss` has the shape (N, 2), `th_next_pred_ss` has the shape (N, 2), `th_d_next_pred_ss` has the shape (N, 2)
# `th_dd_ss` has the shape (N, 2). N is the number of samples in the batch.

discrete_forward_dynamics_vmapped = None
# YOUR CODE HERE
raise NotImplementedError()

In [None]:
@partial(
    jit,
    static_argnums=2,
    static_argnames="learning_rate_fn",
)
def train_step(
    states: Dict[str, TrainState],
    batch: Dict[str, jnp.ndarray],
    learning_rate_fn: Callable,
) -> Tuple[Dict[str, TrainState], Dict[str, jnp.ndarray]]:
    """
    Trains the neural network for one step.
    Args:
        states: Dictionary containing the current states of the training of the two neural networks
            Entries of dictionary:
                - MassMatrixNN: TrainState of the mass matrix neural network
                - PotentialEnergyNN: TrainState of the potential energy neural network
        batch: dictionary of batch data
        learning_rate_fn: A function that takes the current step and returns the current learning rate.
            It has the signature learning_rate_fn(step: int) -> lr.
    Returns:
        states: Dictionary of updated training states
            Entries of dictionary:
                - MassMatrixNN: TrainState of the mass matrix neural network
                - PotentialEnergyNN: TrainState of the potential energy neural network
        metrics: Dictionary of training metrics
    """

    def loss_fn(
        mass_matrix_nn_params: Dict, potential_energy_nn_params: Dict
    ) -> Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]:
        """
        Loss function to train the neural network.
        Args:
            mass_matrix_nn_params: Parameters of the mass matrix neural network.
            potential_energy_nn_params: Parameters of the potential energy neural network.
        Returns:
            loss: Mean squared error between the predicted and target link angular velocities at the next timestep.
            preds: Dictionary of batch predictions (i.e. outputs of the neural network) with keys:
        """

        # Compute the estimated link angles and velocities
        # save the batch of predicted link angles, velocities at the next timestep and the link angular
        # accelerations into the variables `th_next_pred_ss`, `th_d_next_pred_ss`, and `th_dd_pred_ss`
        # Hint: use `discrete_forward_dynamics_vmapped`
        _th_next_pred_ss = jnp.zeros_like(batch["th_curr_ss"])
        _th_d_next_pred_ss = jnp.zeros_like(batch["th_d_curr_ss"])
        _th_dd_pred_ss = jnp.zeros_like(batch["th_curr_ss"])
        # YOUR CODE HERE
        raise NotImplementedError()

        # Compute the MSE loss on the predicted link angular velocities at the next timestep
        _loss = jnp.array(0.0)
        # YOUR CODE HERE
        raise NotImplementedError()

        # write the predictions into a dictionary
        _preds = {
            "th_next_ss": _th_next_pred_ss,
            "th_d_next_ss": _th_d_next_pred_ss,
            "th_dd_ss": _th_dd_pred_ss,
        }

        return _loss, _preds

    # compute loss and gradients with respect to the parameters of the two neural networks
    # save the loss to `loss`, the dictionary with predictions to `preds`
    # save the gradients to `grad_mass_matrix_nn` and `grad_potential_energy_nn`
    # Hint: look at the `TrainState` source code to find-out how to access the NN params from the
    # TrainState object.
    # https://flax.readthedocs.io/en/latest/_modules/flax/training/train_state.html
    # Hint: consider using `jax.value_and_grad` to get the loss and its gradient with respect to the NN parameters
    loss = jnp.array(0.0)
    grad_mass_matrix_nn, grad_potential_energy_nn = None, None
    # YOUR CODE HERE
    raise NotImplementedError()

    # optimize the mass matrix neural network parameters with gradient descent
    states["MassMatrixNN"] = states["MassMatrixNN"].apply_gradients(
        grads=grad_mass_matrix_nn
    )
    # optimize the potential energy neural network parameters with gradient descent
    states["PotentialEnergyNN"] = states["PotentialEnergyNN"].apply_gradients(
        grads=grad_potential_energy_nn
    )

    # compute metrics
    metrics = compute_metrics(batch, preds)
    metrics["loss"] = loss
    # save the currently active learning rates to the `metrics` dictionary
    metrics["lr_mass_matrix_nn"] = learning_rate_fn(states["MassMatrixNN"].step)
    metrics["lr_potential_energy_nn"] = learning_rate_fn(
        states["PotentialEnergyNN"].step
    )

    return states, metrics

In [None]:
@jit
def eval_step(
    states: Dict[str, TrainState], batch: Dict[str, jnp.ndarray]
) -> Dict[str, jnp.ndarray]:
    """
    One validation step of the neural networks.
    Args:
        states: Dictionary containing the current states of the training of the two neural networks
            Entries of dictionary:
                - MassMatrixNN: TrainState of the mass matrix neural network
                - PotentialEnergyNN: TrainState of the potential energy neural network
        batch: dictionary of batch data
    Returns:
        metrics: Dictionary of validation metrics
    """

    # Compute the estimated link angles and velocities
    # save the batch of predicted link angles, velocities at the next timestep and the link angular acceleration
    # into the variables `th_next_pred_ss`, `th_d_next_pred_ss`, and `th_dd_pred_ss`
    # Hint: use `discrete_forward_dynamics_vmapped`
    # Hint: look at the `TrainState` source code to find-out how to access the NN params from the TrainState object
    # https://flax.readthedocs.io/en/latest/_modules/flax/training/train_state.html
    th_next_pred_ss = jnp.zeros_like(batch["th_curr_ss"])
    th_d_next_pred_ss = jnp.zeros_like(batch["th_d_curr_ss"])
    th_dd_pred_ss = jnp.zeros_like(batch["th_curr_ss"])
    # YOUR CODE HERE
    raise NotImplementedError()

    # write the predictions into a dictionary
    preds = {
        "th_next_ss": th_next_pred_ss,
        "th_d_next_ss": th_d_next_pred_ss,
    }

    # Compute the MSE loss on the predicted link angular velocities at the next timestep
    loss = jnp.array(0.0)
    # YOUR CODE HERE
    raise NotImplementedError()

    # compute metrics
    metrics = compute_metrics(batch, preds)
    metrics["loss"] = loss

    return metrics

In [None]:
def train_epoch(
    states: Dict[str, TrainState],
    train_ds: Dict[str, jnp.ndarray],
    batch_size: int,
    epoch: int,
    learning_rate_fn: Callable,
    rng: random.KeyArray,
) -> Tuple[Dict[str, TrainState], float, Dict[str, float]]:
    """
    Train for a single epoch.
    Args:
        states: Dictionary containing the current states of the training of the two neural networks.
            Entries of dictionary:
                - MassMatrixNN: TrainState of the mass matrix neural network
                - PotentialEnergyNN: TrainState of the potential energy neural network
        train_ds: Dictionary containing the training dataset.
        batch_size: Batch size of training loop.
        epoch: Index of current epoch.
        learning_rate_fn: A function that takes the current step and returns the current learning rate.
            It has the signature learning_rate_fn(step: int) -> lr.
        rng: PRNG key for pseudo-random number generation.
    Returns:
        states: Dictionary of updated training states.
        train_loss: Training loss of the current epoch.
        train_metrics: Dictionary of training metrics.
    """
    train_ds_size = int(train_ds["th_curr_ss"].shape[0])
    steps_per_epoch = train_ds_size // batch_size

    perms = jax.random.permutation(rng, train_ds_size)  # get a randomized index array
    perms = perms[: steps_per_epoch * batch_size]  # skip incomplete batch
    perms = perms.reshape(
        (steps_per_epoch, batch_size)
    )  # index array, where each row is a batch
    batch_metrics = []
    for perm in perms:
        batch = {k: v[perm, ...] for k, v in train_ds.items()}
        states, metrics = train_step(states, batch, learning_rate_fn)
        batch_metrics.append(metrics)

    # compute mean of metrics across each batch in epoch.
    batch_metrics_np = jax.device_get(batch_metrics)
    epoch_metrics_np = {
        k: np.mean(jnp.array([metrics[k] for metrics in batch_metrics_np])).item()
        for k in batch_metrics_np[0]
    }  # jnp.mean does not work on lists

    return states, epoch_metrics_np["loss"], epoch_metrics_np

In [None]:
def eval_model(
    states: Dict[str, TrainState],
    val_ds: Dict[str, jnp.ndarray],
) -> Tuple[float, Dict[str, jnp.ndarray]]:
    """
    Validate the model on the validation dataset.
    Args:
        states: Dictionary containing the current states of the training of the two neural networks.
            Entries of dictionary:
                - MassMatrixNN: TrainState of the mass matrix neural network
                - PotentialEnergyNN: TrainState of the potential energy neural network
        val_ds: Dictionary containing the validation dataset.
    Returns:
        val_loss: Validation loss.
        val_metrics: Dictionary of metrics.
    """
    val_metrics = eval_step(states, val_ds)
    val_metrics = jax.device_get(val_metrics)
    val_metrics = jax.tree_util.tree_map(
        lambda x: x.item(), val_metrics
    )  # map the function over all leaves in metrics

    return val_metrics["loss"], val_metrics

In [None]:
def run_lnn_training(
    rng: random.PRNGKey,
    train_ds: Dict[str, jnp.ndarray],
    val_ds: Dict[str, jnp.ndarray],
    num_epochs: int,
    batch_size: int,
    base_lr: float,
    warmup_epochs: int = 0,
    weight_decay: float = 0.0,
    verbose: bool = True,
) -> Tuple[
    jnp.ndarray,
    List[Dict[str, jnp.ndarray]],
    List[Dict[str, jnp.ndarray]],
    List[Dict[str, TrainState]],
]:
    """
    Run the training loop.
    Args:
        rng: PRNG key for pseudo-random number generation.
        train_ds: Dictionary of jax arrays containing the training dataset.
        val_ds: Dictionary of jax arrays containing the validation dataset.
        num_epochs: Number of epochs to train for.
        batch_size: The size of a minibatch (i.e. number of samples in a batch).
        base_lr: Base learning rate (after warmup and before decay).
        warmup_epochs: Number of epochs for warmup.
        weight_decay: Weight decay.
        verbose: If True, print the training progress.
    Returns:
        val_loss_history: Array of validation losses for each epoch.
        val_metrics_history: List of dictionaries containing the validation metrics for each epoch.
        train_states_history: List of dictionaries containing the training states for each epoch.
    """

    # number of training samples
    num_train_samples = len(train_ds["th_curr_ss"])

    # initialize the learning rate scheduler
    learning_rate_fn = None
    # YOUR CODE HERE
    raise NotImplementedError()

    # split of PRNG keys
    # the 1st is used for training,
    # the 2nd to initialize the neural network weights.
    rng, init_train_states_rng = jax.random.split(rng, 2)

    # initialize the train states
    states = initialize_train_states(
        init_train_states_rng, learning_rate_fn, weight_decay
    )

    # initialize the lists for the training history
    val_loss_history = []  # list with validation losses
    train_metrics_history = []  # list with train metric dictionaries
    val_metrics_history = []  # list with validation metric dictionaries
    states_history = []  # list with dictionaries of model states

    if verbose:
        print(f"Training the Lagrangian neural network for {num_epochs} epochs...")

    for epoch in (pbar := tqdm(range(1, num_epochs + 1))):
        # Split the `rng` PRNG key into two new keys
        # use the 1st PRNG to update the `rng` variable
        # store the 2nd PRNG key in the variable `epoch_rng`
        epoch_rng = None
        # YOUR CODE HERE
        raise NotImplementedError()

        # Run the training for the current epoch
        # Use the `epoch_rng` to randomly shuffle the batches
        train_loss, train_metrics = jnp.array(0.0), {}
        # YOUR CODE HERE
        raise NotImplementedError()

        # Evaluate the current set of neural network parameters on the validation set
        val_loss, val_metrics = jnp.array(0.0), {}
        # YOUR CODE HERE
        raise NotImplementedError()

        # Save the model parameters and the validation loss for the current epoch
        val_loss_history.append(val_loss)
        train_metrics_history.append(train_metrics)
        val_metrics_history.append(val_metrics)
        states_history.append(states)

        if verbose:
            pbar.set_description(
                "Epoch: %d, lr: %.6f, train loss: %.7f, val loss: %.7f"
                % (epoch, train_metrics["lr_mass_matrix_nn"], train_loss, val_loss)
            )

    # array of shape (num_epochs, ) with the validation losses of each epoch
    val_loss_history = jnp.array(val_loss_history)

    return val_loss_history, train_metrics_history, val_metrics_history, states_history