<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Training-process-(2-fig)" data-toc-modified-id="Training-process-(2-fig)-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Training process (2 fig)</a></span></li><li><span><a href="#Diagnostic-Results-(1-2-figs)" data-toc-modified-id="Diagnostic-Results-(1-2-figs)-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Diagnostic Results (1-2 figs)</a></span><ul class="toc-item"><li><span><a href="#Q1-and-Q2" data-toc-modified-id="Q1-and-Q2-2.1"><span class="toc-item-num">2.1&nbsp;&nbsp;</span>Q1 and Q2</a></span></li><li><span><a href="#Precipitation" data-toc-modified-id="Precipitation-2.2"><span class="toc-item-num">2.2&nbsp;&nbsp;</span>Precipitation</a></span></li></ul></li><li><span><a href="#Equilibrium-Statistics-comparison-(3-4-figures)" data-toc-modified-id="Equilibrium-Statistics-comparison-(3-4-figures)-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Equilibrium Statistics comparison (3-4 figures)</a></span><ul class="toc-item"><li><span><a href="#Prec-vs.-Lat-biases-for-NN" data-toc-modified-id="Prec-vs.-Lat-biases-for-NN-3.1"><span class="toc-item-num">3.1&nbsp;&nbsp;</span>Prec vs. Lat biases for NN</a></span></li><li><span><a href="#Precipitation" data-toc-modified-id="Precipitation-3.2"><span class="toc-item-num">3.2&nbsp;&nbsp;</span>Precipitation</a></span></li><li><span><a href="#Biases-in-equilibrium-state-(equator)" data-toc-modified-id="Biases-in-equilibrium-state-(equator)-3.3"><span class="toc-item-num">3.3&nbsp;&nbsp;</span>Biases in equilibrium state (equator)</a></span></li><li><span><a href="#Standard-Deviation" data-toc-modified-id="Standard-Deviation-3.4"><span class="toc-item-num">3.4&nbsp;&nbsp;</span>Standard Deviation</a></span></li></ul></li><li><span><a href="#Transient-Comparisons" data-toc-modified-id="Transient-Comparisons-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Transient Comparisons</a></span><ul class="toc-item"><li><span><a href="#Single-Column-(1-3-figures)" data-toc-modified-id="Single-Column-(1-3-figures)-4.1"><span class="toc-item-num">4.1&nbsp;&nbsp;</span>Single Column (1-3 figures)</a></span></li><li><span><a href="#Error-growth-(equator)-(3-figures)" data-toc-modified-id="Error-growth-(equator)-(3-figures)-4.2"><span class="toc-item-num">4.2&nbsp;&nbsp;</span>Error growth (equator) (3 figures)</a></span><ul class="toc-item"><li><span><a href="#Humidity-Errors" data-toc-modified-id="Humidity-Errors-4.2.1"><span class="toc-item-num">4.2.1&nbsp;&nbsp;</span>Humidity Errors</a></span></li><li><span><a href="#Temperature-Errors" data-toc-modified-id="Temperature-Errors-4.2.2"><span class="toc-item-num">4.2.2&nbsp;&nbsp;</span>Temperature Errors</a></span></li><li><span><a href="#Column-Integrated-Error" data-toc-modified-id="Column-Integrated-Error-4.2.3"><span class="toc-item-num">4.2.3&nbsp;&nbsp;</span>Column Integrated Error</a></span></li></ul></li></ul></li></ul></div>

In [None]:
%matplotlib inline
import holoviews as hv
hv.extension('matplotlib')
import seaborn as sns
sns.set_style('whitegrid')
plt.rc("image", cmap='viridis')

In [None]:
import xarray as xr
from xnoah import swap_coord
from lib.plots import plot_soln
import lib
from lib import cam as lc
import glob
from toolz import valmap
from toolz.curried import get

from xnoah import integrate

In the following cell I load the truth, SCAM and NN single column datasets.

In [None]:
def load_data():

    # Truth
    truth = xr.open_dataset("../../data/processed/inputs.nc")
    force = xr.open_dataset("../../data/processed/forcings.nc")
    truth = truth.assign(prec=force.Prec)
    truth = swap_coord(truth, {'z': 'p'})

    # SCAM
    # cam = xr.open_dataset("../data/processed/iop/0-8/cam.nc").squeeze()\
    #         .sel(time=truth.time[:-10])

    
    # load and interpolate CAM
    cam = xr.open_dataset("../../data/output/scam.nc")
    cam = lib.pressure_interp_ds(cam, truth.p)

    # neural network scheme
    nn_cols = xr.open_dataset("../../data/output/model.VaryNHid-256/7.columns.nc").assign(p=truth.p)
    nn_cols = swap_coord(nn_cols, {'z': 'p'})
    
    # combine data
    datasets = [truth, nn_cols, cam]
    variables = ['qt', 'prec', 'sl']
    time = np.unique(np.intersect1d(cam.time.values, truth.time.values))
    data = [ds[variables].sel(time=time) for ds in datasets]
    model_idx = pd.Index(['Truth', 'Neural Network', 'CAM'], name="model")
    ds = xr.concat(data, dim=model_idx)
    
    # change y coordinates
    y = (ds.y-np.median(ds.y))
    ds['y'] = y

    return ds
    


