# Plot Q1 for different base states

We can start with mid tropospheric humidity as the relevant quantity.

In [None]:
import torch
import numpy as np

from src.data import open_data
from uwnet.thermo import *


def groupby_and_compute_nn(tropics, model, key, bins):
    bins_key = key + '_bins'
    averages = (tropics
     .stack(gridcell=['x', 'y', 'time'])
     .groupby_bins(key, bins=bins)
     .mean('gridcell'))

    # compute NN output
    # need to rename the extra dimension to "time" for call_with_xr
    avgs_expanded = averages.rename({bins_key: 'time'}).expand_dims(['x', 'y'], [-1, -2])
    output = model.call_with_xr(avgs_expanded).rename({'time': bins_key}).squeeze()
    
    for key in output:
        NNkey = 'NN' + key
        averages[NNkey] = output[key]
        
    return averages

def plot_line_cmap(arr, lower_val=4, key='path_bins'):
    val = [bin.mid for bin in arr[key].values]
    for it, arr in arr.groupby(key):
        label = it.mid
        arr.plot(y='z', hue=key, color=plt.cm.inferno((it.mid + lower_val)/(25+lower_val)), label=label)
    plt.legend()


model = torch.load("../../nn/NNLowerDecayLR/5.pkl")
ds = open_data('training').sel(time=slice(120,140))

p = open_data('pressure')

lat = ngaqua_y_to_lat(ds.y)
# compute LTS and Mid Trop moisture
ds['lts'] = lower_tropospheric_stability(ds.TABS, p, ds.SST, ds.Ps)
ds['path'] = midtropospheric_moisture(ds.QV, p, bottom=850, top=600)

tropics = ds.isel(y=(np.abs(lat)<  11))
tropical_mean = tropics.mean(['x', 'time', 'y']).swap_dims({'z': 'p'})

In [None]:
i = np.abs((tropical_mean.TABS - 273.15)).argmin()

tropical_mean.QV.plot(y='p', yincrease=False, label='QV')
plt.axhline(p[i], c='k', ls='--', label='Freezing level')
plt.legend()

The freezing level is at 600mb, so we define the mid tropospheric moisture as the water vapor path between 850mb and 600 mb.

In [None]:
ds['path'].plot()

In [None]:
moisture_bins = np.r_[:28:2.5]
output = groupby_and_compute_nn(tropics, model, 'path', moisture_bins)

In [None]:
plot_line_cmap(output.NNQT)

In [None]:
plot_line_cmap(output.NNSLI)

## Dependence on lower tropospheric stability.

LTS seems easier to calculute that estimated inversion strength (EIS). This analysis is isolated to the tropics, so we are less worried about the temperature dependence of LTS.

In [None]:
ds['lts'].plot();

In [None]:
lts_bins = np.r_[2:20:1]

In [None]:
output_lts_binned = groupby_and_compute_nn(tropics, model=model, key='lts', bins=lts_bins)

In [None]:
plot_line_cmap(output_lts_binned.NNSLI, key='lts_bins')

In [None]:
plot_line_cmap(output_lts_binned.NNQT, key='lts_bins')

Smaller LTS corresponds to stronger and deeper heating/drying.