# **SHARPlib -- Working with Gridded Data (HRRR Example)**

This tutorial notebook provides an example of how to use SHARPlib with 3D gridded data sources, such as the **High Resolution Rapid Refresh (HRRR)** model. This notebook leverages some advanced tools and topics, such as [reading remote GRIB2 data using `kerchunk`](https://nbviewer.org/gist/peterm790/92eb1df3d58ba41d3411f8a840be2452), and [parallelizing computations on chunked arrays using Dask](https://tutorial.dask.org/02_array.html). For the sake of brevity, this notebook will not go over these in any detail, but rather, show how they can be used with SHARPlib. 

## HRRR Data
JSON reference files that map to a single GRIB2 file on Google Cloud are provided in this repository to provide access as a virtual Zarr store. There are 3 separate files for 3 separate groups of variables on different coordinate systems: `hrrr-hybrid.json` for hybrid vertical level data, `hrrr-2m.json` for fields defined at 2-meters AGL, and `hrrr-surface.json` for variables defined on the model surface. The surface pressure is read from the surface group, 2-meter temperature and specific humidity are read from the 2-meter group, and the remaining 3D data `[pressure, geopotential height, temperature, specific humidity]` are read from the hybrid level group. 

## Dask Client
A Dask client is used to set up a "local cluster" using your computer's CPUs. ***It is highly recommended that you rude the `n_workers` and `memory_limit` parameters to something appropriate for your system***. Setting `n_workers` or `memory_limit` to values that exceed what your system supports can cause significant slowdowns and even crashes.

Once the client starts, you can click on the generated URL to see how your parallel tasks are being executed!

In [None]:
## Dask throws some noise about 
## large task graphs. This is unavoidable
## and clutters the output, so turn it off
import warnings
warnings.filterwarnings('ignore', category=UserWarning)

import dask
import logging
import cmocean
import numpy as np
import xarray as xr
import cartopy.crs as ccrs
from distributed import Client
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from cartopy import feature as cfeature
from scipy.ndimage import gaussian_filter
from mpl_toolkits.axes_grid1 import make_axes_locatable


from nwsspc.sharp.calc import constants
from nwsspc.sharp.calc import parcel
from nwsspc.sharp.calc import params
from nwsspc.sharp.calc import thermo
from nwsspc.sharp.calc import winds
from nwsspc.sharp.calc import layer
from nwsspc.sharp.calc import interp

client = Client(
    n_workers=8, 
    memory_limit='8GB', 
    silence_logs=logging.ERROR
)
client

## Dataset Chunking
One of the ways we can control performance and parallelism is by controling how the dataset is chunked. The chunk sizes control things like how much memory each task uses, how many tasks are generated, and task performance depending on how CPU cache-friendly the chunk strategy is.  

First and foremost, the `hybrid` coordinate should be one, single chunk and not sub-divided. Our calculations with SHARPlib rely on contiguous, vertical profile arrays of data for computation, and splitting these into chunks in the vertical would break that requirement. This is effectively hard-coded in the call: `.chunk(dict(hybrid=-1, **chunks))`. 

Second, CPUs work most efficiently with linear, cache friendly data. Arrays are typically stored in memory using the C memory layout, meaning that an array shape of `(NZ, NY, NX)` is ordered linearly by `X` values. This means that, generally speaking, we can be the most efficient with our CPU cache by passing entire rows of data for computation, and chunking by our column coordinate `y`. The `-1` is shorthand for "use the entire dimension". 

Depending on your system, number of cores, and amount of memory, you may want to adjust these values. I recommend adjusting the y-chunk first before x, but if you do adjust x, make sure its sufficiently large for the CPU to leverage linearity! If you don't quite believe it, experiment with what happens when you use a small x-chunk and a large y-chunk, and compare the timing...

## Dataset Precision and Downloading
SHARPlib works using 32-bit floating point precision, while most of Python operates at 64-bit by default. A call to `.astype("float32")` is made in order to convert all of the arrays to 32-bit precision up front. Additionally, a call to `.compute()` is included in order to download the dataset into memory. While Dask would and could handle downloading the data during later computation, sometimes it is simpler to just retain the whole dataset in memory -- particularly if additional computation is desired. 



In [None]:
chunks = {"y": 64, "x": 256}

ds_hybrid = xr.open_dataset(
    "hrrr-hybrid.json", 
    engine="kerchunk", 
    decode_timedelta=True
)[["pres", "gh", "t", "q", "u", "v"]] \
    .astype("float32") \
    .compute() \
    .chunk(dict(hybrid=-1, **chunks)) 


ds_2m = xr.open_dataset(
    "hrrr-2m.json", 
    engine="kerchunk", 
    decode_timedelta=True
)[["t2m", "sh2"]] \
    .astype("float32") \
    .compute() \
    .chunk(chunks) 


ds_sfc = xr.open_dataset(
    "hrrr-surface.json", 
    engine="kerchunk", 
    decode_timedelta=True
)[["sp", "orog"]] \
    .astype("float32") \
    .compute() \
    .chunk(chunks) 


In [None]:
print(ds_hybrid)
print("==========")
print(ds_2m)
print("==========")
print(ds_sfc)
print("==========")

## Calling SHARPlib on Gridded Data
SHARPlib is designed to work on vertical arrays of profile data (as mentioned in the chunking section), but our model data is a combination of 2D and 3D fields. There is also a need for computing additional array data, such as `water vapor mixing ratio`, `virtual temperature` and `dewpoint temperature`. While the afformentioned variables could reasonably be pre-computed on the 3D grids at a relatively low computational cost, other things such as `CAPE`, `CINH`, and the `Effective Inflow Layer` are much more computationally expensive... and iterating over the HRRR grid multiple times to compute each variable is usually more expensive than just computing everything at once. Once a gridpoint's vertical profile is loaded into memory, why not go ahead and get everything you need out of it?
<hr>

### Function Arguments
The function arguments are going to be structured like so...

- 1D Vertical Arrays: `[pres, hght, tmpk, spfh]`<br>
- Scalar values: `[sp, t2m, sh2]`<br>
- Keyword arguments: `use_2m=True`<br>

The magic of how we get 1D and scalar values out of 3D and 2D arrays is elaborated on further in the next cell. The primary goal is to show that computational logic should be locially grouped into a "per-profile" basis. 

### SHARPlib Computations
Within this function, we compute and return the following:

- **Computed**
    - Water Vapor Mixing Ratio
    - Virtual Temperature 
    - Dewpoint Temperature
<br>

- **Computed and Returned**
    - Most Unstable Parcel CAPE
    - Most Unstable Parcel CINH
    - Effective Inflow Layer Bottom 
    - Effective Inflow Layer Top

In [None]:
def compute_everything(
    pres, hght, 
    tmpk, spfh, 
    uwin, vwin, 
    sp, orog, 
    t2m, sh2, 
):
    def qc(val):
        if (val == constants.MISSING): return np.nan
        else: return val
            
    sp = sp.item()
    orog = orog.item()
    sh2 = sh2.item()
    t2m = t2m.item()
    
    mixr = thermo.mixratio(spfh)
    theta = thermo.theta(pres, tmpk)

    mixr_2m = thermo.mixratio(sh2)
    dwpk_2m = thermo.temperature_at_mixratio(mixr_2m, sp)

    mixr[mixr < constants.TOL] = constants.TOL
    if (mixr_2m < constants.TOL): mixr_2m = constants.TOL

    vtmp = thermo.virtual_temperature(tmpk, mixr)
    dwpk = thermo.temperature_at_mixratio(mixr, pres)

    # get the mixed-layer parcel
    mix_lyr = layer.PressureLayer(
        pres[0], pres[0] - 10000.0)
    mlpcl = parcel.Parcel.mixed_layer_parcel(
        mix_lyr,
        pres,
        theta,
        mixr
    )

    lifter = parcel.lifter_cm1()
    lifter.ma_type = thermo.adiabat.pseudo_liq
    lifter.converge = 0.15
    
    # lift the parcel and get CAPE
    pcl_vtmp = mlpcl.lift_parcel(lifter, pres)
    pcl_buoy = thermo.buoyancy(pcl_vtmp, vtmp)
    cape, cinh = mlpcl.cape_cinh(pres, hght, pcl_buoy)

        # Get the LCL height in meters AGL
    mllcl_hght = interp.interp_pressure(
        mlpcl.lcl_pressure,
        pres,
        hght
    ) - orog

    # Get the effective inflow layer for effective SRH
    mupcl = parcel.Parcel()
    eil = params.effective_inflow_layer(
        lifter,
        pres,
        hght,
        tmpk,
        dwpk,
        vtmp,
        mupcl=mupcl
    )
    
    eil_hght = layer.pressure_layer_to_height(
        eil,
        pres,
        hght
    )

    # Get the storm relative helicity 
    # for the effective inflow layer
    storm_mtn = params.storm_motion_bunkers(
        pres,
        hght,
        uwin,
        vwin,
        eil, mupcl
    )
    esrh = winds.helicity(
        eil,
        storm_mtn,
        pres,
        uwin,
        vwin
    )

    if mupcl.eql_pressure != constants.MISSING:
        # Get the effective bulk wind difference
        eql_hght = interp.interp_pressure(
            mupcl.eql_pressure,
            pres,
            hght
        )
        depth = (eql_hght - eil_hght.bottom)*0.5
        ebwd_lyr = layer.HeightLayer(
            eil_hght.bottom, 
            float(eil_hght.bottom + depth)
        )
        ebwd_cmp = winds.wind_shear(
            ebwd_lyr,
            hght,
            uwin,
            vwin
        )
        ebwd = winds.vector_magnitude(ebwd_cmp.u, ebwd_cmp.v)
    else:
        ebwd_cmp = winds.WindComponents(constants.MISSING, constants.MISSING)
        ebwd = 0


    
    stp = params.significant_tornado_parameter(
        mlpcl,
        mllcl_hght,
        esrh,
        ebwd
    )

    return mlpcl.cape, mlpcl.cinh, stp, qc(eil_hght.bottom), qc(eil_hght.top), qc(esrh), qc(ebwd_cmp.u), qc(ebwd_cmp.v)
    

## Parallelizing Computations with Xarray and Dask
This is the magic of how to parallelize the profile-based computation across an entire gridded dataset. It relies on two key variables: `input_core_dims` and `output_core_dims` that tell Xarray about how to decompose the input and reconstruct the output.

### Input Core Dims
This is a list of coordinates for each variable, ordered by argument order. The first 4 arguments to `compute_everything` are the 1D arrays that we need to get from the 3D fields, and so the first four values to `input_core_dims` are `["hybrid"]`. The `hybrid` dimension is our vertical dimension from the GRIB2 file, so, we are telling Xarray that this dimension is "core" to our computation, and should be present for these arguments/arrays in our function. 

### Output Core Dims
Similarly to `input_core_dims`, `output_core_dims` tells Xarray how to reconstruct the returned values from our output. In the case of `compute_everything`, we are returning 6 scalar values that should be assembled as a 2D array. They have no "core" dimensions, since they are independent of any other neighbor or coordinate, so this is a list of empty lists. However, if we wanted to return a 3D array of parcel virtual temperature, we could add another output argument that would have an core dim of `["hybrid"]`, much like the inputs. 

### Output DTypes
Having the same length as the number of output variables, this tells Xarray the expected return type and precision. SHARPlib returns float32 types, so that's what we specify. 

### Other arguments
The `vectorize=True` and `dask="parallelized` are required arguments in order to make this work. It coerces the function inputs to be numpy arrays (required to work with SHARPlib), and tells Dask that we want it to parallelize it in such a way that values such as `sp, t2m,` and `sh2` are scalars and not arrays. 

### Compute 
The return type from `apply_ufunc` is an object that tells Dask how to parallelize, but it does not compute a result intil `.compute()` is called. We do this for all of our variables at once so that we do not iterate over the grid multiple times. 

In [None]:
%%time
input_core_dims = [
    ["hybrid"], ["hybrid"],
    ["hybrid"], ["hybrid"],
    ["hybrid"], ["hybrid"],
    ["surface"], ["surface"],
    ["heightAboveGround"], 
    ["heightAboveGround"],
]

output_core_dims = [
    [],[],
    [],[],
    [],[],
    [],[],
]

output_dtypes = [
    np.float32, np.float32, 
    np.float32, np.float32,
    np.float32, np.float32,
    np.float32, np.float32,
]

result_func = xr.apply_ufunc(
    compute_everything, 
    ds_hybrid["pres"], ds_hybrid["gh"],
    ds_hybrid["t"], ds_hybrid["q"],
    ds_hybrid["u"], ds_hybrid["v"],
    ds_sfc["sp"], ds_sfc["orog"],
    ds_2m["t2m"], ds_2m["sh2"],
    input_core_dims=input_core_dims,
    output_core_dims=output_core_dims,
    vectorize=True,
    dask="parallelized",
    output_dtypes=output_dtypes,
)

output_names = [
    "mixed_layer_cape", "mixed_layer_cinh", 
    "effective_layer_stp", "effective_inflow_layer_height_bottom_agl",
    "effective_inflow_layer_height_top_agl", "effective_srh",
    "effective_bulk_wind_difference_u", "effective_bulk_wind_difference_v",
]

result_dict = {name: data for name, data in zip(output_names, result_func)}
result_ds = xr.Dataset(result_dict)
result_ds = result_ds.compute()

In [None]:
print(result_ds)
print(result_ds["effective_layer_stp"].max(), result_ds["effective_layer_stp"].min())

In [None]:
%%time

lon = ds_2m.longitude.values
lat = ds_2m.latitude.values

stp = result_ds["effective_layer_stp"]
ml_cape = result_ds["mixed_layer_cape"]
ml_cinh = result_ds["mixed_layer_cinh"]
ebwd_u = result_ds["effective_bulk_wind_difference_u"]*1.94384
ebwd_v = result_ds["effective_bulk_wind_difference_v"]*1.94384

stp = gaussian_filter(stp, sigma=2.5)
ml_cinh = gaussian_filter(ml_cinh, sigma=2.5)

cape_alpha = np.clip(ml_cape / 100.0, 0, 1)

proj = ccrs.LambertConformal(
    central_longitude=-95.0, 
    standard_parallels=(25.0, 25.0)
)

fig = plt.figure(figsize=(16*2, 9*2), dpi=200)
fig.patch.set_visible(False)
ax = plt.axes(projection=proj)
ax.axis('off')

ax.set_extent(
    [-110, -80, 
     26, 45
    ], crs=ccrs.PlateCarree()
)

ax.add_feature(
    cfeature.OCEAN, 
    facecolor='#285970', 
    zorder=1
)
ax.add_feature(
    cfeature.LAND, 
    facecolor='#757575', 
    zorder=1
)
ax.coastlines(
    '50m', 
    color='k', 
    zorder=3
)
ax.add_feature(
    cfeature.BORDERS, 
    color='k', 
    zorder=3
)

ax.add_feature(
    cfeature.NaturalEarthFeature(
        category='cultural',
        name='admin_1_states_provinces_lines',
        scale='50m',
        facecolor='none',
        edgecolor='k'
    ), zorder=3
)

cm = ax.pcolormesh(
    lon, 
    lat, 
    ml_cape, 
    alpha=cape_alpha,
    edgecolors=None,
    cmap='cmo.amp', 
    vmin=0, 
    vmax=5500,
    shading="nearest",
    transform=ccrs.PlateCarree(), 
    rasterized=True,
    zorder=2
)

cf = ax.contourf(
    lon,
    lat,
    ml_cinh,
    levels=[-250, -150, -100, -50, -25, -10],
    cmap="Blues_r",
    extend="min",
    alpha=0.3,
    vmin=-250, 
    vmax=0,
    zorder=2,
    hatches=['/','/','/','','',''],
    transform=ccrs.PlateCarree(), 
)

ct = ax.contour(
    lon,
    lat,
    stp,
    linewidths=[1, 2, 2.5, 4, 4],
    levels=[0.5, 1.0, 3.0, 5.0, 10.0],
    cmap="cmo.thermal",
    vmin=-20, 
    vmax=12,
    zorder=3,
    transform=ccrs.PlateCarree(), 
)

ax.barbs(
    lon[::25, ::25],
    lat[::25, ::25], 
    ebwd_u[::25, ::25].values, 
    ebwd_v[::25, ::25].values,
    barbcolor="k",
    edgecolor="w",
    length=8,
    linewidth=1.5,
    transform=ccrs.PlateCarree(),
    zorder=3,
)

## background for colorbars to improce label
## legibility... 
rect = patches.Rectangle(
    (0, 0), 
    0.15, 
    1, 
    transform=ax.transAxes, 
    facecolor="w", 
    alpha=0.75
)
ax.add_patch(rect)

divider = make_axes_locatable(ax)
cbaxes = ax.inset_axes((0.075, 0.015, 0.025, 0.975), zorder=20)
cb = fig.colorbar(cm, cax=cbaxes)
cb.ax.tick_params(labelsize=24, labelcolor="k")
cb.set_label("ML CAPE (J/kg)", labelpad=-98, fontsize=20)

cbaxes2 = ax.inset_axes((0.005, 0.015, 0.025, 0.975), zorder=20)
cb2 = fig.colorbar(cf, cax=cbaxes2)
cb2.ax.tick_params(labelsize=24, labelcolor="k")
cb2.set_label("ML CINH (J/kg)", labelpad=-101, fontsize=20)

plt.tight_layout()
plt.savefig("plot.png", bbox_inches="tight")
plt.show()
fig.clf()
plt.close(fig)