def compute_errors(metric, ds, **kwargs):

    mad = {}
    truth = ds.sel(model='Truth')

    for key in ['Neural Network', 'CAM']:
        mad[key] = metric(truth, ds.sel(model=key), **kwargs).sortby('time')

    mad['Mean'] = metric(truth, truth.mean(['x', 'time']), **kwargs).sortby('time')
    mad['Persistence'] = metric(truth, truth.isel(time=0), **kwargs).sortby('time')
    
    return mad


def mean_squared_error(truth, pred, dims=('x',)):
    return ((truth-pred).fillna(0.0)**2).mean(dims)


def mean_absolute_dev(truth, pred, dims=('x',)):
    return (truth-pred).fillna(0.0).apply(np.abs).mean(dims)

Load the data

In [None]:
ds = load_data()

# Training process (2 fig)

I need to add some plots showing the training vs testing error for different numbers of epochs

In [None]:
from lib.plots import training as tp

data = tp.get_plotting_data("../../data/output/")

In [None]:
tp.plot_parameter_sensitivity(data)

In [None]:
tp.plot_epochs_vs_loss(data)

# Diagnostic Results (1-2 figs)

## Q1 and Q2

## Precipitation

# Equilibrium Statistics comparison (3-4 figures)

## Prec vs. Lat biases for NN

In this section I look at the bias between the means of the neural network scheme and the time series from SAM. Running SCAM is quite expensive, so I have only run SCAM for the equatorial points.

In [None]:
def plot_pres_vs_lat(bias, ax, levels=np.arange(-5, 6)*.5, title=None):
    im = ax.contourf(bias.y, bias.p, bias, levels, cmap='bwr', extend='both')

    plt.colorbar(im, pad=.01, ax=ax)
    ax.set_xlabel('y (km)')
    
    if title:
        ax.set_title(title)

    
bias = ds.sel(model='Neural Network').mean(['x', 'time']) - ds.sel(model='Truth').mean(['x', 'time'])

fig, axs = plt.subplots(1,2, figsize=(7,3), dpi=100, sharey=True)
plot_pres_vs_lat(bias.qt, axs[0], title="Humidity bias (g/kg)")
plot_pres_vs_lat(bias.sl, axs[1], title="Temperature bias (K)")
plt.subplots_adjust(wspace=.02)
axs[0].invert_yaxis()

## Precipitation 

Precipitation biases.

In [None]:
prec = ds.prec.mean(['x', 'time']).to_dataset('model')

plt.figure(figsize=(5, 5/1.61), dpi=100)

prec['Truth'].plot(label='Truth')
prec['Neural Network'].plot(label='NN')

plt.scatter([prec.y[8]], [prec.CAM.isel(y=8)], label='CAM')

plt.ylabel('Prec (mm/day)')
plt.legend(loc="upper right")

## Biases in equilibrium state (equator)

In this section I just show vertical profiles of the bias that NN and CAM have compared to the mean on the equator.

In [None]:
def plot_bias(ds, ax=None, title="", xlim=None, unit=None):
    obs = ds.sel(model='Truth')
    ds = ds.sel(model=['Neural Network', 'CAM'])
    
    if ax is None:
        fig, ax = plt.subplots(1,1, figsize=(2,3), dpi=100)

    for key, val in ds.groupby('model'):
        ax.plot(val.squeeze()-obs, val.p, label=key)
    ax.set_title(title, size=10)
    if xlim:
        ax.set_xlim(xlim)
    if unit:
        ax.set_xlabel(unit)


We can see that the bias is much smaller in the neural network model than in CAM.

In [None]:
loc = ds.isel(y=8).mean(['x', 'time'])

fig, axs = plt.subplots(1,2, figsize=(4,3), dpi=100, sharey=True)
plot_bias(loc.qt, title='Total water bias', unit='g/kg',
          ax=axs[0], xlim=[-1, 1])
plot_bias(loc.sl, title='Temperature bias', unit='K',
          ax=axs[1], xlim=[-2, 2])


axs[0].invert_yaxis()
axs[0].set_ylabel('p (hPa)')
axs[1].legend()
axs[1].set_ylim([1000, 80])

## Standard Deviation

