In this notebook I compare the impact a few different NNs have withing SAM

1. active (from last report)
2. passive (from last report)
3. NGAqua training data (from last report)
4. active Neural network which does not use FQT and FSLI as an input.

In [None]:
import os

# plotting libraries
import holoviews as hv
from holoviews.operation.datashader import datashade, shade, dynspread, rasterize
from toolz import valmap

import gnl
from uwnet.analysis.sam_debug import *
from uwnet.interface import step_with_xarray_inputs
from uwnet.model import MLP

# colorblind friendly defaults
hv.extension('bokeh')
gnl.colorblind()

In [None]:
ids = {'active': 'compassionate_chandrasekhar', 'passive': 'romantic_volta'}
training_data_path = "../data/training_data.nc"
model_path = "../data/samNN/curious_kilby/NG1/data.pkl"


paths = valmap(lambda x: os.path.join("../data", "samNN", x), ids)


def _open_debug(t):
    ds = open_debug_and_training_data(t, paths, training_data_path)
    return t * 30 / 60, ds


training_ds = xr.open_dataset(training_data_path)
ds = concat_datasets([_open_debug(t) for t in range(20)], name='t')
model = MLP.from_dict(torch.load(model_path)['dict'])

In [None]:
ds

Let's find a point with strong precipitation:

# Precipitation

In [None]:
i, j = (ds.Prec.sel(tag='active')[-1] > 115).values.nonzero()
loc = ds.isel(x=j, y=i).squeeze()

In [None]:
loc.Prec.plot(hue='tag')

We can see that the initial precipitation is way to high for this point. Why? Let's compute the neural networks prediction of precip for every point in the training data.

In [None]:
def _compute_precip(x):
    out_ds = step_with_xarray_inputs(model.step, x, 10.0)
    return out_ds.Prec

prec = training_ds.groupby('time').apply(_compute_precip)

And plot it:

In [None]:
x = prec.mean(['x', 'y'])
y = training_ds.Prec.mean(['x', 'y'])

scatter = hv.Scatter((x, y)) *\
hv.Curve([(x.min(), x.min()), (x.max(), x.max())])

time_series = hv.Curve(x, label="NN") * hv.Curve(y, label="NGAqua")

scatter.redim(x="NN Prec", y="NGAqua Prec") + time_series.opts(plot=dict(width=500))

It actually looks like the neural network under-estimates the mean precipitation for the time points with higher domain averaged precip. The neural network struggles with the diurnal cycle of precipitation. The time points used for the two dimensional fields are probably mis-aligned in the training data.

In [None]:
args  = tuple(np.ravel(x) for x in 
        xr.broadcast(prec, training_ds.Prec))
datashade(hv.Scatter(args)).redim(x="NN", y="Training Data").relabel("Precipitation (mm/day)")

# Vertical Velocity

What is $W$ doing:

In [None]:
loc.W.plot(row='tag', x='t', aspect=2)
plt.xlabel('Minutes')

We can see that this point has a very strong upward vertical velocity throughout the troposphere. In the active simulation, this vetical velocity increases dramatically within 10 minutes.

In [None]:
loc.W.isel(z=8).plot(hue='tag')

There is a nearly linear increase in W. I would expect this to look more exponential if it were some kind of neural network instability.

In [None]:
sli_anom = loc.SLI - loc.SLI[0]
sli_anom.plot(row='tag', x='t', aspect=2)
plt.xlabel('Minutes')

In [None]:
qt_anom = loc.QT - loc.QT[0]
qt_anom.plot(row='tag', x='t', aspect=2)
plt.suptitle('QT anomaly from initial value (g/kg)', y=1.05)
plt.xlabel('Minutes')

In [None]:
ds.Prec.isel(t=-1).plot(col='tag', col_wrap=2)

# Running the noForcingInput for longer times

Here I show some runs performed over a 1 day period. Without `dodamping=.true.` the domain average precipitation eventually diverges.

Declare some default plotting options

In [None]:
%opts Image[width=400, height=200, colorbar=True] (cmap='viridis')
%opts Image.W (cmap='RdBu_r')

In [None]:
# files = !ls ~/Data/0/72/c4093327a86a43f49340dc1cba8137/NG1*.pkl
files = !ls ~/Data/0/f3/91f3038360fe0bea70c33ab27a0903/NG1*.pkl # damping
# files = !ls ~/Data/0/6e/7484a67685a170effbecac09bfa7ca/NG1*.pkl # sgs
# files = !ls /Users/noah/Data/0/f4/407c81ec353da245cc43f488e06131/NG1*.pkl # microphysics
# files = !ls ../data/samNN/dmaping/NG1_*.pkl

In [None]:
def curve_t(file):
    d = torch.load(file)
    out = d['out']
    args, dt = d['args']
#     FSLI = (d['out']['SLI'] - args['SLI'])/dt * 86400
    nstep = int(file.rstrip('.pkl')[-6:])
    time = (nstep-1) * 30 / 3600

    return time, (hv.Image(args['W'][5], kdims=['x', 'y'], vdims=['W'], label='W')#.opts(style=dict(cmap='RdBu'))
          + hv.Image(out['Prec'][0], kdims=['x', 'y'], vdims=['Prec'], label='Prec')).cols(1)


hmap = hv.HoloMap(dict(curve_t(file) for file in files), kdims=['time']).collate()

In [None]:
%%output size=150
hmap.redim.range(W=(-.1, .1)).redim.unit(time='hr', W='m/s', Prec='mm/day')

A checkerboard pattern appears in both the W and Precip fields after a few hours.