# Lagged predictions

In this notebook, we plot the output of the time stepping scheme over varius prediction intervals, and compare it to the truth. In particular, define the lagged prediction of a given time point be given by
$$ \tilde{x}^n_m = f^m(x^{n-m}),$$
where $f^m$ is the time stepping operator defined recursively by
    $$f^m(x) = f(f^{m-1}(x)) + \frac{\Delta t}{2}\left(g^{m-1} + g^{m}\right)$$
and $f^0(x) = x$. In the formula above, $g^m$ is the advection forcing evaluated at time step $m$.

The neural network or other machine learning models gives the function $f$ for the unforced variability. The minimization problem used to find $f$ is
$$ \text{argmin}_{f} \sum_{i,n} \sum_{m=0}^{\ell} ||f^m(x_i^{n-m}) - x_i^n||^2_w.$$
Here $i$ is the index of the horizontal spatial location and $n$ is the time index.

This minimization problem tries to more directly assess the error of running the scheme in a prognostic mode for $\ell$ time steps. I have done numerical experiments showing that using $\ell=2$ tends to give an unstable scheme. Indeed, for $\ell=2$ the loss function is proportional to the error in predicting the discrete time derivative, which is what we were working on before. In previous reports, I stabilized that scheme by adding an L2 penalization to the weights of the network, but it turns out that we do not need to do this when $\ell=10$. This result is analogous to the result that accuracy in numerical ODE requires both consistency and stability. When $\ell=2$ the objective function above only cares about consistency, but with $\ell=10$ a stable scheme is needed.

# Results

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
from lib.evaluation.single_column import  xr_runsteps, lagged_predictions
import xarray as xr
import numpy as np
import torch

Here we load the data for a given horizontal location on the equator. We also load the trained neural network predictor $f_{NN}$.

In [None]:
def subset(x):
    return x.isel(x=0, y=32)

inputs = xr.open_mfdataset("../data/calc/ngaqua/*.nc", preprocess=subset)
forcings = xr.open_mfdataset("../data/calc/forcing/ngaqua/*.nc", preprocess=subset)
w = xr.open_dataarray("../data/processed/ngaqua/w.nc")
stepper = torch.load("../data/ml/ngaqua/multistep_objective.torch")

p = xr.open_dataset("../data/raw/ngaqua/stat.nc").p

Now, I calculate $\tilde{x}_m^n$ using $f_{NN}$.

In [None]:
lagged_preds = lagged_predictions(stepper, inputs, forcings, w, 100)

## Comparing the fields at lag 10

Here, I compare $x^n$ with $\tilde{x}^{n}_{10}$.

In [None]:
def plot_preds(lagged_preds,lag=10, **kwargs):

    fig, axs = plt.subplots(3,1, sharex=True, sharey=True, figsize=(10,6))

    def plot_qt(ax, x):   
        im = ax.contourf(x.time.values, p.values, x.T, cmap='viridis', **kwargs)
        plt.colorbar(im, ax=ax)
        return im
    
    
    def add_label(ax, text):
        ax.text(105, 200, text, bbox=dict(color='white'))

    plot_qt(axs[0], lagged_preds.isel(lag=0))
    plot_qt(axs[1], lagged_preds.isel(lag=lag))

    err = lagged_preds.isel(lag=0) - lagged_preds.isel(lag=lag)
    im = axs[2].pcolormesh(err.time.values, p.values, err.T)
    plt.colorbar(im, ax=axs[2])

    axs[0].invert_yaxis()
    axs[-1].set_xlabel('time [d]')

    for ax in axs.flat:
        ax.set_ylabel('p [mb]')
        
    add_label(axs[0], 'Truth')
    add_label(axs[1], 'Lag %d'%lag)
    add_label(axs[2], 'Difference')
        
    plt.tight_layout()

Here is the plot for the humidity field $q_T$.

In [None]:
plot_preds(lagged_preds.qt, levels=np.arange(11)*2)
plt.xlim([100,130])

and for the temperature variable $s_L$

In [None]:
plot_preds(lagged_preds.sl-lagged_preds.sl.mean('time'), levels=[-5,-4,-3,-2,-1,0,1,2,3,4,5])
plt.ylim([1000,150])

## Vertically averaged MSE plots

Now, I summarize the difference using the mass-weighted mean square error averaged for each prediction lag over all time points. This quantity is given by the formula
$$ \sum_n ||x^n - \tilde{x}^n_m ||^2_w.$$

I plot this quantity for each physical variable ($q_T$, and $s_L$) separately. I also plot a horizontal dashed line indicating error of the time mean.

In [None]:
def mse_weighted(true, pred, w=w):
    return (((pred-true)**2) * w).sum('z')/w.sum('z')

lag_errors = mse_weighted(inputs, lagged_preds).mean('time')
mean_error = mse_weighted(inputs, inputs.mean('time')).mean('time')


# this closure will be helpful
def plot_err(key):
    lag_errors[key].plot()
    plt.axhline(mean_error[key], c='k', ls='--')
    plt.title(key)
    plt.xlim([0, 20])

In [None]:
plot_err('sl')

In [None]:
plot_err('qt')

We can predict the moisture with much more accuracy than the temperature.

## Vertical structures in Error

What does the vertical structure of this error look like?

In [None]:
mse_over_time = ((lagged_preds-inputs)**2).mean('time')

Here is the error for $q_T$

In [None]:
plt.contourf(mse_over_time.lag, p, mse_over_time.qt.T, 11)
plt.colorbar()
plt.gca().invert_yaxis()
plt.xlabel('lag')
plt.ylabel('p [mb]')

and for $s_L$

In [None]:
import matplotlib.colors as mc

plt.contourf(mse_over_time.lag, p, mse_over_time.sl.T, np.arange(11)*.5)
plt.colorbar()
plt.gca().invert_yaxis()
plt.xlabel('lag [d]')
plt.ylabel('p [mb]')
plt.ylim([1000, 150])

Notice that this last plot has a logarithmic colorbar. I had to use this because the error at the tropopause (near 100 mb) is so large compared to the errors elsewhere.

# Revisiting $s_l$ errors

Let's make the time series RMS plot again, but this time we will exclude the large errors in the stratosphere.

In [None]:
data = lagged_preds.assign(p=p, w=w).sel(z=slice(0,10e3))

mse = mse_weighted(data, data.isel(lag=0), data.w).mean('time')
mse_mean = mse_weighted(data.isel(lag=0).mean('time'), data.isel(lag=0)).mean('time')

In [None]:
mse.sl.plot()
plt.axhline(mse_mean.sl, c='k', ls='--')
plt.xlim([0,10])

The performance does not improve much.

# Conclusions

1. Using the multiple time step objective function gives more stable results than 1 step objective function.
2. The neural network scheme works much better for the moisture than it does for the temperature.