In [None]:
dims = ['x', 'time']
sig = ds.std(dims)

In [None]:
%%opts Curve[invert_yaxis=True, invert_axes=True] {+framewise}

(hv.Dataset(sig.qt)
 .to.curve("p")
 .overlay("model")
 .redim.unit(qt="g/kg", p="hPa")
 .redim.label(qt=r"$q_T$")
 
 + hv.Dataset(sig.sl)
 .to.curve("p")
 .overlay("model")
 .redim.unit(sl="K", p="hPa")
.redim.label(sl=r"$s_L$"))\
.redim.unit(y="1000 km")

Again, the neural network has a similar standard deviation to the truth, whereas CAM  has much larger variability in the lower atmosphere.

# Transient Comparisons

## Single Column (1-3 figures)

In this section I compare the observed time series for a given spatial location near the equator to the time series generated by forcing the neural network (NN) parametrization and the single column version of CAM.

I use the three dimensional advective tendency and the surface fluxes to force both NN and CAM, I also tried to match the diurnal cycle between the runs, but I am not sure I did this perfectly yet.

In [None]:
for name, data in ds.groupby('model'):
    axs = plot_soln(data.isel(x=0, y=8))
    
    for ax in axs[:-1]:
        ax.xaxis.set_visible(False)
        
    axs[-1].set_ylim([-10, 150])
    axs[-1].set_yticks([0, 50, 100])
    plt.subplots_adjust(hspace=0.0)
    axs[0].set_title(f"Model: {name}")

## Error growth (equator) (3 figures)

In this section, I plot the dynamic growth of errors after the beginning of the simulations. I do this to provide a more quantitative perspective on the pressure vs time series above.

I compare the zonally averaged mean absolute deviation (MAD) for four different predictions

1. MAD between the time series and its time and zonal mean. This measures the magnitude of the fluctuations about the climatology.
2. Persistence forecast. This forecast assumes that humidity and temperature do not change over the course of the simulation. This gives us an estimate of the time-scale over which the fields change naturally.
3. SCAM based prediction
4. NN based prediction

I only do this at the equator because the sCAM simulation is too expensive to run for the whole domain.

In [None]:
def plot_mses(mses, axs=None, label='', **kwargs):
    if axs is None:
        fig, axs =plt.subplots(2, 2, sharey=True, sharex=True, dpi=100,
                              figsize=(6, 3.5))
        
    keys = mses.keys()
    
    for ax, key in zip(axs.flat, keys):
        val = mses[key]
        im = ax.contourf(val.time, val.p, val.T, **kwargs)

        ax.text(.05, .8, key,
                transform=ax.transAxes,
                color='white', fontsize=13)
        
    
    axs[0,0].invert_yaxis()
    axs[0,0].set_ylabel('p (hPa)')
    axs[1,0].set_ylabel('p (hPa)')
    
    for ax in axs[1,:]:
        ax.set_xlabel('days')
    
    plt.subplots_adjust(wspace=.02, hspace=.02)
    cb = plt.colorbar(im, ax=axs, pad=.01)
    cb.set_label(label)
    
    axs[0,0].set_xlim([val.time.min(), val.time.min()+ 20])
    
    return axs, cb
    

### Humidity Errors

In [None]:
mad = compute_errors(mean_absolute_dev, ds.isel(y=8), dims=['x'])


In [None]:
mad_qt = valmap(get('qt'), mad)
_, cb = plot_mses(mad_qt, levels=np.arange(11)*.25,
                  extend='max',
                  label='MAD (g/kg)')

We can see errors grow much more slowly for NN than for the other quantities. In practice, the NN is able to predict around 5-7 days before it decays to the mean. This is much vetter than either the persistence forecast or CAM.

### Temperature Errors

In [None]:
mad_sl = valmap(get('sl'), mad)
plot_mses(mad_sl, levels=.5*np.arange(11), extend='max',
          label='MAD (K)');

The temperature predicted by the NN diverges from the truth much faster that it does for the humidity. Some large biases in the temperature emerge in the NN and CAM schemes.  Moreover, it appears the neural network misses the diurnal cycle of temperature.

### Column Integrated Error

In [None]:
fig, (axT, axQ) = plt.subplots(1, 2, figsize=(7,3), dpi=100)
mass_mad_sl = xr.Dataset({k: -integrate(v, 'p')/1015 for k, v in mad_sl.items()}).to_dataframe()
mass_mad_qt = xr.Dataset({k: -integrate(v, 'p')/1015 for k, v in mad_qt.items()}).to_dataframe()

mass_mad_sl.plot(ax=axT, legend=False)
axT.set_ylabel(r'vMAD (K)')

mass_mad_qt.plot(ax=axQ)
axQ.set_ylabel(r'vMAD (g/kg)')
