# NWM v2.1 Retrospective Zarr Usage Example
## Subset CHRTOUT to gages and re-rechunk for better data access times

_James McCreight_

Sometimes the chunks are not optimized to match your access pattern and you need to re-chunk (re-rechunk?) to maintain your cool.
As seen in [usage_example_streamflow_timeseries.ipynb](usage_example_streamflow_timeseries.ipynb), getting a single gage can take about 2 minutes. If you have to do that over and over again for 8000 gages you could be waiting like 11 days to get it all done. Instead of doing that, rechunk! 

This notebook take the chrtout.zarr store, uses the `gage_id` variable to subset out just the gages, and then rechunks the subset to optimize access to the full timeseries at each point individually. Note that this is the inverse chunk when the model natively writes the "chanobs" files, which consists of all the gages (single space chunk) in separate files by time (effectively a chunk for each time). 

In [2]:
# Using virtual environment in requirements_vis.txt
# from climata.usgs import DailyValueIO
import dask
from dask.distributed import Client, progress, LocalCluster, performance_report
from dask_jobqueue import PBSCluster
import holoviews as hv
import hvplot
import numcodecs
import numpy as np
import pathlib
import pandas as pd
from rechunker import rechunk
import shutil
import xarray as xr

hv.extension('bokeh')
hv.opts.defaults(
    hv.opts.Scatter(width=1200, height=500) )
pd.options.plotting.backend = 'holoviews'

cfs_2_cms = 0.028316846592

ModuleNotFoundError: No module named 'rechunker'

In [None]:
n_workers = 16
n_cores = 1
queue = "casper"
cluster_mem_gb = 25
chunk_mem_factor = 0.9

numcodecs.blosc.use_threads = False

cluster = PBSCluster(
    cores=n_cores,
    memory=f"{cluster_mem_gb}GB",
    queue=queue,
    project="NRAL0017",
    walltime="05:00:00",
    death_timeout=75,)

In [None]:
cluster.adapt(maximum=n_workers, minimum=n_workers)
client = Client(cluster)

In [None]:
#dask.config.set({"distributed.dashboard.link": "/{port}/status"})
client.dashboard_link

In [None]:
chrtout_file = '/glade/scratch/arezoo/HI/chrtout/chrtout.zarr'
ds_nwm_chrtout = xr.open_zarr(chrtout_file)

In [None]:
nwm_gages = (
    ds_nwm_chrtout
    .gage_id.where(ds_nwm_chrtout.gage_id != ''.rjust(15).encode(), drop=True))

In [None]:
nwm_gages

In [None]:
ds_nwm_gages_0 = (
    ds_nwm_chrtout
    .where(ds_nwm_chrtout.gage_id.isin(nwm_gages), drop=True))

In [None]:
ds_nwm_gages = ds_nwm_gages_0.copy()

In [None]:
# ds_nwm_gages['crs'] = ds_nwm_gages['crs'][0]
ds_nwm_gages = ds_nwm_gages.drop('crs')  ## gives a dask type issue if retained... not sure why that dosent happen elsewhere

In [None]:
ds_nwm_gages

In [12]:
dim_chunk_sizes = {'feature_id': 1, 'time': len(ds_nwm_gages.time)}
ds_nwm_gages = ds_nwm_gages.chunk(chunks=dim_chunk_sizes)

In [13]:
chunk_plan = {}
for vv in ds_nwm_gages.variables:
    if vv in ['streamflow', 'velocity']:
        chunk_plan[vv] = tuple((dim_chunk_sizes[tt] for tt in ds_nwm_gages[vv].dims))
    else: 
        chunk_plan[vv] = ds_nwm_gages[vv].shape
    ds_nwm_gages[vv].encoding['chunks'] = None  # seems redundant, with ds.chunk() ?

In [14]:
# for vv in ds_nwm_gages.variables:
#     print('\n')
#     print(vv)
#    print(ds_nwm_gages[vv].encoding)

In [15]:
ds_nwm_gages = ds_nwm_gages.chunk(chunks=dim_chunk_sizes)

In [16]:
chunk_plan

{'streamflow': (367439, 1),
 'velocity': (367439, 1),
 'elevation': (7994,),
 'feature_id': (7994,),
 'gage_id': (7994,),
 'latitude': (7994,),
 'longitude': (7994,),
 'order': (7994,),
 'time': (367439,)}

In [17]:
dir_scratch = pathlib.Path('/glade/scratch/jamesmcc')
file_chanobs = dir_scratch / 'chanobs.zarr'
file_chanobs_temp = dir_scratch / 'chanobs_temp.zarr'
for ff in [file_chanobs_temp]:  # , file_chanobs_temp]:
    if ff.exists():
        shutil.rmtree(ff)

In [18]:
if not file_chanobs.exists():    
    max_mem = f"{format(chunk_mem_factor * cluster_mem_gb / n_workers, '.2f')}GB"
    rechunk_obj = rechunk(
        ds_nwm_gages,
        chunk_plan,
        max_mem,
        str(file_chanobs),
        temp_store=str(file_chanobs_temp),
        executor="dask",)
    
    with performance_report(filename="dask-report.html"):
        result = rechunk_obj.execute(retries=10)

In [19]:
ds_chanobs = xr.open_zarr(file_chanobs)

In [20]:
ds_chanobs

Unnamed: 0,Array,Chunk
Bytes,31.23 kiB,31.23 kiB
Shape,"(7994,)","(7994,)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 31.23 kiB 31.23 kiB Shape (7994,) (7994,) Count 2 Tasks 1 Chunks Type float32 numpy.ndarray",7994  1,

