# Dask Local Cluster - Larger than memory computation <img align="right" src="../resources/csiro_easi_logo.png">

In the ODC and Dask (LocalCluster) notebook we saw how dask can be used to speed up IO and computation by parallelising operations into _chunks_ and _tasks_, and using _delayed tasks_ and _task graph_ optimization to remove redundant tasks when results are not used.

Using _chunks_ provides one additional capability beyond parallelisation - _the ability to perform computations that are larger than available memory_.

Since dask operations are performed on _chunks_ it is possible for dask to perform operations on smaller pieces that each fit into memory. This is particularly useful if you have a large amount of data that is being reduced, say by performing a seasonal mean.

As with parallelisation, not all algorithms are amenable to being broken into smaller pieces so this won't always be possible. Dask arrays though go a long way to make this easier for a great many operations.

We'll continue using the same algorithm as before but this time we're going to modify it's memory usage to exceed the LocalCluster's available memory. This example notebook is setup to run on a compute node with 28 GiB of available memory and 8 cores for the LocalCluster. We'll make that explicit here in case you are blessed with a larger number of resources.

Let's start the cluster...

In [None]:
from dask.distributed import Client, LocalCluster

cluster = LocalCluster(n_workers=2, threads_per_worker=4)
cluster.scale(n=2, memory="14GiB")
client = Client(cluster)
client

We can monitor memory usage on the workers using the dask dashboard and the Status tab. The workers are local so this will be memory on the same compute node that Jupyter is running in.  

In [None]:
import os
user = os.environ.get("JUPYTERHUB_USER")
dashboard_address=f'https://hub.csiro.easi-eo.solutions/user/{user}/proxy/8787/status'
print(dashboard_address)

In [None]:
import datacube
from datacube.utils import masking

dc = datacube.Datacube()

In [None]:
# Central Tasmania (near Little Pine Lagoon)
central_lat = -42.019
central_lon = 146.615

# Set the buffer to load around the central coordinates
# This is a radial distance for the bbox to actual area so bbox 2x buffer in both dimensions
buffer = 0.05

# Compute the bounding box for the study area
study_area_lat = (central_lat - buffer, central_lat + buffer)
study_area_lon = (central_lon - buffer, central_lon + buffer)

# Data products - Landsat 8 ARD from Geoscience Australia
products = ["ga_ls8c_ard_3"]

# Set the date range to load data over 
set_time = ("2021-01-01", "2021-12-31")

# Set the measurements/bands to load. None eill load all of them
measurements = None

# Set the coordinate reference system and output resolution
# This choice corresponds to Aussie Albers, with resolution in metres
set_crs = "epsg:3577"
set_resolution = (-30, 30)
group_by = "solar_day"

In [None]:
dataset = None # clear results from any previous runs
dataset = dc.load(
            product=products,
            x=study_area_lon,
            y=study_area_lat,
            time=set_time,
            measurements=measurements,
            resampling={"fmask": "nearest", "*": "average"},
            output_crs=set_crs,
            resolution=set_resolution,
            dask_chunks =  {"time":1},
            group_by=group_by,
        )
dataset

We can check the total size of the dataset using `nbytes`. We'll divide by 2**30 to have the result display in [gibibytes](https://simple.wikipedia.org/wiki/Gibibyte).

In [None]:
dataset.nbytes / 2**30

As you can see this ROI and spatial range (1 year) is tiny, let's scale up by increasing our ROI


In [None]:
buffer = 0.8

# Compute the bounding box for the study area
study_area_lat = (central_lat - buffer, central_lat + buffer)
study_area_lon = (central_lon - buffer, central_lon + buffer)

In [None]:
dataset = None # clear results from any previous runs
dataset = dc.load(
            product=products,
            x=study_area_lon,
            y=study_area_lat,
            time=set_time,
            measurements=measurements,
            resampling={"fmask": "nearest", "*": "average"},
            output_crs=set_crs,
            resolution=set_resolution,
            dask_chunks =  {"time":1},
            group_by=group_by,
        )
dataset.nbytes / 2**30

Okay, larger than available memory.

Let's take a look at the memory usage for one of the bands, we'll use `nbart_red`.

In [None]:
dataset.nbart_red

You can see the year now has more time observations (69) because we've expanded the ROI and picked up multiple satellite passes. The spatial dimensions are also much larger.

