The data we are using are only evaulated with a coarse sampling time step of 3 hours. On the other hand, we will probably use 10-20 minute time step for the coarse resolution model. This means that the dynamical model we are trying to fit is 
$$ x^i_{n+1} = \underbrace{f(f(\ldots f}_{\text{m times}}(x^i_n))) + \int_{t_n}^{t_{n+1}} g(x(t), t) dt$$ 
where $i$ is the horizontal spatial index, and $n$ is the time step. The number of times the function $f$ is applied is $m=\frac{\Delta t}{h}$ where $h$ is the GCMs time step, and $\Delta t$ is the sampling interval of the stored output. The integral on the right represents the approximately known terms such as advection, and $f$ represents the unknown source terms.

We solve a minimization problem to find $f$. This is given by 
$$
\min_{a} \lim_{m \rightarrow \infty} \sum_{i,n} ||x^{i}_{n+1} - F^{(m)} x^i_{n} - g_n^{i}||_W^2 \quad \text{s.t.}\quad F^{(m)}(\cdot) = \underbrace{f(f(\ldots f}_{\text{m times}}(\cdot))),\ f(x) = x +  \frac{ \Delta t}{m} a(x).
$$
Intuitively, the forward operator $F^{(m)}$ is the result applying $m$ forward euler steps to the system $a$.

Let's try performing this fit. First, we need to import the appropriate models, and load the data

In [None]:
import numpy as np
import xarray as xr
import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset

import attr

from xnoah.data_matrix import stack_cat

In [None]:
# weight
w = xr.open_dataset("../data/processed/ngaqua/w.nc")
# data
X = xr.open_mfdataset("../data/calc/ngaqua/*.nc")[['qt', 'sl']].load()
G = xr.open_mfdataset("../data/calc/adv/ngaqua/*.nc").load()

This is the coarse sampling time step

In [None]:
dt = float(X.time[1] - X.time[0])*86400
dt

How many 10 minute time steps are in this period?

In [None]:
dt//(60*10)

Let's now define a torch module for the function $a$. It will just be a single layer perceptron, which appropriately scales the inputs first. Let's first compute the appropriate scaling

In [None]:
def prepvar(X):
    return stack_cat(X, "feature", ["z"])

def prep_for_torch(X):
    x =  prepvar(X).pipe(np.asarray)
    return Variable(torch.FloatTensor(x))

In [None]:
class ScaledSLP(nn.Module):
    """Scaled single layer preceptron"""
    
    def __init__(self, mu, sig):
        super(ScaledSLP, self).__init__()
        
        self.mu = mu
        self.sig = sig
        
        self.net = nn.Sequential(
            nn.Linear(68, 256), nn.ReLU(),
            nn.Linear(256, 68)
        )
        
    def forward(self, x):
        x = x.sub(self.mu).mul(1/(self.sig + 1e-5))
        return self.net.forward(x)
        
        
class EulerStepper(nn.Module):
    """time stepping class"""
    def __init__(self, net, n, h):
        super(EulerStepper, self).__init__()
        self.net = net
        self.n = n
        self.h = h
        
        # The default weight initialization are too large,
        # and yield unreasonably large values applied several times 
        # when 
        # for now let's just use 0.0 as an initial condition
        # this should probably be changed to random initalization
        # with very small weights.
        for mod in self.modules():
            if isinstance(mod, nn.Linear):
                mod.weight.data.fill_(0.0)
                mod.bias.data.fill_(0.0)
        

    def forward(self, x):
        for i in range(self.n):
            x =  x +  self.net(x).mul(self.h)
        return x
        

Can we apply this once?

In [None]:
mu = prep_for_torch(X.mean(['x', 'y', 'time']))
sig = prep_for_torch(X.std(['x', 'y', 'time']))

In [None]:
a = ScaledSLP(mu, sig)
stepper = EulerStepper(a, n=18, h=dt/18)

In [None]:
x = prep_for_torch(X.isel(x=0,y=0,time=0))
stepper.forward(x)[1:10]

It appears to work. Now let's work on fitting this model.

# Data Loading 

We need to define some utilities for loading the data from xarray.

In [None]:
def get_stacked_var(X):
    return prepvar(X).stack(samples=['x', 'y', 'time']).transpose('samples', 'feature')


class ParamDataset(Dataset):
    
    def __init__(self, X, G):
        batch_dims = ['x', 'y', 'time']
        
        # forcing and initial condition data
        self.x_stacked = get_stacked_var(X.isel(time=slice(0, -1)))
        self.g_stacked = get_stacked_var(G.isel(time=slice(0,-1)))
        
        # data shifted forward in time
        self.xp_stacked = get_stacked_var(X.isel(time=slice(1, None)))
        
    def __len__(self):
        return len(self.x_stacked)
    
    def __getitem__(self, idx):
        
        return [np.asarray(x) for x in
        [self.x_stacked[idx], self.xp_stacked[idx], self.g_stacked[idx]]]


train_dataset = ParamDataset(X.isel(y=slice(32-3,32+3)), G.isel(y=slice(32-3,32+3)))

Pass this dataset object to `DataLoader`, which will generate the batches.

In [None]:
train_loader = DataLoader(train_dataset, batch_size=100, shuffle=True)

# Loss funciton

In [None]:
# TODO replace these functions
from sklearn.externals import joblib

data_dict = joblib.load("../data/ml/ngaqua/data.pkl")
w = data_dict['w'][1]
scale = data_dict['scale'][1]
w_torch = Variable(torch.FloatTensor(w/scale**2))

In [None]:
def loss_function(y, output, g):
    return torch.mean(torch.pow(output + g.mul(dt)-y, 2).mul(w_torch))

# Learning

Now, let's train the data!

In [None]:
from tqdm import tqdm

In [None]:
num_epochs = 1

# make the optimizer function
optimizer = torch.optim.Adam(stepper.parameters(), lr=.001)

# Do the time stepping
for epoch in range(num_epochs):
    avg_loss = 0
    for batch_idx, batch in tqdm(enumerate(train_loader), total=len(train_loader)):
        x,y,g = map(Variable, batch)
        optimizer.zero_grad()  # this is not done automatically
        pred = stepper(x)
        loss = loss_function(y, pred, g)
        loss.backward()
        optimizer.step()

        avg_loss += loss.data.numpy()

    print(f"Epoch: {epoch} [{batch_idx}]\tLoss: {avg_loss}")
    avg_loss = 0


In [None]:
t = prep_for_torch(X.isel(y=32, x=0))

In [None]:
plt.pcolormesh(a(t).data.numpy().T)

The training basically failed. This idea of using multiple time steps is very expensive.