# On Chunks - The Art of Dask Part 1 <img align="right" src="../../resources/csiro_easi_logo.png">

In this notebook we'll be exploring the impact of chunking choices for dask arrays. We'll use an ODC example but this isn't specific to ODC, it applies to all usage of dask `Array`s. Chunking choices have a _significant_ impact on performance for three reasons:
1. Chunks are the unit of work during processing
2. Chunks are the unit of transport in communicating information between workers
3. Chunks are directly related to the number of _tasks_ being executed

Performance is thus impacted in multiple ways - this is all about tradeoffs:
* if chunks are too small, there will be too many _tasks_ and processing may be inefficient _BUT_
* if chunks are too big, communication may be too long and the combined total of all chunks required for a calculation may exceed worker memory causing spilling to disk or worse, workers are killed

It's not just size that matters either, the relative contiguity of dimensions matters:
* Temporal processing is enhanced by larger chunks along the time dimension
* Spatial processing is enhanced by larger chunks along the spatial dimensions, _BUT_
* Earth Observation data can be sparse spatially, if chunks are too large spatially there will be a lot of empty chunks

Thankfully it is possible to _re-chunk_ data for different stages of computation. Whilst _re-chunking is an *expensive* operation_ the efficiency gains for downstream computation can be very significant and sometimes are simply essential to support the numerical processing required. For example, it is often necessary to have a single chunk on the time dimension for temporal calculations.

To understand the impact of chunking choices on _your code_ (it is very algorithm dependent) it is essential to understand both the:
* _Static_ impact of chunking (e.g. task count, chunk size in memory), and;
* _Dynamic_ impact of chunking (e.g. CPU load, thread utilisation, network communication, task count and scheduler load).

`Dask` provides tools for viewing all of these when you print out arrays in the notebook (static) and when viewing the various graphs in the dask dashboard (dynamic).

## Our example

The code below will be familiar, it's the same example from previous notebooks (seasonal mean NDVI over a large area). A normalised burn ratio (NBR2) calculation has been added as well to provide some additional load to assist in making the performance differences more noticeable in various graphs. The NBR2 uses two additional bands but is effectively the same type of calculation as the NDVI (a normalised difference ratio).

The primary difference for this example is the calculation (both NDVI and NBR) is performed 4 times, each with a different chunking regime. See the `chunk_settings` list.

__When running this notebook, be sure to have the dask dashboard open and preferably visible as calculations proceed.__

There are several sections to pay attention too:
* Status
   * Short term snapshot of the Memory use (total and per worker - also shows Managed, Unmanaged and Spilled to Disk splits), Processing and CPU usage (change the Tab to switch between them)
   * Progress of the optimized Task graph
   * Near term Task Stream (Red is comms, White space is "doing nothing", other colours mostly match the tasks and you can hover over them with the mouse to get more information)
* Tasks
   * Longer term Task Stream. This is a more comprehensive and accurate view of the Execution over time
* System
   * Scheduler CPU, Memory and Communications load. You can zoom the graphs out using the control to get a longer term view.
* Groups
   * High level view of the Task Graph Groups and their execution. The actual task graph is too detailed to display so this provides some insight into how high level aspects of your algorithm are executing.

_All_ of these graphs are dynamic and should be interpreted over time.

The dask _scheduler_ itself is also dynamic and as your code executes it stores information about how the tasks are executing and the communication occuring and adjusts scheduling accordingly. It can take a few minutes for the scheduler to settle into a true pattern. That pattern may also change, particularly in latter parts of a computation when work is completing and there are fewer tasks to execute.

Yes, that is a LOT of information. Thankfully you don't necessarily need to learn it all at once. In time, reading the information available will become easier as will knowing what to do about it.

Now let's run this notebook, remember to watch the execution in the Dask Dashboard.

> __Tip__: It's likely you will want to repeat the calculation in this notebook several times. Because the results are `persisted` to the cluster simply calling it again will result in no execution (none is required because it was `persisted`). Rather than doing `cluster.shutdown()` and creating a new cluster each time you can clear the `persisted` result by performing a `client.restart()`. This will clear out all previous calculations so you can `persist` again. You can do this either by creating a new cell or using a Python Console for this Notebook (right click on the notebook and select _New Console for Notebook_).

### Create a cluster
A modest cluster will do... _and Open the dashboard_

In [None]:
# Initialize the Gateway client
from dask.distributed import Client
from dask_gateway import Gateway

number_of_workers = 5 

gateway = Gateway()

clusters = gateway.list_clusters()
if not clusters:
    print('Creating new cluster. Please wait for this to finish.')
    cluster = gateway.new_cluster()
else:
    print(f'An existing cluster was found. Connecting to: {clusters[0].name}')
    cluster=gateway.connect(clusters[0].name)

cluster.scale(number_of_workers)

client = cluster.get_client()
client

### Setup all our functions and query parameters

Nothing special here

In [None]:
import pyproj
pyproj.set_use_global_context(True)

