# Rechunk Kyoko's CONUS404 output
Rechunk only variables contained in the DRB spreadsheet

In [None]:
import xarray as xr
import rechunker
import zarr
import os
import numpy as np
import time

In [None]:
import hvplot.xarray
import fsspec

In [None]:
import pandas as pd

In [None]:
df = pd.read_csv('https://raw.githubusercontent.com/nhm-usgs/data-pipeline-helpers/main/conus404/wrf2d_vars_drb.csv')

In [None]:
vars = df['variable'].to_list()

In [None]:
len(vars)

In [None]:
vars

#### Create a list of 2D files for one water year

Use fsspec for file operations, even though we are on a local file system.  If we use fsspec for everything (local files, https, s3, gcs) it leads to less code changes when we switch between these

In [None]:
fs = fsspec.filesystem('file')

In [None]:
flist = sorted(fs.glob('/caldera/projects/usgs/water/impd/wrf-conus404/kyoko/wrfout_post/WY2017/wrf2d_d01*'))
print(flist[0])
print(flist[-1])
len(flist)

In [None]:
from dask.distributed import Client
from dask_jobqueue import SLURMCluster

There is a useful environment variable called `SLURM_CLUSTER_NAME` that indicates whether we are on Denali or Tallgrass.  If the code bombs with `SLURM_CLUSTER_NAME` not defined, that means we didn't request an interactive node via SLURM before we launched the notebook and we are running on the main node!  So it's a good reminder also!

In [None]:
if os.environ['SLURM_CLUSTER_NAME']=='denali':
    cluster = SLURMCluster(processes=16, cores=16, memory='160GB', interface='ipogif0',
                    project='woodshole', walltime='03:00:00',
                    job_extra={'hint': 'multithread', 'exclusive':'user'})
    cluster.scale(1)

In [None]:
# if os.environ['SLURM_CLUSTER_NAME']=='tallgrass':
#    cluster = SLURMCluster(processes=1, cores=36, memory='370GB', interface='ib0',
#                       project='woodshole', walltime='01:10:00')

In [None]:
#client.close(); cluster.close()

It turns out that on Tallgrass requesting 16 processes with one core each and specifying `exclusive=user` to stay on a node is the most efficient way to run this workflow. I previously tried `processes=1, cores=16` and the performance was terrible.  This isn't a compute intensive workflow, so those cores were just sitting around doing nothing, I guess.  

In [None]:
if os.environ['SLURM_CLUSTER_NAME']=='tallgrass':
    cluster = SLURMCluster(processes=1, cores=1, memory='10GB', 
                           interface='ib0',
                       project='woodshole', walltime='04:00:00',
                          job_extra={'hint': 'multithread', 'exclusive':'user'})
    cluster.scale(16)

In [None]:
client = Client(cluster)

In [None]:
client

Determine how much memory each worker has:

In [None]:
148/16*0.7

In [None]:
max_mem = '6.4GB'    # workers are 4GB, max_mem should be set to 75% or less

In [None]:
#ds2d = xr.open_dataset(flist[0], chunks={})   # open just one file

#### Exploring the data a bit before rechunking
Let's take a look at a few files.  open_mfdataset doesn't do well with lots of files (e.g. 1000s) but 144 should be not too bad.  Let's find out!

In [None]:
%%time
ds2d = xr.open_mfdataset(flist[:144], concat_dim='Time', combine='nested',
                         parallel=True, coords="minimal", data_vars="minimal", 
                         compat='override', chunks={})

In [None]:
ds2d.T2.shape

In [None]:
ds2d.T2.encoding

In [None]:
ds2d.assign_coords({'time':ds2d.XTIME})

In [None]:
ds2d.SMCWTD

In [None]:
a = ds2d.SWDOWN[:,500,500].hvplot(x='time')
b = ds2d.SWDNTC[:,500,500].hvplot(x='time')
c = ds2d.SWDNBC[:,500,500].hvplot(x='time')


In [None]:
a * b * c

Good.  Seems that's okay.  So if we process the dataset in 144 time step chunks, we should be fine

In [None]:
len(ds2d.data_vars)

In [None]:
ds2d.T2

The `.encoding` attribute should tell us what type of compression and chunking the input NetCDF files have

In [None]:
ds2d.T2.encoding

