# <span style="color:teal"> Introduction to surrogate modelling in the geosciences </span>

#### Marc Bocquet¹ [marc.bocquet@enpc.fr](mailto:marc.bocquet@enpc.fr) and Alban Farchi¹ [alban.farchi@enpc.fr](mailto:alban.farchi@enpc.fr)
##### (1) CEREA, École des Ponts and EdF R&D, IPSL, Île-de-France, France

During this session, we will apply standard machine learning methods to learn the dynamics of the Lorenz 1996 model. The objective here is to get a preview of how machine learning can be applied to geoscientific models in a low-order models where testing is quick.

## <span style="color:green"> The Lorenz 1996 model </span>

The Lorenz 1996 (L96, [Lorenz and Emanuel 1998](https://journals.ametsoc.org/view/journals/atsc/55/3/1520-0469_1998_055_0399_osfswo_2.0.co_2.xml)) is a low-order chaotic model commonly used in data assimilation to asses the performance of new algorithms. It represents the evolution of some unspecified scalar meteorological quantity (perhaps vorticity or temperature) over a latitude circle.

The model **dynamics** is driven by the following set of ordinary differential equations (ODEs):
$$
    \forall n \in [1, N_{x}], \quad \frac{\mathrm{d}x_{n}}{\mathrm{d}t} =
    (x_{n+1}-x_{n-2})x_{n-1}-x_{n}+F,
$$
where the indices are periodic: $x_{-1}=x_{N_{x}-1}$, $x_{0}=x_{N_{x}}$, and $x_{1}=x_{N_{x}+1}$, and where the system size $N_{x}$ can take arbitrary values.

In the standard configuration, $N_{x}=40$ and the forcing coefficient is $F=8$. The ODEs are integrated using a fourth-order Runge-Kutta scheme with a time step of $0.05$ model time unit (MTU). The resulting dynamics is **chaotic** with a doubling time of errors around $0.42$ MTU (the corresponding Lyapunov is hence $0.61$ MTU). For comparison, $0.05$ MTU represent six hours of real time and correspond to an average autocorrelation around $0.967$. Finally, the model variability (spatial average of the time standard deviation per variable) is $3.64$.

In this series of experiments, we will try to emulate the dynamics of the L96 model using artificial neural networks (NN).
1. We start by running the **true model** to build a training dataset.
2. We build and **train neural networks** using this dataset.
3. We evaluate the **forecast skill** of the surrogate models (the NNs).

## <span style="color:green"> The true model dynamics </span>

Before building the training dataset, let us illustrate the model dynamics.

### <span style="color:blue"> Importing all modules and define some visualisation functions</span>

In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
from tqdm.notebook import tqdm, trange

def plot_l96_traj(
    x,
    model,
    linewidth,
):
    fig = plt.figure(figsize=(linewidth, linewidth/3))
    plt.grid(False)
    im = plt.imshow(
        x.T, 
        aspect = 'auto',
        origin = 'lower',
        interpolation = 'spline36',
        cmap = sns.diverging_palette(240, 60, as_cmap=True),
        extent = [0, model.dt*x.shape[0], 0, model.Nx],
        vmin = -10,
        vmax = 15,
    )
    plt.colorbar(im)
    plt.xlabel('Time (MTU)')
    plt.ylabel('Lorenz 96 variables')
    plt.tick_params(direction='out', left=True, bottom=True)
    plt.show()

def plot_l96_compare_traj(
    x_ref,
    x_pred,
    model,
    linewidth,
):
    error = x_pred - x_ref
    fig = plt.figure(figsize=(linewidth, linewidth))
    ax = plt.subplot(311)
    ax.grid(False)
    im = plt.imshow(
        x_ref.T, 
        aspect = 'auto',
        origin = 'lower',
        interpolation = 'spline36',
        cmap = sns.diverging_palette(240, 60, as_cmap=True),
        extent = [0, model.dt*x_pred.shape[0], 0, model.Nx],
        vmin = -10,
        vmax = 15,
    )
    ax.set_title('true model integration')
    plt.colorbar(im)
    ax.set_ylabel('Lorenz 96 variables')
    ax.tick_params(direction='out', left=True, bottom=True)
    ax.set_xticklabels([])
    ax = plt.subplot(312)
    ax.grid(False)
    im = plt.imshow(
        x_pred.T,
        aspect = 'auto',
        origin = 'lower',
        interpolation = 'spline36',
        cmap = sns.diverging_palette(240, 60, as_cmap=True),
        extent = [0, model.dt*x_pred.shape[0], 0, model.Nx],
        vmin = -10,
        vmax = 15,
    )
    ax.set_title('surrogate model integration')
    plt.colorbar(im)
    ax.set_ylabel('Lorenz 96 variables')
    ax.tick_params(direction='out', left=True, bottom=True)
    ax.set_xticklabels([])
    ax = plt.subplot(313)
    ax.grid(False)
    im = ax.imshow(
        error.T, 
        aspect = 'auto',
        origin = 'lower',
        interpolation = 'spline36',
        cmap = sns.diverging_palette(240, 10, as_cmap=True),
        extent = [0, model.dt*error.shape[0], 0, model.Nx],
        vmin = -15,
        vmax = 15,
    )
    ax.set_title('signed error')
    plt.colorbar(im)
    ax.set_xlabel('Time (MTU)')
    ax.set_ylabel('Lorenz 96 variables')
    ax.tick_params(direction='out', left=True, bottom=True)
    plt.show()

def get_plotly_color_palette(alpha=None):
    if alpha is None:
        return [
            'rgb(99, 110, 250)',
            'rgb(239, 85, 59)',
            'rgb(0, 204, 150)',
            'rgb(171, 99, 250)',
            'rgb(255, 161, 90)',
            'rgb(25, 211, 243)',
            'rgb(255, 102, 146)',
            'rgb(182, 232, 128)',
            'rgb(255, 151, 255)',
            'rgb(254, 203, 82)'
        ]
    else:
        return [
            f'rgba(99, 110, 250, {alpha})',
            f'rgba(239, 85, 59, {alpha})',
            f'rgba(0, 204, 150, {alpha})',
            f'rgba(171, 99, 250, {alpha})',
            f'rgba(255, 161, 90, {alpha})',
            f'rgba(25, 211, 243, {alpha})',
            f'rgba(255, 102, 146, {alpha})',
            f'rgba(182, 232, 128, {alpha})',
            f'rgba(255, 151, 255, {alpha})',
            f'rgba(254, 203, 82, {alpha})'
        ]

def plot_l96_forecast_skill(
    fss,
    model,
    p1,
    p2,
    xmax,
    linewidth,
):
    fig = go.Figure()
    palette = get_plotly_color_palette()
    spalette = get_plotly_color_palette(alpha=0.2)
    
    for (index, key) in enumerate(fss):
        
        time = (model.dt/model.lyap_time)*np.arange(fss[key].shape[0])
        rmse_m = fss[key].mean(axis=1) / model.model_var
        rmse_p1 = np.percentile(fss[key], p1, axis=1) / model.model_var
        rmse_p2 = np.percentile(fss[key], p2, axis=1) / model.model_var
        
        fig.add_scatter(
            x=time,
            y=rmse_m,
            name=key,
            customdata=np.arange(len(time)),
            hovertemplate='index = %{customdata}, value = %{y:.3f}',
            line_color=palette[index]
        )
        fig.add_scatter(
            x=np.concatenate([time, time[::-1]]),
            y=np.concatenate([rmse_p1, rmse_p2[::-1]]),
            fill='toself',
            name=key+' (CI)',
            hoverinfo='skip',
            fillcolor=spalette[index],
            line_width=0,
            mode='lines'
        )        
        
    fig.update_xaxes(title_text='Time (Lyapunov time)')
    fig.update_yaxes(title_text='normalised RMSE')
    fig.update_layout(
        title='Forecast skill', 
        xaxis_range=[0, xmax], 
        yaxis_range=[0, 2], 
        width=linewidth, 
        height=0.7*linewidth,
        hovermode='x unified',
    )
    fig.add_hline(
        y=np.sqrt(2), 
        line_width=1,
        line_dash='dash',
        line_color='black',
        label_text=r'$\sqrt{2}$',
        label_textposition='start',
    )
    fig.show()
    
def plot_learning_curve(
    loss,
    val_loss,
    title,
    linewidth,
):
    
    fig = go.Figure()
    palette = get_plotly_color_palette()
    
    fig.add_scatter(
        x=np.arange(len(loss)),
        y=loss,
        name='training loss',
        customdata=np.arange(len(loss)),
        hovertemplate='epoch = %{customdata}, value = %{y:.3f}',
        line_color=palette[0]
    )
    
    fig.add_scatter(
        x=np.arange(len(val_loss)),
        y=val_loss,
        name='validation loss',
        customdata=np.arange(len(val_loss)),
        hovertemplate='epoch = %{customdata}, value = %{y:.3f}',
        line_color=palette[1]
    )
    
    fig.update_xaxes(title_text='Number of epochs')
    fig.update_yaxes(title_text='MSE', type='log')
    fig.update_layout(title=title, width=linewidth, height=0.7*linewidth, hovermode='x unified')

    fig.show()

class TQDMCallback(tf.keras.callbacks.Callback):
    
    def __init__(self, desc, loss=None, val_loss=None):
        super().__init__()
        self.desc = desc
        self.metrics = {'loss':loss, 'val_loss':val_loss}
    
    def on_train_begin(self, logs=None):
        self.epoch_bar = tqdm(total=self.params['epochs'], desc=self.desc)
    
    def on_train_end(self, logs=None):
        self.epoch_bar.close()
        
    def on_epoch_end(self, epoch, logs=None):
        for name in self.metrics:
            self.metrics[name]  = logs.get(name, self.metrics[name])
        self.epoch_bar.set_postfix(mse=self.metrics['loss'], val_mse=self.metrics['val_loss'], refresh=False)
        self.epoch_bar.update()

### <span style="color:blue"> Defining the true model </span>

In the following cell, we define the true Lorenz 1996 model using standard values: 
- the number of variables $N_{x}$ is set to `Nx=40`;
- the forcing coefficient $F$ is set to `F=8`;
- the integration time step is set to `dt=0.05`.

<span style="color:red"> Exercise </span>
- Implement the `tendency()` method of the `Lorenz1996Model` class.
  This method should compute the model tendencies. You may use the
  [`roll()`](https://numpy.org/doc/stable/reference/generated/numpy.roll.html)
  function of `numpy`.
- Implement the `forward()` method of the `Lorenz1996Model` class.
  This method should compute an integration step forward in time.
  The Runge--Kutta scheme is explained in the method's docstring.
  A simple straightforward implementation with six statements is more 
  than enough for the present set of experiments.

In [None]:
class Lorenz1996Model:
    """Implementation of the Lorenz 1996 model.
    
    Use the `tendency()` method to compute the model tendencies (i.e., dx/dt)
    and use the `forward()` method to apply an integration step forward in time,
    using a fourth order Runge--Kutta scheme.
    
    Attributes
    ----------
    Nx : int
        The number of variables in the model.
    F : float
        The model forcing.
    dt : float
        The model integration time step.
    """

    def __init__(self, Nx, F, dt):
        """Initialise the model."""
        self.Nx = Nx
        self.F = F
        self.dt = dt

    def tendency(self, x):
        """Compute the model tendencies dx/dt.
        
        The tendencies are computed by batch using
        `numpy` vectorisation.
        
        Parameters
        ----------
        x : np.ndarray, shape (..., Nx)
            Batch of input states.
            
        Returns
        -------
        dx_dt : np.ndarray, shape (..., Nx)
            Model tendencies computed at the input states.
        """
        # TODO: implement it!
        xp = np.roll(x, shift=-1, axis=-1)
        xmm = np.roll(x, shift=+2, axis=-1)
        xm = np.roll(x, shift=+1, axis=-1)
        return (xp - xmm)*xm - x + self.F

    def forward(self, x):
        """Apply an integration step forward in time.
        
        This method uses a fourth-order Runge--Kutta scheme:
        k1 <- dx/dt at x
        k2 <- dx/dt at x + dt/2*k1
        k3 <- dx/dt at x + dt/2*k2
        k4 <- dx/dt at x + dt*k3
        k <- (k1 + 2*k2 + 2*k3 + k4)/6
        x <- x + dt*k
        
        Parameters
        ----------
        x : np.ndarray, shape (..., Nx)
            Batch of input states.
            
        Returns
        -------
        integrated_x : np.ndarray, shape (..., Nx)
            The integrated states after one step.
        """
        # TODO: implement it!
        k1 = self.tendency(x)
        k2 = self.tendency(x+self.dt/2*k1)
        k3 = self.tendency(x+self.dt/2*k2)
        k4 = self.tendency(x+self.dt*k3)
        k = (k1 + 2*k2 + 2*k3 + k4)/6
        return x + self.dt*k

In [None]:
# create model
true_model = Lorenz1996Model(Nx=40, dt=0.05, F=8)

# save some statistics about the model
true_model.model_var = 3.64
true_model.doubling_time = 0.42
true_model.lyap_time = 0.61

### <span style="color:blue"> Short model integration </span>

In the following cells, we perform a rather short model integration, in order to illustrate the model dynamics. The initial condition is a random field.

<span style="color:red"> Exercise </span>
- Implement the true model integration in the `perform_true_model_integration()` function.
  A simple implementation with a `for-loop` should do the job.

In [None]:
def perform_true_model_integration(Nt, Ne=1, seed=None):
    """Perform an integration in time using the true model.
    
    The initial state is a batch of random fields.
    
    Parameters
    ----------
    Nt : int
        The number of integration steps to perform.
    Ne : int
        The batch size.
    seed : int
        The random seed for the initialisation.
        
    Returns
    -------
    xr : np.ndarray, shape (Nt+1, Ne, Nx)
        The integrated batch of trajectories.
    """
    # define rng
    rng = np.random.default_rng(seed=seed)

    # allocate memory
    xt = np.zeros((Nt+1, Ne, true_model.Nx))

    # initialisation
    xt[0] = rng.normal(loc=3, scale=1, size=(Ne, true_model.Nx))
    
    # TODO: implement the model integration for Nt steps
    for t in trange(Nt, desc='model integration'):
        xt[t+1] = true_model.forward(xt[t])
    
    # return the trajectory
    return xt

In [None]:
# short model integration for visualisation purpose
xt_plot = perform_true_model_integration(Nt=500, Ne=1, seed=3)[:, 0]

In [None]:
# plot the trajectory
plot_l96_traj(
    xt_plot, 
    true_model,
    linewidth=18,
)

We see first a spin-up period of about $1$ MTU, where the initial condition is progressively forgotten and the trajectory progressively gets back to the model attractor. After this spin-up period, the dynamics is characterised by waves moving slowly towards the east (i.e. decreasing variable index). 

## <span style="color:green"> Prepare the dataset </span>

### <span style="color:blue"> A long model integration for the training data</span>

We now use a true model trajectory to make the **training dataset**. This trajectory starts from a random field (different than the one used for the plotting trajectory) and we discard the first $100$ time steps to get rid of the spin-up process.

In [None]:
# long model integration for the training data
xt_train = perform_true_model_integration(Nt=10_000+100, Ne=1, seed=31)[:, 0]

# discard the spin-up process
xt_train = xt_train[100:]

### <span style="color:blue"> Preprocess the training data </span>

The training dataset is made of input/output pairs, where the input is the state at a given time, and the output is the state at the following time.

<span style="color:red"> Exercise </span>
- Implement the `extract_input_output()` function, in which the 
  neural network input and output are extracted from a given
  trajectory. Use `numpy` slicing for this.

In [None]:
def extract_input_output(xt):
    # TODO: extract x (input)
    x = xt[:-1]
    # TODO: extract y (output)
    y = xt[1:]
    # return input/output
    return (x, y)

In [None]:
# extract input/output from the training data
x_train, y_train = extract_input_output(xt_train)

We compute the normalisation using the training data.

In [None]:
# compute input/output mean/std
x_mean = x_train.mean()
y_mean = y_train.mean()
x_std = x_train.std()
y_std = y_train.std()

# define normalisation/denormalisation functions
def normalise_x(x):
    return (x - x_mean)/x_std
def normalise_y(y):
    return (y - y_mean)/y_std
def denormalise_x(x_norm):
    return x_norm*x_std + x_mean
def denormalise_y(y_norm):
    return y_norm*y_std + y_mean

Finally, the training data is normalised. 

In [None]:
# normalise the training data
x_train_norm = normalise_x(x_train)
y_train_norm = normalise_y(y_train)

### <span style="color:blue"> Shorter model integrations for the validation and testing data</span>

We repeat the same process to make the **validation** and **testing** data. In this case, the trajectory starts from two other random fields (and we still get rid of the spin-up processes) and can be somewhat shorter, but the normalisation must be the same as for the training data.

In [None]:
# short model integration for the validation data
xt_valid = perform_true_model_integration(Nt=1_000+100, Ne=1, seed=314)[:, 0]

# discard the spin-up process
xt_valid = xt_valid[100:]

# extract input/output from the validation data
x_valid, y_valid = extract_input_output(xt_valid)

# normalise the validation data
x_valid_norm = normalise_x(x_valid)
y_valid_norm = normalise_y(y_valid)

In [None]:
# short model integration for the testing data
xt_test = perform_true_model_integration(Nt=1_000+100, Ne=1, seed=3141)[:, 0]

# discard the spin-up process
xt_test = xt_test[100:]

# extract input/output from the testing data
x_test, y_test = extract_input_output(xt_test)

# normalise the testing data
x_test_norm = normalise_x(x_test)
y_test_norm = normalise_y(y_test)

### <span style="color:blue"> An ensemble model integration for the forecast skill data</span>

In order to assess the forecast skill of the surrogate model, we will use a different test dataset, in which we record an ensemble of **trajectories** (instead of an ensemble of input/output pairs). This will allow us to measure the accuracy of the forecast for longer integration times.

In [None]:
# ensemble integration for the forecast skill data
xt_fs = perform_true_model_integration(Nt=400+100, Ne=512, seed=31415)

# discard the spin-up process
xt_fs = xt_fs[100:]

## <span style="color:green"> The baseline model: persistence </span>

In this first test series, we use **persistence** as surrogate model. This will provide baselines for our NN results. Persistence is defined as the model for which there is no time evolution.

### <span style="color:blue"> Evaluate the model</span>

The mean square error (MSE) is the loss function that we will use to train our NNs later. Therefore, the test MSE is a measure of the efficiency of the learning/training process.

In [None]:
# compute test MSE
test_mse_baseline = np.mean(np.square(y_test_norm - x_test_norm))

# show test MSE
print('-'*100)
print(f'test mse of persistence = {test_mse_baseline}')
print('-'*100)

The test MSE of persistence is a number whose absolute value is not that important per se (because the input and output data have been normalised) but it will be useful to normalise the test MSE of our trained NNs.

In the following cell, we compute the forecast skill of persistence. It will be illustrated in the following sections.

In [None]:
# compute forecast skill
fs_baseline = np.sqrt(np.mean(np.square(xt_fs-xt_fs[0]), axis=2))

### <span style="color:blue"> Example of surrogate model integration</span>

In the following cell, we show one example of model integration.

In [None]:
# compare the true and surrogate model integration for one trajectory
plot_l96_compare_traj(
    xt_fs[:, 0],
    np.broadcast_to(xt_fs[0, 0], shape=xt_fs[:, 0].shape),
    true_model,
    linewidth=18,
)

### <span style="color:blue"> Forecast skill</span>

In the following cells, we plot the average forecast skill, normalised by the model variability. The shadow delimits the 90% confidence interval (percentiles 5 and 95).

In [None]:
# plot the forecast skill
plot_l96_forecast_skill(
    dict(
        persistence=fs_baseline,
    ),
    true_model,
    p1=5,
    p2=95,
    xmax=4,
    linewidth=1000,
)

The error rapidly grows as time evolves. After about $1$ Lyapunov time, the error oscillates around $\sqrt{2}$, which is the theoretical asymptotic value due to the normalisation and which is consistent with the wave behaviour of the dynamics.

## <span style="color:green"> A naive ML model </span>

### <span style="color:blue"> Construct and train the model</span>

In this second test series, we train and evaluate a dense NN (sequential NN with only dense layers). In order to create this model, we use the [sequential API of tensorflow](https://www.tensorflow.org/api_docs/python/tf/keras/Sequential).

<span style="color:red"> Exercise </span>
- Implement the `make_sequential_network()` function, in which a 
  sequential neural network is created. The neural network should
  take as input the current state and return the forecasted state.

In [None]:
def make_sequential_network(seed, num_layers, num_nodes, activation):
    """Build a sequential neural network.
    
    Parameters
    ----------
    seed : int
        The random seed.
    num_layers : int
        The number of hidden layers.
    num_nodes : int
        The number of nodes per hidden layer.
    activation : str
        The activation function for the hidden layers.
        
    Returns
    -------
    network : tf.keras.Sequential
    """
    # set seed
    tf.keras.utils.set_random_seed(seed=seed)
    # TODO: create a sequential network
    network = tf.keras.models.Sequential()
    # TODO: add the input layer
    network.add(tf.keras.Input(shape=(true_model.Nx,)))
    # TODO: add the hidden layers
    for i in range(num_layers):
        network.add(tf.keras.layers.Dense(num_nodes, activation=activation))
    # TODO: add the output layer
    network.add(tf.keras.layers.Dense(true_model.Nx))
    # compile the neural network
    network.compile(loss='mse', optimizer='adam')
    # print short summary
    network.summary()
    # return the network
    return network

def train_network(seed, num_epochs, description, patience, model):
    """Train a neural network.
    
    Parameters
    ----------
    seed : int
        The random seed.
    num_epochs : int
        The number of epochs.
    description : str
        The progress bar description.
    patience : int
        The patience for EarlyStopping.
    model : tf.keras.Model
        The network to train.
    
    Returns
    -------
    history : dict
        The training history.
    """
    # train the ML model
    # set random seed
    tf.keras.utils.set_random_seed(seed=seed)
    # tqdm callback
    tqdm_callback = TQDMCallback(description)
    # early stopping callback
    early_stopping_callback = tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=patience,
        verbose=0,
        restore_best_weights=True,
    )
    fit = model.fit(
        x_train_norm, 
        y_train_norm,
        verbose=0,
        epochs=num_epochs, 
        validation_data=(x_valid_norm, y_valid_norm),
        callbacks=[tqdm_callback, early_stopping_callback],
    )
    return fit.history

In the following cell, we build a dense NN with $4$ internal layers and $128$ nodes per layer. The total number of parameters of this model is $59944$. This is actually quite large for a $40$-variable system. This is because the dense architecture is rather "inefficient" in terms of parameters.

In [None]:
# construct the naive neural network
naive_network = make_sequential_network(seed=314159, num_layers=4, num_nodes=128, activation='relu')

In the following cell, we train the model for $256$ epochs. We use an EarlyStopping callback to end the training when the validation loss stops improving. This should avoid overfitting.

In [None]:
# train the network
fit_naive = train_network(
    seed=3141592, 
    num_epochs=256, 
    description='naive NN training', 
    patience=16, 
    model=naive_network,
)

In the following cell we plot the training history, that is, the evolution of the training MSE (the `loss`) and the validation MSE (the `val_loss`) as a function of the number of epochs.

In [None]:
# plot the learning history
plot_learning_curve(
    fit_naive['loss'],
    fit_naive['val_loss'],
    title='Naive NN training',
    linewidth=1000,
)

Both values are visually closely related. The validation MSE is more noisy than the training MSE, which is expected because the training data is ten times as large as the validation data. After several epochs, the validation MSE gets a bit higher than the training MSE. This is explained by the fact that this data is not used in the gradient descent algorithm. Finally, at the end the validation MSE stops improving. This is the sign that the model is starting to overfit the training data and that we should stop the training.

### <span style="color:blue"> Evaluate the model</span>

We now compute the test MSE to evaluate our surrogate model. 

In [None]:
# compute test MSE
test_mse_naive = naive_network.evaluate(x_test_norm, y_test_norm, verbose=0, batch_size=x_test_norm.shape[0])

# show test MSE
print('-'*100)
print(f'test mse of persistence = {test_mse_baseline}')
print(f'test mse of naive model = {test_mse_naive}')
print()
print(f'relative test mse of naive model = {test_mse_naive/test_mse_baseline}')
print('-'*100)

We obtain a reduction of about 80%, which is already quite good, but we will see later that it is possible to do much better.

We now compute and illustrate the forecast skill of the neural network.

<span style="color:red"> Exercise </span>
- Implement the `compute_forecast_skill()` function, in which we 
  use a surrogate model to predict the trajectories and then
  compute the forecast skill. Use the `predict()` method of
  `tf.keras.Model` inside a `for-loop`.

In [None]:
def compute_forecast_skill(model):
    """Compute the forecast skill.
    
    Parameters
    ----------
    model : tf.keras.Model
        The model to evaluate.
        
    Returns
    -------
    fs : np.ndarray, shape (Nt+1, Ne)
        The forecast skill.
    """
    # allocate memory
    (Nt, Ne, Nx) = xt_fs.shape
    xt = np.zeros((Nt, Ne, Nx))
    
    # initialisation
    xt[0] = xt_fs[0]
    
    # TODO: implement the neural network integration
    for t in trange(Nt-1, desc='surrogate model integration'):
        x_norm = normalise_x(xt[t])
        y_norm = model.predict(x_norm, batch_size=Ne, verbose=0)
        xt[t+1] = denormalise_y(y_norm)
        
    # compute and return the forecast skill
    fs = np.sqrt(np.mean(np.square(xt_fs-xt), axis=2))
    return (xt, fs)

In [None]:
# compute forecast skill
xt_naive, fs_naive = compute_forecast_skill(naive_network)

### <span style="color:blue"> Example of surrogate model integration</span>

In the following cell, we show once again one example of model integration.

In [None]:
# compare the true and surrogate model integration for one trajectory
plot_l96_compare_traj(
    xt_fs[:, 0],
    xt_naive[:, 0],
    true_model,
    linewidth=18,
)

The error is lower than in the first test series, but only during the first few integration steps.

### <span style="color:blue"> Forecast skill</span>

In the following cell, we plot once again the average forecast skill, normalised by the model variability.

In [None]:
# plot the forecast skill
plot_l96_forecast_skill(
    dict(
        persistence=fs_baseline,
        naive=fs_naive,
    ),
    true_model,
    p1=5,
    p2=95,
    xmax=4,
    linewidth=1000,
)

This curve confirms that the naive surrogate model is more accurate than persistence for one integration step, and that it remains more accurate until about $2$ Lyapunov times.

## <span style="color:green"> A smart ML model </span>

### <span style="color:blue"> Build and train the model</span>

In this third and last test series, we train and evaluate a smart NN. This NN uses a sparse architecture with convolutional NN and controlled nonlinearity to reproduce the **model tendencies**, as well as a Runge-Kutta integration scheme to **emulate the dynamics**. In order to implement this NN, we use both the [functional API](https://www.tensorflow.org/guide/keras/functional) (for the model tendency) and the [subclassing API](https://www.tensorflow.org/guide/keras/custom_layers_and_models) (for the integration scheme) of tensorflow.

In [None]:
class SmartNetwork(tf.keras.Model):
    """Smart neural network for the Lorenz 1996 model.
    
    Attributes
    ----------
    dt : float
        The integration time step.
    tendency : tf.keras.Model
        The network to compute the tendencies.
    """
    
    def __init__(self, num_filters, kernel_size, dt=0.05, **kwargs):
        """Initialise the smart network.
        
        Parameters
        ----------
        num_filters : int
            Number of filters to use in the convolutional layer.
        kernel_size : int
            Size of the convolutional kernel.
        dt : float
            Integration time step.
        kwargs : dict
            Additional parameters forwarded to `tf.keras.Model.__init__()`.
        """
        super().__init__(**kwargs)
        self.dt = dt
        
        # reshape layers
        reshape_input = tf.keras.layers.Reshape((true_model.Nx, 1))
        reshape_output = tf.keras.layers.Reshape((true_model.Nx,))
        
        # padding layer
        border = kernel_size//2
        def apply_padding(x):
            x_left = x[..., -border:, :]
            x_right = x[..., :border, :]
            return tf.concat([x_left, x, x_right], axis=-2)
        padding_layer = tf.keras.layers.Lambda(apply_padding)
        
        # convolutional layers
        conv_layer_1 = tf.keras.layers.Conv1D(num_filters, kernel_size)
        conv_layer_2 = tf.keras.layers.Conv1D(1, 1)
        
        # network for the model tendencies
        x_in = tf.keras.Input(shape=(true_model.Nx,))
        # reshape the input to be able to use convolutional layers
        x = reshape_input(x_in)
        # apply convolution with periodic padding
        x = padding_layer(x)
        x1 = conv_layer_1(x)
        # construct non-linear terms
        x2 = x1 * x1
        # concatenate linear and non-linear terms
        x3 = tf.concat([x1, x2], axis=-1)
        # combine all channels into one
        # there is no actual convolution here 
        # because the kernel_size is one for this layer
        x_out = conv_layer_2(x3)
        # reshape the output after the convolutional layers
        x_out = reshape_output(x_out)
        # pack everything into a tf.keras.Model
        self.tendency = tf.keras.Model(inputs=x_in, outputs=x_out)
    
    @tf.function
    def call(self, x):
        """Apply the network."""
        dx_dt_0 = self.tendency(x)
        dx_dt_1 = self.tendency(x+0.5*self.dt*dx_dt_0)
        dx_dt_2 = self.tendency(x+0.5*self.dt*dx_dt_1)
        dx_dt_3 = self.tendency(x+self.dt*dx_dt_2)
        dx_dt =  (dx_dt_0 + 2*dx_dt_1 + 2*dx_dt_2 + dx_dt_3)/6
        return x + self.dt*dx_dt
    
def make_smart_network(seed, num_filters, kernel_size):
    """Build a sequential neural network.
    
    Parameters
    ----------
    seed : int
        The random seed.
    num_filters : int
        The number of filters.
    kernel_size : int
        The convolution kernel.
        
    Returns
    -------
    network : SmartNetwork
        The smart network.
    """
    # set seed
    tf.keras.utils.set_random_seed(seed=seed)
    # create the network
    network = SmartNetwork(
        num_filters=num_filters, 
        kernel_size=kernel_size, 
        dt=true_model.dt,
    )
    # compile the neural network
    network.compile(loss='mse', optimizer='adam')
    # print short summary
    network.tendency.summary()
    # return the network
    return network

In [None]:
# construct the smart neural network
smart_network = make_smart_network(seed=31415926, num_filters=6, kernel_size=5)

The total number of parameters is only $49$. Furthermore in this case, with well-chosen parameters it is possible to reproduce the true dynamics up to machine precision: the model is said to be **identifiable**. Also note that this network is built in such a way that we don't need the input and output data to be normalised.

In the following cell, we train the model for up to $128$ epochs. Once again, we use an EarlyStopping callback to end the training when the validation loss stops improving in order to avoid overfitting.

In [None]:
# train the network
fit_smart = train_network(
    seed=314159265, 
    num_epochs=128, 
    description='smart NN training', 
    patience=8, 
    model=smart_network,
)

In the following cell we plot the training history.

In [None]:
# plot the learning history
plot_learning_curve(
    fit_smart['loss'],
    fit_smart['val_loss'],
    title='Smart NN training',
    linewidth=1000,
)

Once again, the training and validation MSE are visually closely related. However, by contrast with the previous test series, after about $35$ epochs, the MSEs have decreased to $10^{-9}$, which should be very close to the numerical precision zero (tensorflow is working on simple precision for real numbers). Passed $40$ epochs, the MSEs oscillate at very low values. This behaviour can be considered as numerical noise.

### <span style="color:blue"> Evaluate the model</span>

We now compute the test MSE to evaluate our surrogate model. 

In [None]:
# compute test MSE
test_mse_smart = smart_network.evaluate(x_test_norm, y_test_norm, verbose=0, batch_size=x_test_norm.shape[0])

# show test MSE
print('-'*100)
print(f'test mse of persistence = {test_mse_baseline}')
print(f'test mse of naive model = {test_mse_naive}')
print(f'test mse of smart model = {test_mse_smart}')
print()
print(f'relative test mse of naive model = {test_mse_naive/test_mse_baseline}')
print(f'relative test mse of smart model = {test_mse_smart/test_mse_baseline}')
print('-'*100)

The test MSE is sufficiently close to zero so that we can consider that our surrogate model reproduces the true model dynamics up to numerical precision.

In [None]:
# compute forecast skill
xt_smart, fs_smart = compute_forecast_skill(smart_network)

### <span style="color:blue"> Example of surrogate model integration</span>

In the following cell, we show once again one example of model integration.

In [None]:
# compare the true and surrogate model integration for one trajectory
plot_l96_compare_traj(
    xt_fs[:, 0],
    xt_smart[:, 0],
    true_model,
    linewidth=18,
)

This time, the error is so low that it is not visible until about $4$ MTU. At that time, the true model trajectory and the surrogate model trajectory diverge from each other. Indeed, the two models are equivalent up to numerical precision, but they are not bit-wise equivalent, which means that this divergence is unavoidable because of the chaotic nature of the dynamics. 

### <span style="color:blue"> Forecast skill</span>

In the following cell, we plot once again the average forecast skill, normalised by the model variability.

In [None]:
# plot the forecast skill
plot_l96_forecast_skill(
    dict(
        persistence=fs_baseline,
        naive=fs_naive,
        smart=fs_smart,
    ),
    true_model,
    p1=5,
    p2=95,
    xmax=30,
    linewidth=1000,
)

This curve confirms that the the smart surrogate model is equivalent to the true model up to numerical precision. The numerical divergence between the true and surrogate model happens on average after about $5$ Lyapunov times.