import git
import sys, os
from dateutil.parser import parse
from dateutil.relativedelta import relativedelta
from dask.distributed import Client, LocalCluster, wait
import datacube
from datacube.utils import masking
from datacube.utils.aws import configure_s3_access

# EASI defaults
os.environ['USE_PYGEOS'] = '0'
repo = git.Repo('.', search_parent_directories=True).working_tree_dir
if repo not in sys.path: sys.path.append(repo)
from easi_tools import EasiDefaults, notebook_utils
easi = EasiDefaults()

In [None]:
dc = datacube.Datacube()
configure_s3_access(aws_unsigned=False, requester_pays=True, client=client)

In [None]:
# Get the centroid of the coordinates of the default extents
central_lat = sum(easi.latitude)/2
central_lon = sum(easi.longitude)/2
# central_lat = -42.019
# central_lon = 146.615

# Set the buffer to load around the central coordinates
# This is a radial distance for the bbox to actual area so bbox 2x buffer in both dimensions
buffer = 1

# Compute the bounding box for the study area
study_area_lat = (central_lat - buffer, central_lat + buffer)
study_area_lon = (central_lon - buffer, central_lon + buffer)

# Data product
product = easi.product('landsat')

# Set the date range to load data over
set_time = easi.time
set_time = (set_time[0], parse(set_time[0]) + relativedelta(months=6))
#set_time = ("2021-07-01", "2021-12-31")

# Selected measurement names (used in this notebook)
alias = easi.aliases('landsat')
measurements = [alias[x] for x in ['qa_band', 'red', 'nir', 'swir1', 'swir2']]

# Set the QA band name and mask values
qa_band = alias['qa_band']
qa_mask = easi.qa_mask('landsat')

# Set the resampling method for the bands
resampling = {qa_band: "nearest", "*": "average"}

# Set the coordinate reference system and output resolution
set_crs = easi.crs('landsat')  # If defined, else None
set_resolution = easi.resolution('landsat')  # If defined, else None
# set_crs = "epsg:3577"
# set_resolution = (-30, 30)

# Set the scene group_by method
group_by = "solar_day"

In [None]:
def calc_ndvi(dataset):
    # Calculate the components that make up the NDVI calculation
    band_diff = dataset[alias['nir']] - dataset[alias['red']]
    band_sum = dataset[alias['nir']] + dataset[alias['red']]
    # Calculate NDVI
    ndvi = band_diff / band_sum
    return ndvi

def calc_nbr2(dataset):
    # Calculate the components that make up the NDVI calculation
    band_diff = dataset[alias['swir1']] - dataset[alias['swir2']]
    band_sum = dataset[alias['swir1']] + dataset[alias['swir2']]
    # Calculate NBR2
    nbr2 = band_diff / band_sum
    return nbr2

def mask(dataset, bands):
    # Identify pixels that are either "valid", "water" or "snow"
    cloud_free_mask = masking.make_mask(dataset[qa_band], **qa_mask)
    # Apply the mask
    cloud_free = dataset[bands].astype('float32').where(cloud_free_mask)
    return cloud_free

def seasonal_mean(dataset):
    return dataset.resample(time="QS-DEC").mean('time') # perform the seasonal mean for each quarter

We have an array of chunk settings to trial.

* Notice the `chunk_settings` are nominally the same size
* Notice we're varying the temporal chunking from large to small and adjusting the spatial chunking to keep the overall volume similar (nominally this will be 100 Megs per chunk for the original dataset)

There are two `time:1` chunks because 50 doesn't have a clean sqrt. The first is the nearest square, the second simply changes the chunks to be rectangles (no one said the spatial dimensions needed to be the same).

Given the chunk size in memory is roughly the same, the cluster the same, the calculation the same - any differences in execution are a result of the different chunking shape.

In [None]:
chunk_settings = [
    {"chunks": {"time":100, "x":300, "y":300}, "comment": "This run has small spatial chunks but each chunk has a lot of time steps. This results in many small file reads, but there are more total tasks for the scheduler to handle."},
    {"chunks": {"time":50, "x":1*300, "y":2*300}, "comment": "This second run has slightly larger spatial chunks but smaller temporal extents in each chunk. This results in fewer total tasks, but each one takes longer to load."},
    {"chunks": {"time":1, "x":10*300, "y":10*300}, "comment": "This run has only a single time step in each chunk, but large, square spatial extents. As a result, workers need to store much more data in memory and some data is spilled to disk."},
    {"chunks": {"time":1, "x":21*300, "y":5*300}, "comment": "Again this run has a single time step per chunk, but the spatial extents are rectangles."},
]

Now we can loop over all our `chunk_settings` and create all the required `delayed task graphs`. This will take a moment as the ODC database will be interogated for all the necessary dataset information.

_You will notice the calculation is split up so we can see the interim results_ - well the last one at least given its a loop and we're overwriting them.

