# Rechunking Larger Datasets with Dask

The goal of this notebook is to expand on the rechunking performed in the [Introductory Rechunking tutorial](../101/Rechunking.ipynb).
This notebook will perfrom the same operations, but will work on the **much** larger dataset and involve some parallelization using [Dask](https://www.dask.org/). 

:::{Warning}
You should only run workflows like this tutorial on a cloud or HPC compute node.
In application, this will require reading and writing **enormous** amounts of data.
Using a typical network connection and simple compute environment, you would saturate your bandwidth and max out your processor, thereby taking days to for the rechunking to complete.
:::

In [None]:
import os
import xarray as xr
import fsspec
from rechunker import rechunk
import zarr
import shutil
import numpy as np
import dask

## Read in a Zarr Store

Like the [Introductory Rechunking tutorial](../101/Rechunking.ipynb), we will use the data from the National Water Model Retrospective Version 2.1.
The full dataset is part of the [AWS Open Data Program](https://aws.amazon.com/opendata/), available via the S3 bucket at: `s3://noaa-nwm-retro-v2-zarr-pds/`.

As this is a `zarr` store, let's read it in with [`xarray.open_dataset()`](https://docs.xarray.dev/en/stable/generated/xarray.open_dataset.html) and `engine='zarr'`.

In [None]:
file = fsspec.get_mapper('s3://noaa-nwm-retro-v2-zarr-pds', anon=True)
ds = xr.open_dataset(file, chunks={}, engine='zarr')
ds

## Restrict for Tutorial

As we saw in the [Introductory Rechunking tutorial](../101/Rechunking.ipynb), this data set is massive, taking up almost 32 TiB uncompressed.
As this is a tutorial, we will still restrict the data to a subset, as we don't really need to work on the entire dataset.
Following the [Introductory Rechunking tutorial](../101/Rechunking.ipynb) let's only look at `streamflow` and `velocity` for the first 15,000 `feature_id`, but the 2000s decade of water years (October 1999 through September 2009) instead of a single water year.
This will make our dataset larger-than-memory, but it should still run in a reasonable amount of time.

For processing the full-sized dataset, you'd just skip this step where we slice off a representative example of the data.
Expect run time to increase in proportion to the size of the data being processed.

In [None]:
ds = ds[['streamflow', 'velocity']]
ds = ds.isel(feature_id=slice(0, 15000))
ds = ds.sel(time=slice('1999-10-01', '2009-09-30'))
ds

Now, our subset of data is only about 10 GiB per data variable and has a chunk shape of `{'time': 672, 'feature_id': 15000}` with size of 76.9 MiB.
However, the chunk shape is not an optimal choice for our analysis as it is chunked completely by `feature_id` (i.e., all feature IDs for a given time can be read in a single chunk).
Following the [Introductory Rechunking tutorial](../101/Rechunking.ipynb), let's get chunk shapes that are time-series wise chunking (i.e., all `time` for a given `feature_id` in one chunk) for streamflow and balanced for velocity.

## Rechunk Plan

Using our general strategy of time-series wise chunking for streamflow and balanced for velocity,
let's compute how large the chunk sizes will be if we have chunk shapes of `{'time': 87672, 'feature_id': 1}` for streamflow and 3 chunks per dimension for velocity.

In [None]:
nfeature = len(ds.feature_id)
ntime = len(ds.time)

streamflow_chunk_plan = {'time': ntime, 'feature_id': 1}
bytes_per_value = ds.streamflow.dtype.itemsize
total_bytes = streamflow_chunk_plan['time'] * streamflow_chunk_plan['feature_id'] * bytes_per_value
streamflow_MiB = total_bytes / (2 ** 10) ** 2
partial_chunks = {'time': ntime -  streamflow_chunk_plan['time'] * (ntime / streamflow_chunk_plan['time']),
                  'feature_id': nfeature -  streamflow_chunk_plan['feature_id'] * (nfeature / streamflow_chunk_plan['feature_id']),}
print("STREAMFLOW \n"
      f"Chunk of shape {streamflow_chunk_plan} \n"
      f"Partial 'time' chunk remainder: {partial_chunks['time']} ({partial_chunks['time']/streamflow_chunk_plan['time']:.3f}% of a chunk)\n"
      f"Partial 'feature_id' chunk remainder: {partial_chunks['feature_id']} ({partial_chunks['feature_id']/streamflow_chunk_plan['feature_id']:.3f}% of a chunk)\n"
      f"Chunk size: {streamflow_MiB:.2f} [MiB] \n")

chunks_per_dim = 3
velocity_chunk_plan = {'time': ntime // chunks_per_dim, 'feature_id': nfeature // chunks_per_dim}
bytes_per_value = ds.velocity.dtype.itemsize
total_bytes = velocity_chunk_plan['time'] * velocity_chunk_plan['feature_id'] * bytes_per_value
velocity_MiB = total_bytes / (2 ** 10) ** 2
partial_chunks = {'time': ntime -  velocity_chunk_plan['time'] * chunks_per_dim,
                  'feature_id': nfeature -  velocity_chunk_plan['feature_id'] * chunks_per_dim,}
print("VELOCITY \n"
      f"Chunk of shape {velocity_chunk_plan} \n"
      f"Partial 'time' chunk remainder: {partial_chunks['time']} ({partial_chunks['time']/velocity_chunk_plan['time']:.3f}% of a chunk)\n"
      f"Partial 'feature_id' chunk remainder: {partial_chunks['feature_id']} ({partial_chunks['feature_id']/velocity_chunk_plan['feature_id']:.3f}% of a chunk)\n"
      f"Chunk size: {velocity_MiB:.2f} [MiB]")

Okay, we can see that the streamflow chunk size is way to small by a factor of ~100.
So, let's include 100 feature IDs per chunk.
As for velocity, it is about ~10x too big.
As it is an even chunk split, that means we need to increase the number of chunks per dimension by ~$\sqrt{10} \approx 3$.
However knowing that the time dimension is hourly, we can get no partial chunks if our chunk per dimension is a divisor of 24.
Luckily, this also applies to the feature ID dimension as 15000 is a multiple of 24.
So, rather than increasing our chunks per dimension by a factor of 3 to 9, let's increase them to 12 as this will give no partial chunks.

In [None]:
nfeature = len(ds.feature_id)
ntime = len(ds.time)

streamflow_chunk_plan = {'time': ntime, 'feature_id': 100}
bytes_per_value = ds.streamflow.dtype.itemsize
total_bytes = streamflow_chunk_plan['time'] * streamflow_chunk_plan['feature_id'] * bytes_per_value
streamflow_MiB = total_bytes / (2 ** 10) ** 2
partial_chunks = {'time': ntime -  streamflow_chunk_plan['time'] * (ntime / streamflow_chunk_plan['time']),
                  'feature_id': nfeature -  streamflow_chunk_plan['feature_id'] * (nfeature / streamflow_chunk_plan['feature_id']),}
print("STREAMFLOW \n"
      f"Chunk of shape {streamflow_chunk_plan} \n"
      f"Partial 'time' chunk remainder: {partial_chunks['time']} ({partial_chunks['time']/streamflow_chunk_plan['time']:.3f}% of a chunk)\n"
      f"Partial 'feature_id' chunk remainder: {partial_chunks['feature_id']} ({partial_chunks['feature_id']/streamflow_chunk_plan['feature_id']:.3f}% of a chunk)\n"
      f"Chunk size: {streamflow_MiB:.2f} [MiB] \n")

chunks_per_dim = 12
velocity_chunk_plan = {'time': ntime // chunks_per_dim, 'feature_id': nfeature // chunks_per_dim}
bytes_per_value = ds.velocity.dtype.itemsize
total_bytes = velocity_chunk_plan['time'] * velocity_chunk_plan['feature_id'] * bytes_per_value
velocity_MiB = total_bytes / (2 ** 10) ** 2
partial_chunks = {'time': ntime -  velocity_chunk_plan['time'] * chunks_per_dim,
                  'feature_id': nfeature -  velocity_chunk_plan['feature_id'] * chunks_per_dim,}
print("VELOCITY \n"
      f"Chunk of shape {velocity_chunk_plan} \n"
      f"Partial 'time' chunk remainder: {partial_chunks['time']} ({partial_chunks['time']/velocity_chunk_plan['time']:.3f}% of a chunk)\n"
      f"Partial 'feature_id' chunk remainder: {partial_chunks['feature_id']} ({partial_chunks['feature_id']/velocity_chunk_plan['feature_id']:.3f}% of a chunk)\n"
      f"Chunk size: {velocity_MiB:.2f} [MiB]")

Nice!
Now, our chunks are a reasonable size and have no remainders.
So, lets use these chunk plans for our rechunking.

In [None]:
chunk_plan = {
    'streamflow': streamflow_chunk_plan,
    'velocity': velocity_chunk_plan,
     # We don't want any of the coordinates chunked
    'latitude': (nfeature,),
    'longitude': (nfeature,),    
    'time': (ntime,),
    'feature_id': (nfeature,)
}
chunk_plan

## Rechunk with `Rechunker`

With this plan, we can now ask `rechunker` to re-write the data using the prescribed chunking pattern.

### Set up output location

Unlike with the smaller dataset in our previous rechunking tutorial, we will write this larger dataset to an object store (an S3 'bucket') in a datacenter.
So, we need to set that up so that `rechunker` will have a suitable place to write data.

TODO: Update these next three cells to properly use AWS or HPC. May need to add some markdown cells to describe what is being done.

In [None]:
os.environ['AWS_PROFILE'] = "osn-renci"
os.environ['AWS_S3_ENDPOINT'] = "https://renc.osn.xsede.org"
# %run ../AWS.ipynb

In [None]:
from getpass import getuser
uname=getuser()

fsw = fsspec.filesystem(
    's3', 
    anon=False, 
    default_fill_cache=False, 
    skip_instance_cache=True, 
    client_kwargs={'endpoint_url': os.environ['AWS_S3_ENDPOINT'], }
)

workspace = 's3://rsignellbucket2/'
testDir = workspace + "testing/"
myDir = testDir + f'{uname}/'
fsw.mkdir(testDir)

In [None]:
temp_store = fsw.get_mapper(myDir + 'tutorial_staging.zarr')
outfile = fsw.get_mapper(myDir + 'tutorial_rechunked.zarr')
for fname in [staging, outfile]:
    print(f"Ensuring {fname.root} is empty...", end='')
    try:
        fsw.rm(fname.root, recursive=True)
    except:
        FileNotFoundError
    print(" Done.")

### Spin up Dask Cluster

Our rechunking operation will be able to work in parallel.
To do that, we will spin up a `dask` cluster on the cloud hardware to schedule the various workers.
Note that this cluster must be configured with a specific user **profile** with permissions to write to our eventual output location.

TODO: Ensure this is spinning up the cluster we want

In [None]:
import logging

from dask.distributed import Client

# client = Client(n_workers=8, silence_logs=logging.ERROR)
# client

### Rechunk

Now, we are ready to rechunk!

In [None]:
result = rechunk(
    # Make sure the base chunks are correct
    ds.chunk({'time': 672, 'feature_id': 15000}),
    target_chunks=chunk_plan,
    max_mem="16GB",
    target_store=outfile,
    temp_store=temp_store
)
result

Remember that merely invoking Rechunker does not do any work.
It just sorts out the rechunking plan and writes metadata.
We need to call `.execute` on the `results` object to actually run the rechunking.

In [None]:
from dask.distributed import progress, performance_report

with performance_report(filename="dask-report.html"):
    r = result.execute(retries=10)  

# Also consolidate the metadata for fast reading into xarray
_ = zarr.consolidate_metadata(outfile)

## Results
Let's read in the resulting re-chunked dataset to see how it looks:

In [None]:
ds_rechunked = xr.open_zarr(outfile)
ds_rechunked

### Comparison


In [None]:
## Before:
ds

In [None]:
## After:
ds_rechunked

In [None]:
client.close()
cluster.close()