The lack of an explicit momentum source could cause the biases we are seeing in the SAM simulation. In this notebook, I train a neural network for the momentum source in the same manners as I did for QT and SLI. I also want to see the parametrized source that this gives.

# Imports, Functions and Data Loading

In [None]:
from toolz import pipe
import uwnet.interface
from uwnet.model import MLP
import xarray as xr

import torch
from torch.utils.data import DataLoader
from uwnet.datasets import XRTimeSeries
from uwnet.utils import concat_dicts


# define paths for data and nn model
train_data_path = "../data/processed/2018-10-02-ngaqua-subset.nc"
model_path = "../models/4/9.pkl"

# load the model and training data
data = xr.open_dataset(train_data_path)
# mlp = MLP.from_path(model_path)

# Analysis

In [None]:
# Visualization imports

from ipywidgets import interact, FloatSlider
from gnl.colorblind import colorblind_matplotlib
colorblind_matplotlib()

In [None]:
def forward_xr(model, ds, **kwargs):
    """Run column simulation with prescribed forcings"""
    ds = ds.isel(z=model.z)
    data = XRTimeSeries(ds.load(), [['time'], ['x', 'y'], ['z']])
    loader = DataLoader(data, batch_size=1024, shuffle=False)

    constants = data.torch_constants()


    print("Running model")
    model.add_forcing = True
    # prepare input for mod
    outputs = []
    with torch.no_grad():
        for batch in loader:
            batch.update(constants)
            out = model(batch, **kwargs)
            outputs.append(out)

    # concatenate outputs
    out = concat_dicts(outputs, dim=0)

    def unstack(val):
        val = val.detach().numpy()
        dims = ['xbatch', 'xtime', 'xfeat'][:val.ndim]
        coords = {key: data._ds.coords[key] for key in dims}

        if val.shape[-1] == 1:
            dims.pop()
            coords.pop('xfeat')
            val = val[..., 0]
        ds = xr.DataArray(val, dims=dims, coords=coords)
        for dim in dims:
            ds = ds.unstack(dim)

        # transpose dims
        dim_order = [dim for dim in ['time', 'z', 'y', 'x'] if dim in ds.dims]
        ds = ds.transpose(*dim_order)

        return ds

    print("Reshaping and saving outputs")
    out_da = {key: unstack(val) for key, val in out.items()}

    truth_vars = set(out) & set(data.data)
    rename_dict = {key: key + 'OBS' for key in truth_vars}

    ds = xr.Dataset(out_da).merge(data.data.rename(rename_dict))
    return ds


def load_and_predict(model_path, data, **kw):
    mlp = MLP.from_path(model_path)
    return forward_xr(mlp, data, **kw)

In [None]:
diagnosis = load_and_predict(model_path, data)
prediction = load_and_predict(model_path, data, n=1)

In [None]:
prediction.U.isel(x=0).plot(x='time')

In [None]:
prediction.UOBS.isel(x=0).plot(x='time')

## Dissipation

In [None]:
dims = ['x', 'time']
dissip_x = (prediction.FUNN * prediction.U).mean(dims)/(prediction.U**2).mean(dims)


plt.plot(dissip_x.values*86400)
plt.grid()
plt.xlabel('Vertical grid number')

The model is mostly damping in the in the free troposphere, but it is amplifying in the lowest few grid points.

In [None]:

plt.plot(1/np.abs(dissip_x)/86400)
plt.grid()
ax = plt.gca()

ticks = np.arange(0, dissip_x.shape[0], 5)
ax.set_xticks(ticks)
ax.set_xticklabels(dissip_x.z[ticks].values)
plt.xlabel('Height')
plt.ylabel('Damping/growth time-scale')

The time scales vary from around 1 day in the boundary layer to around 20 in the free troposphere.

## Drift in Mean state

In [None]:
def plot_mean_drift(prediction):

    fu_mean = prediction.FU.mean(['x', 'time','y'])
    funn_mean = prediction.FUNN.mean(['x', 'time','y'])
    du_obs = (prediction.UOBS[-1]  - prediction.UOBS[0])/(prediction.time[-1]-prediction.time[0])/86400
    du_obs = du_obs.mean(['x', 'y'])

    plt.figure(figsize=(3,6))

    fu_mean.plot(y='z', label='FU')
    funn_mean.plot(y='z', label='FUNN')
    (funn_mean+fu_mean).plot(label='FU-FUNN', y='z')
    (du_obs).plot(label=r'$\Delta U / \Delta t$', y='z')

    a = 2e-5
    plt.xlim([-a, a])
    plt.legend()


In [None]:
plot_mean_drift(prediction)

Is this problem also in the diagnosis?

In [None]:
plot_mean_drift(diagnosis)

It is. Perhaps we should penalize the difference in the mean drifts? Would this work batchwise? Perhaps we should also make sure the output of the neural network is reasonably smooth.