# Analyzing the National Water Model with Xarray, Dask, and Coiled

_This example was adapted from [this notebook](https://github.com/dcherian/dask-demo/blob/main/nwm-aws.ipynb) by Deepak Cherian, Kevin Sampson, and Matthew Rocklin._

<iframe width="560" height="315" src="https://www.youtube.com/embed/blxvfGt9av8?si=-F_kY5K3VK4UvuPc" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" allowfullscreen></iframe>

## The National Water Model Dataset

In this example, we'll perform a county-wise aggregation of output from the National Water Model (NWM) available on the [AWS Open Data Registry](https://registry.opendata.aws/nwm-archive/). You can read more on the NWM from the [Office of Water Prediction](https://water.noaa.gov/about/nwm).

## Problem description

Datasets with high spatio-temporal resolution can get large quickly, vastly exceeding the resources you may have on your laptop. Dask integrates with Xarray to support parallel computing and you can use Coiled to scale to the cloud.

We'll calculate the mean depth to soil saturation for each US county:

- Years: 2020
- Temporal resolution: 3-hourly land surface output
- Spatial resolution: 250 m grid
- 6 TB

This example relies on a few tools:
- `dask` + `coiled` process the dataset in parallel in the cloud
- `xarray` + `flox` to work with the multi-dimensional Zarr datset and aggregate to county-level means from the 250m grid.

## Start a Coiled cluster

To demonstrate calculation on a cloud-available dataset, we will use Coiled to set up a dask cluster in AWS `us-east-1`.

In [None]:
import coiled

cluster = coiled.Cluster(
    name="xarray-nwm",
    region="us-east-1", # close to dataset, avoid egress charges
    n_workers=10,
    tags={"project": "nwm"},
    scheduler_vm_types="r7g.xlarge", # memory optimized AWS EC2 instances
    worker_vm_types="r7g.2xlarge"
)

client = cluster.get_client()

cluster.adapt(minimum=10, maximum=50)

### Load NWM data

In [None]:
import xarray as xr

ds = xr.open_zarr(
    fsspec.get_mapper("s3://noaa-nwm-retrospective-2-1-zarr-pds/rtout.zarr", anon=True),
    consolidated=True,
    chunks={"time": 896, "x": 350, "y": 350}
)
ds

Each field in this dataset is big!

In [None]:
ds.zwattablrt

Subset to a single year subset for demo purposes

In [None]:
subset = ds.zwattablrt.sel(time=slice("2020-01-01", "2020-12-31"))
subset

### Load county raster for grouping

Load a raster TIFF file identifying counties by unique integer with [rioxarray](https://corteva.github.io/rioxarray/html/rioxarray.html).

In [None]:
import fsspec
import rioxarray

fs = fsspec.filesystem("s3", requester_pays=True)

counties = rioxarray.open_rasterio(
    fs.open("s3://nwm-250m-us-counties/Counties_on_250m_grid.tif"), chunks="auto"
).squeeze()

# remove any small floating point error in coordinate locations
_, counties_aligned = xr.align(subset, counties, join="override")

counties_aligned

We'll need the unique county IDs later, calculate that now.

In [None]:
import numpy as np

county_id = np.unique(counties_aligned.data).compute()
county_id = county_id[county_id != 0]
print(f"There are {len(county_id)} counties!")

### GroupBy with flox

We could run the computation as:

```python
subset.groupby(counties_aligned).mean()
```

This would use flox in the background, however, it would also load `counties_aligned` into memory. To avoid egress charges, you can use `flox.xarray` which allows you to lazily groupby a Dask array (here `counties_aligned`) as long as you pass in the expected group labels in `expected_groups`. See the [flox documentation](https://flox.readthedocs.io/en/latest/intro.html#with-dask).

In [None]:
import flox.xarray

county_mean = flox.xarray.xarray_reduce(
    subset,
    counties_aligned.rename("county"),
    func="mean",
    expected_groups=(county_id,),
)

county_mean

In [None]:
county_mean.load()

### Cleanup

In [None]:
# since our dataset is much smaller now, we no longer need cloud resources
cluster.shutdown()

## Visualize

Data prep

In [None]:
# Read county shapefile, combo of state FIPS code and county FIPS code as multi-index
import geopandas as gpd
import hvplot.pandas

counties = gpd.read_file(
    "https://www2.census.gov/geo/tiger/GENZ2022/shp/cb_2022_us_county_20m.zip"
).to_crs("EPSG:3395")
counties["STATEFP"] = counties.STATEFP.astype(int)
counties["COUNTYFP"] = counties.COUNTYFP.astype(int)
continental = counties[~counties["STATEFP"].isin([2, 15, 72])].set_index(["STATEFP", "COUNTYFP"])

# Interpret `county` as combo of state FIPS code and county FIPS code. Set multi-index:
yearly_mean = county_mean.mean("time")
yearly_mean.coords["STATEFP"] = (yearly_mean.county // 1000).astype(int)
yearly_mean.coords["COUNTYFP"] = np.mod(yearly_mean.county, 1000).astype(int)
yearly_mean = yearly_mean.drop_vars("county").set_index(county=["STATEFP", "COUNTYFP"])

# join
continental["zwattablrt"] = yearly_mean.to_dataframe()["zwattablrt"]

Plot

In [None]:
continental.hvplot(
    c="zwattablrt",
    cmap='turbo_r',
    title="Mean Depth to Soil Saturation in 2020 by US County (meters)",
    xaxis=None,
    yaxis=None
)