# Tutorial: Time-varying Extension of $\texttt{JKOnet}^\ast$


The purpose of this tutorial is to analyze the different approaches of extending $\texttt{JKOnet}^\ast$ to learn time-varying potentials. We will use interactive widgets to allow you to experiment with different configurations and visualize the results.


In [1]:
import jax
import jax.numpy as jnp
from jax import random, grad, vmap
import optax
import matplotlib.pyplot as plt
from flax import linen as nn
from typing import Any, Callable, Sequence
import ipywidgets as widgets
from IPython.display import display
from IPython.display import clear_output


## Define the Real Potential Function
We deliberately select a discontinuous potential to simplify the learning process. The discontinuous nature of the potential provides clear and distinct shifts in value, making it easier to detect and model changes over time. We use `jax.lax.cond` instead of standard Python `if` statements because it enables differentiable conditional logic within JAX. While Python `if` only executes one branch and disrupts JAX's automatic differentiation, `jax.lax.cond` evaluates both branches and selects the appropriate one. This ensures that the entire computation remains differentiable, making it compatible with functions like `jax.grad`.

In [2]:
# Potential Function
def V_real(x, t):
    condition1 = (0.2 <= t) & (t <= 0.3)
    condition2 = (0.7 <= t) & (t <= 0.8)
    condition = condition1 | condition2
    return jax.lax.cond(condition, lambda _: 0.0 * x, lambda _: -0.75 * x**2, None)

## Define the Multi-Layer Perceptron (MLP)
Here we define our MLP architecture using Flax's linen module.

In [3]:
# MLP Model Definition
class MLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, inputs):
        x = inputs
        for i, feat in enumerate(self.features):
            x = nn.Dense(feat, name=f'layers_{i}')(x)
            if i != len(self.features) - 1:
                x = nn.softmax(x)
        return x


## Generation of training data

In [4]:
def generate_training_trajectories(potential, initial_conditions, t_values, tau):
    num_trajectories = len(initial_conditions)
    timesteps = len(t_values)
    trajectories = jnp.zeros((num_trajectories, timesteps))
    trajectories = trajectories.at[:, 0].set(initial_conditions)

    for i in range(1, timesteps):
        x_prev = trajectories[:, i - 1]
        t_prev = t_values[i - 1]
        grad_x = vmap(lambda x: grad(potential, argnums=0)(x, t_prev))(x_prev)
        x_next = x_prev - tau * grad_x
        trajectories = trajectories.at[:, i].set(x_next)

    return trajectories

## Generation of predicted trajectories
We explore two methods for generating predicted trajectories: implicit and explicit schemes. The implicit method updates the position using the gradient of the potential at the future state, while the explicit method uses the gradient at the current state. The equations for each approach are shown below:

Implicit scheme:
$$
x_{t+1} = x_{t} - \tau \nabla V(x_{t+1}, t+1).
$$

Explicit scheme:
$$
x_{t+1} = x_{t} - \tau \nabla V(x_{t}, t).
$$

In [5]:
# Helper Functions for Trajectory Generation
def generate_pred_trajectories_explicit(potential, params, x_initials, t_values, tau):
    num_trajectories = len(x_initials)
    num_timesteps = len(t_values)
    
    trajectories = jnp.zeros((num_trajectories, num_timesteps))
    
    trajectories = trajectories.at[:, 0].set(x_initials)
    
    for i in range(1, num_timesteps):
        t_prev = t_values[i - 1]
        
        for j in range(num_trajectories):
            x_prev = trajectories[j, i - 1]
            grad_x = grad(lambda x: potential(params, x, t_prev))(x_prev)
            x_next = x_prev - tau * grad_x
            trajectories = trajectories.at[j, i].set(x_next)
        
    return trajectories

def generate_pred_trajectories_implicit(potential, params, x_initials, t_values, tau):
    def implicit_eq(x_next, x_prev, t_next):
        return x_next - x_prev + tau * grad(lambda x: potential(params, x, t_next))(x_next)
    
    num_trajectories = len(x_initials)
    trajectory_length = len(t_values)
    
    # Initialize a zero array to store all trajectories
    trajectories = jnp.zeros((num_trajectories, trajectory_length))
    
    for idx, x_initial in enumerate(x_initials):
        x_learned_trajectory = jnp.zeros_like(t_values)
        x_learned_trajectory = x_learned_trajectory.at[0].set(x_initial)
        
        for i in range(1, trajectory_length):
            x_prev = x_learned_trajectory[i - 1]
            t_next = t_values[i]
            
            # Perform Newton-Raphson step to find the next x
            x_next = newton_raphson_step(lambda x_next: implicit_eq(x_next, x_prev, t_next), x_prev)
            x_learned_trajectory = x_learned_trajectory.at[i].set(x_next)
        
        # Store the trajectory
        trajectories = trajectories.at[idx].set(x_learned_trajectory)
    
    return trajectories

