It turns out that training Q3 on the full dataset significantly helps performance.

# Imports, Functions and Data Loading

In [None]:
from toolz import pipe
import uwnet.interface
from uwnet.model import MLP
import xarray as xr

from uwnet.interface import call_with_xr as forward_xr

# define paths for data and nn model
train_data_path = "../data/processed/2018-10-02-ngaqua-subset.nc"
model_path = "../models/4/9.pkl"

# load the model and training data
data = xr.open_dataset(train_data_path)
# mlp = MLP.from_path(model_path)

# Analysis

In [None]:
# Visualization imports

from ipywidgets import interact, FloatSlider
from gnl.colorblind import colorblind_matplotlib
colorblind_matplotlib()

In [None]:
def load_and_predict(model_path, data, **kw):
    mlp = MLP.from_path(model_path)
    return forward_xr(mlp, data, **kw)


def plot_mean_drift(prediction):

    fu_mean = prediction.FU.mean(['x', 'time','y'])
    funn_mean = prediction.FUNN.mean(['x', 'time','y'])
    du_obs = (prediction.UOBS[-1]  - prediction.UOBS[0])/(prediction.time[-1]-prediction.time[0])/86400
    du_obs = du_obs.mean(['x', 'y'])

    plt.figure(figsize=(3,6))

    fu_mean.plot(y='z', label='FU')
    funn_mean.plot(y='z', label='FUNN')
    (funn_mean+fu_mean).plot(label='FU-FUNN', y='z')
    (du_obs).plot(label=r'$\Delta U / \Delta t$', y='z')

    a = 2e-5
    plt.xlim([-a, a])
    plt.legend()

def plot_mean_drift_file(model_path):
    model = MLP.from_path(model_path)
    diagnosis = forward_xr(model, data)
    plot_mean_drift(diagnosis)

In [None]:
model_path = "../models/18/4.pkl"
diagnosis = load_and_predict(model_path, data)
prediction = load_and_predict(model_path, data, n=1)

In [None]:
prediction.U.isel(x=0).plot(x='time')

## Dissipation

In [None]:
dims = ['x', 'time']
dissip_x = (prediction.FUNN * prediction.U).mean(dims)/(prediction.U**2).mean(dims)


plt.plot(dissip_x.values*86400)
plt.grid()
plt.xlabel('Vertical grid number')

The model is mostly damping in the in the free troposphere, but it is amplifying in the lowest few grid points.

In [None]:

plt.plot(1/np.abs(dissip_x)/86400)
plt.grid()
ax = plt.gca()

ticks = np.arange(0, dissip_x.shape[0], 5)
ax.set_xticks(ticks)
ax.set_xticklabels(dissip_x.z[ticks].values)
plt.xlabel('Height')
plt.ylabel('Damping/growth time-scale')

The time scales vary from around 1 day in the boundary layer to around 20 in the free troposphere.

## Drift in Mean state

In [None]:
plot_mean_drift(prediction)

Is this problem also in the diagnosis?

# Analysis

Why are these biases happening? Could it be something with the loss function?

In [None]:
def get_mom_budget(data):
    dt = 3*3600
    STOR= data.U.diff('time')/dt
    return  xr.Dataset(dict(
        FU=data.FU,
        FUNN=data.FUNN
    )).assign(STOR=STOR, Q3=STOR-data.FU)

    
    
#     data.FUNN.plot()
# path = "../models/11/13.pkl"

diagnosis = forward_xr(MLP.from_path(model_path), data)
mom = get_mom_budget(diagnosis)

In [None]:
mom.to_array(name='SRC').isel(x=0).plot(col='variable', col_wrap=1, aspect=4, size=2, x='time')