In [None]:
def rechunker_wrapper(source_store, target_store, temp_store, chunks=None,
                      mem=None, consolidated=False, verbose=True):

    if isinstance(source_store, xr.Dataset):
        g = source_store  # trying to work directly with a dataset
        ds_chunk = g
    else:
        g = zarr.group(str(source_store))
        # get the correct shape from loading the store as xr.dataset and parse the chunks
        ds_chunk = xr.open_zarr(str(source_store))
        

    group_chunks = {}
    # newer tuple version that also takes into account when specified chunks are larger than the array
    for var in ds_chunk.variables:
        # pick appropriate chunks from above, and default to full length chunks for dimensions that are not in `chunks` above.
        group_chunks[var] = []
        for di in ds_chunk[var].dims:
            if di in chunks.keys():
                if chunks[di] > len(ds_chunk[di]):
                    group_chunks[var].append(len(ds_chunk[di]))
                else:
                    group_chunks[var].append(chunks[di])

            else:
                group_chunks[var].append(len(ds_chunk[di]))

        group_chunks[var] = tuple(group_chunks[var])
    if verbose:
        print(f"Rechunking to: {group_chunks}")
        print(f"mem:{mem}")
    rechunked = rechunker.rechunk(g, target_chunks=group_chunks, max_mem=mem,
                                  target_store=target_store, temp_store=temp_store)
    rechunked.execute(retries=10)
    if consolidated:
        if verbose:
            print('consolidating metadata')
        zarr.convenience.consolidate_metadata(target_store)
    if verbose:
        print('done')

#### these paths need to be something you have write access to:

In [None]:
target_store = '/caldera/projects/usgs/water/zarr/conus404_chunk'
temp_store = '/caldera/projects/usgs/water/zarr/tmp'
concat_store = '/caldera/projects/usgs/water/zarr/conus404_2017'

In [None]:
#try:
#    fs.rm(concat_store, recursive=True)
#except:
#    pass

In [None]:
time_chunk = 144
x_chunk = 300
y_chunk = 300

In [None]:
n_time_chunks = int(len(flist)/time_chunk)
print(n_time_chunks)

In [None]:
%%time
start = time.time()
print("hello")

for in in range(0,n_time_chunks):
#for i in range(42,n_time_chunks): # if bombs or stops before completion
    i0 = i * time_chunk
    i1 = (i+1) * time_chunk
    end = time.time()
    print(i,flist[i0], (end-start)/60.)
    ds2d = xr.open_mfdataset(flist[i0:i1], concat_dim='Time', combine='nested',
                         parallel=True, coords="minimal", data_vars="minimal", 
                         compat='override', chunks={})
    ds2d.assign_coords({'time':ds2d.XTIME})
    # rechunker requires empty tmp and target dirs 
    try:
        fs.rm(temp_store, recursive=True)
    except:
        pass
    try:
        fs.rm(target_store, recursive=True)
    except:
        pass
  
    time.sleep(3)  # wait for files to be removed (necessary? hack!)
    

    rechunker_wrapper(ds2d[vars], target_store=target_store, temp_store=temp_store, 
            mem=max_mem, consolidated=True, verbose=False,
            chunks={'Time':time_chunk, 'south_north':y_chunk, 'west_east':x_chunk})
    
        # read back in the zarr chunk rechunker wrote
    ds = xr.open_dataset(target_store, engine='zarr', backend_kwargs=dict(consolidated=True))

    if i==0:
        ds.to_zarr(concat_store, consolidated=False, mode='w')
    else:
        ds.to_zarr(concat_store, consolidated=False, append_dim='Time')

### Inspect the concatenated Zarr dataset

In [None]:
#url = '/caldera/projects/usgs/hazards/cmgp/woodshole/rsignell/zarr/conus404a'
url = concat_store

In [None]:
%%time
ds = xr.open_dataset(fs.get_mapper(url), engine='zarr', chunks={}, 
                     backend_kwargs=dict(consolidated=False))

In [None]:
ds.T2

How many time chunks do we have?  If less than the full amount, change the starting index for the loop above to this value and rerun the loop cell

In [None]:
int(len(ds.Time)/time_chunk)

In [None]:
ds.T2.encoding

In [None]:
ds.T2[:,500,500].plot()