Take a note of the _Chunk Bytes_ - 61.58 MiB. This is the smallest unit of this dataset that dask will work on. To do an NDVI calculation, dask will need two bands, the mask, the result and a few temporaries in memory at once. This means whilst this value is an indicator of memory required on a worker to perform an operation it is not the total, which will depend on the operation.

We can adjust the amount of memory per chunk further by _chunking_ the spatial dimension. Let's split it into 2048x2048 size pieces.

In [None]:
dataset = None # clear results from any previous runs
dataset = dc.load(
            product=products,
            x=study_area_lon,
            y=study_area_lat,
            time=set_time,
            measurements=measurements,
            resampling={"fmask": "nearest", "*": "average"},
            output_crs=set_crs,
            resolution=set_resolution,
            dask_chunks =  {"time":1, "x":2048, "y":2048},  ## Adjust the chunking spatially as well
            group_by=group_by,
        )
dataset.nbytes / 2**30

As you can see the total dataset size stays the same. 

Look at the `nbart_red` data variable. You can see the chunk size has reduced to 8 MiB, and there are now 828 chunks - compared with 69 previously. The number of Tasks has increased proportionately too. This makes sense: smaller chunks, more tasks.

> __TIP__: The _relationship between tasks and chunks_ is a critical tuning parameter.

Workers have limits in memory and compute capacity. The Dask Scheduler has limits in how many tasks it can manage efficiently (and remember it is tracking all of the data variables, not just this one). Later, when we move to a fully remote and distributed cluster, _chunks_ also become an important element in communicating between workers over networks.

If you look carefully at the figure you will see that some internal lines showing the chunk boundaries for the spatial dimensions. 2048 wasn't an even multiplier so dask has made these ones smaller. The specification of `chunks` is a guide: the actual data, numpy arrays in this case, are made into `chunk` sized shapes or smaller. These are called `blocks` in dask and represent the actual shape of the numpy array that will be processed.

Somewhat confusingly the terms `blocks` and `chunks` are also used in dask literature and you'll need to check the context to see if it is referring to the _specification_ or the _actual block of data_. For the moment this differentiation doesn't matter but when performing low level custom operations knowing that your `blocks` might be a different shape does matter.

In [None]:
dataset.nbart_red

We won't worry to much about tuning these parameters right now and instead will focus on processing this 130 GiB dataset. As before we can exploit dask's ability to use _delayed_ tasks and apply our masking and NDVI directly to the 130 GiB dataset. We'll also add an unweighted seasonal mean calculation using `groupby("time.season").mean("time")`. Dask will seek to complete the reductions (by chunk) first as they reduce memory usage.

It's probably worth monitoring the dask cluster memory usage via the dashboard _Workers Memory_ to see just how little ram is actually used during this calculation despite it being performed on a 130 GiB dataset.

In [None]:
print(dashboard_address)

In [None]:
# Identify pixels that are either "valid", "water" or "snow"
cloud_free_mask = (
    masking.make_mask(dataset.oa_fmask, fmask="valid")
)
# Apply the mask
cloud_free = dataset.where(cloud_free_mask)

# Calculate the components that make up the NDVI calculation
band_diff = cloud_free.nbart_nir - cloud_free.nbart_red
band_sum = cloud_free.nbart_nir + cloud_free.nbart_red
# Calculate NDVI and store it as a measurement in the original dataset ta da
ndvi = None
ndvi = band_diff / band_sum

ndvi_unweighted = ndvi.groupby("time.season").mean("time")  # Calculate the seasonal mean

Let's check the shape of our result - it should have 4 seasons now.

In [None]:
ndvi_unweighted

Before we do the `compute()` to get our result we should make sure the final result will fit in memory for the Jupyter kernel

In [None]:
ndvi_unweighted.nbytes  / 2**30

From 130 GiB down to < 1 Gig for the result.

If you are monitoring the cluster at this point you will notice a delay between running the next cell and actual computation occuring. Dask performs a _task graph optimisation_ step on the _client_ not the cluster. How long this takes depends on the number of tasks and complexity of the graph. We'll talk more about this later.

In the meantime, run the next cell and watch dask compute the result without running out of memory.

In [None]:
actual_result = ndvi_unweighted.compute()

Let's plot the result for the summer (DJF). This will take a few seconds, the image is several thousand pixels across.

In [None]:
actual_result.sel(season='DJF').plot()

Not the most useful visualisation as a thumbnail, and a little sluggish. Dask can help with this too but that's a topic for another notebook.

# Be a good dask user - Clean up the cluster resources

In [None]:
client.close()

cluster.close()