Unnamed: 0,Array,Chunk
Bytes,31.23 kiB,31.23 kiB
Shape,"(7994,)","(7994,)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,117.10 kiB,117.10 kiB
Shape,"(7994,)","(7994,)"
Count,2 Tasks,1 Chunks
Type,|S15,numpy.ndarray
"Array Chunk Bytes 117.10 kiB 117.10 kiB Shape (7994,) (7994,) Count 2 Tasks 1 Chunks Type |S15 numpy.ndarray",7994  1,

Unnamed: 0,Array,Chunk
Bytes,117.10 kiB,117.10 kiB
Shape,"(7994,)","(7994,)"
Count,2 Tasks,1 Chunks
Type,|S15,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,31.23 kiB,31.23 kiB
Shape,"(7994,)","(7994,)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 31.23 kiB 31.23 kiB Shape (7994,) (7994,) Count 2 Tasks 1 Chunks Type float32 numpy.ndarray",7994  1,

Unnamed: 0,Array,Chunk
Bytes,31.23 kiB,31.23 kiB
Shape,"(7994,)","(7994,)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,31.23 kiB,31.23 kiB
Shape,"(7994,)","(7994,)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 31.23 kiB 31.23 kiB Shape (7994,) (7994,) Count 2 Tasks 1 Chunks Type float32 numpy.ndarray",7994  1,

Unnamed: 0,Array,Chunk
Bytes,31.23 kiB,31.23 kiB
Shape,"(7994,)","(7994,)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,31.23 kiB,31.23 kiB
Shape,"(7994,)","(7994,)"
Count,2 Tasks,1 Chunks
Type,int32,numpy.ndarray
"Array Chunk Bytes 31.23 kiB 31.23 kiB Shape (7994,) (7994,) Count 2 Tasks 1 Chunks Type int32 numpy.ndarray",7994  1,

Unnamed: 0,Array,Chunk
Bytes,31.23 kiB,31.23 kiB
Shape,"(7994,)","(7994,)"
Count,2 Tasks,1 Chunks
Type,int32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,21.88 GiB,2.80 MiB
Shape,"(367439, 7994)","(367439, 1)"
Count,7995 Tasks,7994 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 21.88 GiB 2.80 MiB Shape (367439, 7994) (367439, 1) Count 7995 Tasks 7994 Chunks Type float64 numpy.ndarray",7994  367439,

Unnamed: 0,Array,Chunk
Bytes,21.88 GiB,2.80 MiB
Shape,"(367439, 7994)","(367439, 1)"
Count,7995 Tasks,7994 Chunks
Type,float64,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,21.88 GiB,2.80 MiB
Shape,"(367439, 7994)","(367439, 1)"
Count,7995 Tasks,7994 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 21.88 GiB 2.80 MiB Shape (367439, 7994) (367439, 1) Count 7995 Tasks 7994 Chunks Type float64 numpy.ndarray",7994  367439,

Unnamed: 0,Array,Chunk
Bytes,21.88 GiB,2.80 MiB
Shape,"(367439, 7994)","(367439, 1)"
Count,7995 Tasks,7994 Chunks
Type,float64,numpy.ndarray


## Plot data at a single gage

In [21]:
usgs_station_id = "13317000"  ## Lower Salmon River is an unmanaged flow but you can pick your own.

In [22]:
#ds_nwm_gage = (
#    ds_chanobs
#    .where(ds_chanobs.gage_id == f'{usgs_station_id.rjust(15, " ")}'.encode(), drop=True))
ds_nwm_gage = ds_chanobs.isel(gage_id==f'{usgs_station_id.rjust(15, " ")}'.encode())

In [23]:
%time streamflow_nwm = ds_nwm_gage.streamflow.load()

CPU times: user 103 ms, sys: 9.12 ms, total: 112 ms
Wall time: 3.24 s


In [24]:
streamflow_nwm

In [25]:
streamflow_nwm_df = streamflow_nwm.squeeze('feature_id').to_dataframe()

# Bring in observations

In [26]:
param_id = "00060"  # streamflow in ft3/s
data = DailyValueIO(
    start_date=pd.Timestamp(ds_nwm_chrtout.time[0].values).date(),
    end_date=pd.Timestamp(ds_nwm_chrtout.time[-1].values).date(),
    station=usgs_station_id,
    parameter=param_id,)

In [27]:
# create lists of date-flow values
streamflow_usgs_d = {}
for series in data:
    streamflow_usgs_d['streamflow_obs'] = [r[1] * cfs_2_cms for r in series.data]
    streamflow_usgs_d['time'] = [pd.to_datetime(r[0]) for r in series.data]
    
streamflow_usgs_df = pd.DataFrame(streamflow_usgs_d).set_index('time')

In [28]:
combo_df = (
    streamflow_nwm_df
    .join(streamflow_usgs_df, how='outer')
    .rename(columns={'streamflow': 'NWM v2.1', 'streamflow_obs': 'observed'}))

In [29]:
def plot_water_year(water_year: int):
    wy_df = (
        combo_df[(combo_df.index >= f'{water_year - 1}-10-01') & 
                 (combo_df.index < f'{water_year}-10-01')])
    title = (
        f'Water year {water_year}, USGS station {usgs_station_id} : '
        f'National Water Model v2.1 retrospective and USGS observed streamflows')
    display(
        wy_df.plot.scatter(
            x='time', y=['NWM v2.1', 'observed'], 
            title=title)
            .opts(
                ylabel='cubic meters per second',
                xlabel=''))
    return None

In [30]:
plot_water_year(2018)

In [31]:
plot_water_year(2003)

In [32]:
plot_water_year(1996)