# HyTEST Tutorial: Cache NWIS data to Zarr

#### Rich Signell, last updated June 2022

* Tutorial uses pyriver geohydro package extracts streamflow from NWIS
* Here we query all the gages identified in the National Water Model 2.1 over the simulation period and store to zarr for faster access 

[Rendered notebook with output](https://nbviewer.org/gist/3d38160704a7d8f606f99a3ee07680ec)

In [None]:
import logging
logging.getLogger().setLevel(logging.CRITICAL)

In [None]:
import warnings
warnings.filterwarnings('ignore') 

In [None]:
%%time
import os
import pandas as pd
import xarray as xr
import fsspec
import hvplot.xarray
from pathlib import Path
import numpy as np
import dask
from dask_jobqueue import SLURMCluster
from dask.distributed import Client, LocalCluster
from zarr.convenience import consolidate_metadata

### Import Modeled Data from URL

Load model dataset that contains stations and time range of interest:

In [None]:
fs2 = fsspec.filesystem('s3', requester_pays=True)

In [None]:
fs2.ls('s3://nhgf-development/nwm/')

In [None]:
url = 's3://nhgf-development/nwm/chanobs.zarr'

#### create xarray dataset of modeled data

In [None]:
%%time
ds_chanobs = xr.open_dataset(fs2.get_mapper(url), engine='zarr', 
                             backend_kwargs={'consolidated':False}, chunks={})

In [None]:
ds_chanobs

#### About modeled data
The dataset called "ds_chanobs" is the "channel observations" streamflow modeled output from the National Water Model v2.1

This xarray dataset contains hourly streamflow predictions at 7994 streamflow stations.

In [None]:
ds_chanobs

In [None]:
## determine what USGS station ids are in the modeled output.
gage_ids_str = [gage_id.astype('str').lstrip() for gage_id in ds_chanobs['gage_id'].values]

In [None]:
## what the gage IDs look like:
gage_ids_str[0:5]

In [None]:
# determine the start and end of the modeled timeseries
start = ds_chanobs.time[0].values
stop = ds_chanobs.time[-1].values
print(start,stop)

#### Extract obs data using hyriver

In [None]:
#import pygeohydro
from pygeohydro import NWIS

In [None]:
nwis = NWIS()

In [None]:
# use the start and stop dates above from the modeled data to extract observational data from NWIS for the same time period
dates = (start,stop)
print(dates)

If we request only one station, we get a time series with just good data (doesn't span the time window).  So we request two stations:

In [None]:
%%time
ds_obs = nwis.get_streamflow(gage_ids_str[:2], dates, to_xarray=True)

In [None]:
# examine xarray dataset of pulled information for each USGS station ID and associated streamflow in cfs
ds_obs

In [None]:
# rename variables
ds_obs = ds_obs.rename_dims({'station_id':'gage_id'}).rename({'station_id':'gage_id','discharge':'streamflow'})

Define time base for interpolatation of subsequent NWIS data requests:

In [None]:
time_base = ds_obs.time.values

In [None]:
fs = fsspec.filesystem('file')

### Identify directory to store this saved data

In [None]:
# cache data in this directory [change to your directory]:
# dir_scratch = Path('/caldera/projects/usgs/hazards/cmgp/woodshole/rsignell/conus404/zarr')
# file_chanobs = dir_scratch / 'nwis_chanobs2.zarr'

In [None]:
# edit this to your directory where you wish to save NWIS streamflow information
dir_scratch = Path('/caldera/projects/usgs/water/wbbp/')
file_chanobs = dir_scratch / 'nwis_chanobs2.zarr'

In [None]:
if file_chanobs.is_dir():
    fs.rm(str(file_chanobs),recursive=True)

In [None]:
len(gage_ids_str)

In [None]:
#source_dataset = ds_obs.drop_vars(drop_vars)
source_dataset = ds_obs

In [None]:
template = (source_dataset.chunk().
            pipe(xr.zeros_like).
            isel(gage_id=0, drop=True).
            expand_dims(gage_id=len(gage_ids_str), axis=-1))

template = template.assign_coords({'gage_id':[f'USGS-{gage_id}' for gage_id in gage_ids_str]})

template = template.chunk({'time':len(ds_obs.time), 'gage_id': 1})

In [None]:
template

Specify appropriate dtypes and fill values (otherwise int64 and float64 are used by default):

In [None]:
encoding = {'alt_acy_va': dict(_FillValue=-2147483647, dtype=np.int32),
            'alt_va': dict( _FillValue=9.96921e+36, dtype=np.float32),
            'dec_lat_va': dict( _FillValue=None, dtype=np.float32),
            'dec_long_va': dict( _FillValue=None, dtype=np.float32),
            'streamflow': dict( _FillValue=9.96921e+36, dtype=np.float32)}

In [None]:
# Writes no data (yet)
template.to_zarr(file_chanobs, compute=False, encoding=encoding, consolidated=True, mode='w')

In [None]:
nt = len(ds_obs.time)

In [None]:
ds_obs.to_zarr(file_chanobs, region={'time':slice(0, nt), 'gage_id': slice(0, 2)})

In [None]:
def ind2zarr(n):
     site_id = gage_ids_str[n]
     try:
        ds_obs = nwis.get_streamflow(site_id, dates, to_xarray=True).interp(time=time_base)
        ds_obs = ds_obs.rename_dims({'station_id':'gage_id'}).rename({'station_id':'gage_id','discharge':'streamflow'})
        ds_obs.to_zarr(file_chanobs, region={'time': slice(0, nt), 'gage_id': slice(n,n+1)})
     except:
        pass

### Use a Dask cluster to make NWIS station requests in parallel:

In [None]:
resource = 'tallgrass' #choose from denali, tallgrass, local, esip-qhub-gateway-v0.4

In [None]:
project = os.environ['SLURM_JOB_ACCOUNT']

In [None]:
def configure_cluster(resource):
    ''' Helper function to configure cluster
    '''
    if resource == 'denali':
        cluster = LocalCluster(threads_per_worker=1)
        client = Client(cluster)
    
    elif resource == 'tallgrass':
        project = os.environ['SLURM_JOB_ACCOUNT']
        
        cluster = SLURMCluster(processes=1,cores=1, 
            memory='10GB', interface='ib0',
            project=project, walltime='01:00:00',      
            job_extra={'hint': 'multithread'})
        cluster.scale(10)
        client = Client(cluster)
        
    elif resource == 'local':
        import warnings
        warnings.warn("Running locally can result in costly data transfers!\n")
        n_cores = os.cpu_count() # set to match your machine
        cluster = LocalCluster(threads_per_worker=n_cores)
        client = Client(cluster)
        
    elif resource in ['esip-qhub-gateway-v0.4']:   
        import sys
        sys.path.append(os.path.join(os.environ['HOME'],'shared','users','lib'))
        import ebdpy as ebd
        ebd.set_credentials(profile='esip-qhub')

        aws_profile = 'esip-qhub'
        aws_region = 'us-west-2'
        endpoint = f's3.{aws_region}.amazonaws.com'
        ebd.set_credentials(profile=aws_profile, region=aws_region, endpoint=endpoint)
        worker_max = 30
        client,cluster = ebd.start_dask_cluster(profile=aws_profile, worker_max=worker_max, 
                                              region=aws_region, use_existing_cluster=True,
                                              adaptive_scaling=False, wait_for_cluster=False, 
                                              worker_profile='Medium Worker', propagate_env=True)
        
    return client, cluster

In [None]:
client, cluster = configure_cluster(resource)

In [None]:
client

### Begin saving and writing data
This is where all the work gets done (a list of delayed tasks is created and then executed by the Dask cluster):

In [None]:
%%time
# takes less than 5 minutes with a local cluster on Denali:
_ = dask.compute(*[dask.delayed(ind2zarr)(i) for i in range(len(gage_ids_str))], retries=10);

Call Zarr convenience function to consolidate the metadata:

In [None]:
_ = consolidate_metadata(file_chanobs)

#### Check out the resulting dataset

In [None]:
# filename and path check, where the NWIS data is now stored.
file_chanobs

In [None]:
dst = xr.open_dataset(file_chanobs, engine='zarr', chunks={}, backend_kwargs=dict(consolidated=True))
dst

In [None]:
# check out a plot of the discharge over time for a random gage in the list:
dst.streamflow.isel(gage_id=100).hvplot(x='time', grid=True)

All done, close client and cluster

In [None]:
client.close(); cluster.close()