The data we are using are only evaulated with a coarse sampling time step of 3 hours. On the other hand, we will probably use 10-20 minute time step for the coarse resolution model. This means that the dynamical model we are trying to fit is 
$$ x^i_{n+1} = \underbrace{f(f(\ldots f}_{\text{m times}}(x^i_n))) + \int_{t_n}^{t_{n+1}} g(x(t), t) dt$$ 
where $i$ is the horizontal spatial index, and $n$ is the time step. The number of times the function $f$ is applied is $m=\frac{\Delta t}{h}$ where $h$ is the GCMs time step, and $\Delta t$ is the sampling interval of the stored output. The integral on the right represents the approximately known terms such as advection, and $f$ represents the unknown source terms.

We solve a minimization problem to find $f$. This is given by 
$$
\min_{a} \lim_{m \rightarrow \infty} \sum_{i,n} ||x^{i}_{n+1} - F^{(m)} x^i_{n} - g_n^{i}||_W^2 \quad \text{s.t.}\quad F^{(m)}(\cdot) = \underbrace{f(f(\ldots f}_{\text{m times}}(\cdot))),\ f(x) = x +  \frac{ \Delta t}{m} a(x).
$$
Intuitively, the forward operator $F^{(m)}$ is the result applying $m$ forward euler steps to the system $a$.

Let's try performing this fit. First, we need to import the appropriate models, and load the data

# TODO

- [x] Implement torch-only preprocessing routine.
- [ ] Implement euler time stepping routine. This can be based off `torch_cli.py`

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import numpy as np
import xarray as xr
import torch

This is the coarse sampling time step

In [None]:
dt = 3/24

Let's now define a torch module for the function $a$. It will just be a single layer perceptron, which appropriately scales the inputs first. Let's first compute the appropriate scaling

In [None]:
data = np.load("../data/ml/ngaqua/time_series_data.npz")

X = data['X']
G = data['G']
scale = data['scales']
w = data['w']

In [None]:
x = X[:-1,8,0,:]
xp = X[1:, 8,0,:]
g = G[:-1,8,0,:]

plt.pcolormesh(x[:,:34]-a.mu)

In [None]:
def plot_q(x):
    plt.pcolormesh(x[:,34:].T)
    
def plot_t(x):
    plt.figure(figsize=(12,2))
    plt.pcolormesh(x[:,:34].T, vmin=0,vmax=50)

In [None]:
plot_t((xp-x)/dt-g)
plt.colorbar()

In [None]:
from lib.models.torch_models import predict

In [None]:
net = torch.load("../data/ml/ngaqua/time_series_fit.torch")

now let's use it to make a prediction

In [None]:
plot_t(predict(net, x))
plt.colorbar()

It seems to do a pretty good job compared to the run above.