def newton_raphson_step(f, x0, tol=1e-5, max_iter=100):
    def body_fun(val):
        x0, fx, iter_count = val
        dfx = grad(f)(x0)
        x0 = x0 - fx / dfx
        iter_count += 1
        return x0, f(x0), iter_count

    def cond_fun(val):
        _, fx, iter_count = val
        return (jnp.abs(fx) >= tol) & (iter_count < max_iter)

    x0, fx, iter_count = jax.lax.while_loop(cond_fun, body_fun, (x0, f(x0), 0))
    return x0


## Loss scheme
We must also introduce time in the loss equation. The first approach involves evaluating the potential in the loss function at the last time step, resulting in a fully implicit scheme. This means that the loss captures the relationship between the predicted trajectories and the potential's gradient at the subsequent time.

Implicit in time loss scheme:
$$
\sum_{t=0}^{T-1} \int_{\mathbb{R}^d \times \mathbb{R}^d} \left\| 
    \nabla V(x_{t+1}, t+1) + \frac{1}{\tau}(x_{t+1}-x_t) 
\right\|^2 \mathrm{d}\gamma_t(x_t, x_{t+1}).
$$

The second choice is to evaluate the potential at the previous time step in the loss equation,
which would result in a scheme implicit in space and explicit in time.

Explicit in time loss scheme:

$$
\sum_{t=0}^{T-1} \int_{\mathbb{R}^d \times \mathbb{R}^d} \left\| 
    \nabla V(x_{t+1}, t) + \frac{1}{\tau}(x_{t+1}-x_t) 
\right\|^2 \mathrm{d}\gamma_t(x_t, x_{t+1}).
$$

## Main Function: `plot_trajectories`

This function takes several parameters related to the trajectory prediction scheme and performs the following tasks:
1. **Generates Trajectories**: Creates training trajectories.
2. **Subsampling**: Reduces the number of data points based on the selected subsampling rate. This affects the size of the time step, which accentuates the difference between schemes.
3. **Model Initialization**: Sets up the machine learning model and optimizer.
4. **Training**: Trains the model over a specified number of epochs.
5. **Prediction and Visualization**: Compares learned trajectories with real trajectories.

Here’s the code for the function:

In [6]:
# Main Function for Interactive Plotting
def plot_trajectories(traj_scheme, loss_scheme, subsampling_rate, learning_rate, num_epochs, num_trajectories):
    key = random.PRNGKey(42)
    timesteps = 50
    x_initials = random.uniform(key, (num_trajectories,), minval=0.7, maxval=1.3)
    t_values = jnp.linspace(0, 1, timesteps)
    tau = t_values[1] - t_values[0]
    x_trajectories = generate_training_trajectories(V_real, x_initials, t_values, tau)
    
    # Subsampling
    x_trajectories_sub = jnp.array([trajectory[::subsampling_rate] for trajectory in x_trajectories])
    t_values_sub = t_values[::subsampling_rate]

    # Training data preparation
    y_batch = x_trajectories_sub[:, 1:]
    x_batch = x_trajectories_sub[:, :-1]
    t_x = jnp.tile(t_values_sub[:-1], (num_trajectories, 1))
    t_y = jnp.tile(t_values_sub[1:], (num_trajectories, 1))

    # Initialize model and optimizer
    key = random.PRNGKey(0)
    input_shape = (2,)
    features = [20, 20, 1]
    model = MLP(features)
    params = model.init(key, jnp.ones(input_shape))
    optimizer = optax.adam(learning_rate)
    opt_state = optimizer.init(params)

    def aux_fun(params, x, t):
        inputs = jnp.concatenate([x[..., None], t[..., None]], axis=-1)
        return model.apply(params, inputs).squeeze()

    grad_model = grad(aux_fun, argnums=1)

    def mlp_loss(params, x, y, t):
        return jnp.mean(jnp.square(y - x + tau * vmap(grad_model, in_axes=(None, 0, 0))(params, y, t)))

    x_concatenated = jnp.concatenate(x_batch)
    y_concatenated = jnp.concatenate(y_batch)
    t_x_concatenated = jnp.concatenate(t_x)
    t_y_concatenated = jnp.concatenate(t_y)
    
    #Depending on the scheme, we will use the previous time step or the next time step in the loss expression
    if loss_scheme == 'Explicit':
        t_concatenated = t_x_concatenated
    else:
        t_concatenated = t_y_concatenated

    for epoch in range(num_epochs):
        loss, grads = jax.value_and_grad(mlp_loss)(params, x_concatenated, y_concatenated, t_concatenated)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)

        if epoch % 100 == 0:
            print(f"Epoch {epoch}, Loss: {loss}")

    # Prediction and Plotting
    x_initial = x_trajectories_sub[:1][:, 0]
    
    if traj_scheme == 'Explicit':
        trajectories_pred = generate_pred_trajectories_explicit(aux_fun, params, x_initial, t_values_sub, tau)
    else:
        trajectories_pred = generate_pred_trajectories_implicit(aux_fun, params, x_initial, t_values_sub, tau)
        
    plt.figure(figsize=(8, 6))
    plt.plot(t_values_sub, trajectories_pred[0], label='Learned Trajectory', linestyle='--', marker='o')
    plt.plot(t_values, x_trajectories[0], label='Real Trajectory', linestyle='-', marker='x')
    plt.xlabel('Time')
    plt.ylabel('X')
    plt.title('Comparison of Learned and Real Trajectories')
    plt.legend()
    plt.show()