Different stages of computation will produce different data types and calculations and thus _chunk_ and _task_ counts. We may find that an interim result has a terrible chunk size (e.g. `int16` data variables become `float64` and thus your chunks are now 4x the size, or a dimension is reduced and chunks are too small). It is thus advisable when tuning to make it possible to view these interim stages to see the _static_ impact.

__Remember__: there is a single task graph executing to provide the final result. There is no need to `persist()` or `compute()` the interim results to see their _static_ attributes. In fact, it may be unwise to `persist()` as this will chew up resources on the cluster if you don't intend on using the results.

In [None]:
for chunkset in chunk_settings:
    chunks = chunkset["chunks"]
    print(chunks)

In [None]:
import numpy as np
results = []
for chunkset in chunk_settings:
    chunks = chunkset["chunks"]
    dataset = dc.load(
                product=product,
                x=study_area_lon,
                y=study_area_lat,
                time=set_time,
                measurements=measurements,
                resampling=resampling,
                output_crs=set_crs,
                resolution=set_resolution,
                dask_chunks = chunks,
                group_by=group_by,
            )
    
    num_time = dataset.sizes['time']
    time_ind = np.linspace(1, num_time, 100, dtype='int') - 1
    dataset = dataset.isel(time=time_ind) # load exactly 100 evenly spaced timesteps so that we can work more easily with different chunks

    masked_dataset = mask(dataset, [alias[x] for x in ['red', 'nir', 'swir1', 'swir2']])
    ndvi = calc_ndvi(masked_dataset)
    nbr2 = calc_nbr2(masked_dataset)
    seasonal_mean_ndvi = seasonal_mean(ndvi)
    seasonal_mean_nbr2 = seasonal_mean(nbr2)
    seasonal_mean_ndvi.name = 'ndvi'
    seasonal_mean_nbr2.name = 'nbr2'
    results.append([seasonal_mean_ndvi, seasonal_mean_nbr2])

### Inspecting _static_ information

Lets take a look at the vital statistics for the final iteration of the loop. All the calculations are the same, just the `chunk` parameters vary so we can infer easily from these what else is happening for the _static_ parameters.


In [None]:
print(f"dataset size (GiB) {dataset.nbytes / 2**30:.2f}")
print(f"seasonal_mean_ndvi size (GiB) {seasonal_mean_ndvi.nbytes / 2**30:.2f}")
display(dataset)

So the source `dataset` is 150 GB in size - mostly `int16` data type. _We need to be mindful that our calculation will convert these to `floats`._ The code above does an explicit type conversion to `float32` which can fully represent an `int16`. Without the explicit type conversion, Python would use `float64` resulting in double the memory usage for no good reason (for this algorithm).

Open the _cylinder_ to show the `red` dask array details. The chunk is about 100 MiB in size. Generally this is a healthy size though it can be larger and may need to be smaller depending on the calculation involved and communication between workers.

Now let's look at the results for the NDVI and NBR:

In [None]:
display(results[0][0])
display(results[0][1])

Notice the result is _much smaller_ in chunk size - 4 MiB. This is due to the seasonal mean. This may have an impact on downstream usage of the result as the _chunks_ may be too small and result in too many tasks reducing later processing performance.

Notice also the Task count. With both results we're pushing towards 100_000 tasks in the scheduler depending on task graph optimisation. The Scheduler has its own overheads (about 1ms per active task, and memory usage for tracking all tasks, including executed ones as it keeps the history in case it needs to reproduce the results e.g. if a worker is lost). Again, it is possible to have more than 100_000 tasks and be efficient depending on your algorithm but its something to keep an eye on. We will be below it in this case (especially after optimisation).

### Persist the results

Theoretically we could `persist` all of the `results` at once - though we would be well above the 100_000 task limit if we did.
More importantly we actually want to see the difference in the _dynamics_ of the execution.
The loop below will persist each result one at a time and _wait()_ for it to be complete.

__You should monitor execution in the Dask Dashboard__

Look at the various tabs as execution proceeds. you will notice differences in memory per worker, Communication between workers (red bars in the Task Stream), white space (idle time), and CPU utilisation (remember to click on the CPU tab to get to this detail).
The `Tasks` section of the dashboard is particularly useful at looking at a comparison of all four runs' dynamics as the length of all calculations means this snapshot still show all four blocks of computation at once.

Don't forget, if you want to run the code again use `client.restart()` to clear out the previous results from the cluster.

>__Tip:__ If you leave your computer while this step is running, make sure that it doesn't go to sleep by adjusting your power settings.

In [None]:
client.wait_for_workers(n_workers=number_of_workers)

for i, result in enumerate(results):
    print(f'Run number {i+1}:')
    print(f'Chunks: {chunk_settings[i]["chunks"]}')
    print(chunk_settings[i]["comment"])
    client.restart()
    f = client.persist(result)
    %time wait(f)
    client.restart() # clearing the cluster out so each run it cleanly separated
    print()

## Understanding the dynamics



# Be a good dask user - Clean up the cluster resources

Disconnecting your client is good practice, but the cluster will still be up so we need to shut it down as well

In [None]:
client.close()

cluster.shutdown()