## Interactive Widgets

The following widgets allow you to customize the parameters used. 

- **Prediction Scheme**: Choose between `Explicit` and `Implicit` methods for trajectory prediction.
- **Loss Scheme**: Choose between `Explicit` and `Implicit` in time loss equation.
- **Subsampling Rate**: Adjust how many data points to skip during training. Adjust time step size.
- **Learning Rate**: Set the step size for the optimizer.
- **Number of Epochs**: Define how many iterations the model should train.
- **Number of Trajectories**: Specify how many trajectories. Amount of training data.

Here are the widgets:

In [7]:
# Create Widgets for Interaction
traj_scheme_widget = widgets.Dropdown(
    options=['Explicit', 'Implicit'],
    value='Implicit',
    description='Prediction Scheme:',
    disabled=False,
)

loss_scheme_widget = widgets.Dropdown(
    options=['Explicit', 'Implicit'],
    value='Implicit',
    description='Loss Scheme (time):',
    disabled=False,
)

subsampling_rate_widget = widgets.IntSlider(
    value=5,
    min=1,
    max=5,
    step=1,
    description='Subsampling:',
    disabled=False
)

learning_rate_widget = widgets.FloatSlider(
    value=0.05,
    min=0.001,
    max=0.1,
    step=0.001,
    description='Learning Rate:',
    disabled=False
)

num_epochs_widget = widgets.IntSlider(
    value=100,
    min=100,
    max=10000,
    step=100,
    description='Epochs:',
    disabled=False
)

num_trajectories_widget = widgets.IntSlider(
    value=10,
    min=1,
    max=50,
    step=1,
    description='Trajectories:',
    disabled=False
)

# Button to Update Plot
plot_button = widgets.Button(
    description='Update Plot',
    disabled=False,
    button_style='success'
)

output = widgets.Output()


# Callback Function for Button Click
def on_plot_button_clicked(b):
    
    with output:
        # Clear previous output in the output widget
        clear_output(wait=True)

        # Call the plot function with updated parameters
        plot_trajectories(traj_scheme_widget.value,
                          loss_scheme_widget.value,
                          subsampling_rate_widget.value,
                          learning_rate_widget.value,
                          num_epochs_widget.value,
                          num_trajectories_widget.value)

plot_button.on_click(on_plot_button_clicked)




## Update Plot Button

The button below will update the plot with the selected parameters when clicked. The callback function handles the interaction.

In [8]:
display(traj_scheme_widget, loss_scheme_widget, subsampling_rate_widget, learning_rate_widget, num_epochs_widget, num_trajectories_widget, plot_button, output)


Dropdown(description='Prediction Scheme:', index=1, options=('Explicit', 'Implicit'), value='Implicit')

Dropdown(description='Loss Scheme (time):', index=1, options=('Explicit', 'Implicit'), value='Implicit')

IntSlider(value=5, description='Subsampling:', max=5, min=1)

FloatSlider(value=0.05, description='Learning Rate:', max=0.1, min=0.001, step=0.001)

IntSlider(value=100, description='Epochs:', max=10000, min=100, step=100)

IntSlider(value=10, description='Trajectories:', max=50, min=1)

Button(button_style='success', description='Update Plot', style=ButtonStyle())

Output()