# ODC and Dask (LocalCluster) <img align="right" src="../resources/csiro_easi_logo.png">

This notebook explores the use of ODC with Dask LocalCluster. The goal is to introduce fundamental concepts and the role Dask can serve with `datacube` and subsequent computation using `xarray`.

The example computation is fairly typical of an EO data processing pipeline. We'll be using a small area and time period to start with and progressively scaling this example. EO scientists may find some aspects of these examples unrealistic, but this isn't an EO science course. :-). 

For the base example we'll be using the Australian island state of Tasmania as our Region of Interest (ROI). Intially a paddock size, and progressively increasing to the entire island.
The basic algorithm is:
  1. Specify Region of Interest, Satellite products, EO satellite bands, Time range, desired CRS for the `datacube` query
  1. Load data using `datacube.load()`
  1. Mask valid data
  1. Visualisation of the ROI
  1. Compute NDVI
  1. Visualise NDVI
  
  
__Some cells in this notebook will take minutes to run so be patient__

In [None]:
import datacube
from datacube.utils import masking

The next cell sets out all the query parameters used in our `datacube.load()`.
For this run we keep the ROI quite small.

In [None]:
# Central Tasmania (near Little Pine Lagoon)
central_lat = -42.019
central_lon = 146.615

# Set the buffer to load around the central coordinates
# This is a radial distance for the bbox to actual area so bbox 2x buffer in both dimensions
buffer = 0.05

# Compute the bounding box for the study area
study_area_lat = (central_lat - buffer, central_lat + buffer)
study_area_lon = (central_lon - buffer, central_lon + buffer)

# Data products - Landsat 8 ARD from Geoscience Australia
products = ["ga_ls8c_ard_3"]

# Set the date range to load data over - just a month for the moment
set_time = ("2021-01-01", "2021-01-31")

# Set the measurements/bands to load. None will load all of them
measurements = None

# Set the coordinate reference system and output resolution
# This choice corresponds to Australian Albers, with resolution in metres
set_crs = "epsg:3577"
set_resolution = (-30, 30)
group_by = "solar_day"



Now initialise the `datacube`.

In [None]:
dc = datacube.Datacube()

Now load the data. We use `%%time` to keep track of how long things take to complete.

In [None]:
%%time
dataset = None # clear results from any previous runs
dataset = dc.load(
            product=products,
            x=study_area_lon,
            y=study_area_lat,
            time=set_time,
            measurements=measurements,
            resampling={"fmask": "nearest", "*": "average"},
            output_crs=set_crs,
            resolution=set_resolution,
            group_by=group_by,
        )

The result of the `datacube.load()` function is an `xarray.Dataset`. The notebook can be used to render a description of the `dataset` variable as an html block with a _lot of useful information_ about the structure of data.
If you open up the `Data variables` (click the > Data variables) and click on the stacked cylinders for one of them (nbart_red, nbart_green, ...)  you will see the actual data array is available and shown in summary form.

This visualisation will become increasingly importantly when dask is enabled and as scale out occurs so take a moment now to just poke around the interface.
Notice that at this stage we have 5 data variables, 4 time observations and each observation is y:391, by x:323 pixels (30 m pixels). We're at _paddock scale_.

In [None]:
dataset

Next up filter out everything that isn't an `fmask:valid` pixel and compute the NDVI. Since we aren't specifying a time range this will be performed for all images.

In [None]:
%%time
# Identify pixels that are either "valid", "water" or "snow"
cloud_free_mask = (
    masking.make_mask(dataset.oa_fmask, fmask="valid")
)
# Apply the mask
cloud_free = dataset.where(cloud_free_mask)

# Calculate the components that make up the NDVI calculation
band_diff = cloud_free.nbart_nir - cloud_free.nbart_red
band_sum = cloud_free.nbart_nir + cloud_free.nbart_red
# Calculate NDVI
ndvi = None
ndvi = band_diff / band_sum

The result `ndvi` is an `xarray.DataArray`. Let's take a look at it. Again the notebook will render an html version of the data in summary form.
Notice again the actual data values are being shown and that there are 4 time slices and the shape is the same.

In [None]:
ndvi

Raw numbers aren't nice to look at so let's draw a time slice. We'll select just one of them to draw and pick one that didn't get masked out by cloud completely. The masked out white bit is _Little Pine Lagoon_, a water body.

In [None]:
ndvi.isel(time=1).plot()

# Exploring Dask with the ODC - Concepts

Let's set our time range to a couple of weeks, or approximately two passes of Landsat 8 for this ROI. Less data will allow us to explore how dask works with the `datacube` and `xarray` libraries.

In [None]:
set_time = ("2021-01-01", "2021-01-14")

In [None]:
%%time
dataset = None # clear results from any previous runs
dataset = dc.load(
            product=products,
            x=study_area_lon,
            y=study_area_lat,
            time=set_time,
            measurements=measurements,
            resampling={"fmask": "nearest", "*": "average"},
            output_crs=set_crs,
            resolution=set_resolution,
            group_by=group_by,
        )
dataset

As before you can see the actual data in the results but this time there are only 2 observation times

Now let's create a `LocalCluster` as we did in the earlier notebook.

In [None]:
from dask.distributed import Client, LocalCluster

cluster = LocalCluster()
client = Client(cluster)
client

You may like to open up the dashboard for the cluster, although for this notebook we won't be talking about the dashboard (that's for a later discussion).

In [None]:
import os
user = os.environ.get("JUPYTERHUB_USER")
dashboard_address=f'https://hub.csiro.easi-eo.solutions/user/{user}/proxy/8787/status'
print(dashboard_address)

`datacube.load()` will use the default `dask` cluster (the one we just created) if the `dask_chunks` parameter is specified.

The chunk shape and memory size is a critial parameter in tuning `dask` and we will be discussing it in great detail as scale increases. For now we're simply going to specify that the `time` dimension should individually chunked (`1` slice of time) and by not specifying any chunking for the other dimensions they will be form a single contiguous block.

If that made no sense what's so ever, that's fine because we will look at an example.

In [None]:
chunks = {"time":1}

In [None]:
%%time
dataset = None # clear results from any previous runs
dataset = dc.load(
            product=products,
            x=study_area_lon,
            y=study_area_lat,
            time=set_time,
            measurements=measurements,
            resampling={"fmask": "nearest", "*": "average"},
            output_crs=set_crs,
            resolution=set_resolution,
            dask_chunks = chunks, ###### THIS IS THE ONLY LINE CHANGED. #####
            group_by=group_by,
        )
dataset

First thing you probably noticed is that whilst only one line changed the load time dropped to sub-seconds!
The second thing you probably noticed is if you poked around the `data variables` as before there is no data but a nice diagram. It's really fast because it didn't do anything!

When `datatcube` has `dask_chunks` specified it switches to creating `xarrays` with `dask.arrays` in the backend and `lazy loads` them - no data is loaded until used. If you look at one of the data variables you will see it now has `dask.array<....>` rather than values and the cylinder icon will show the Array _and_ Chunk parameters, not actual data.

The `datacube.load()` has used the `dask.Delayed` interface which will not perform any `tasks` until the _result_ of the `task` is actually required. We'll load the data in a moment but first let's take a look at the parameters in that pretty visualisation. Click on the cylinder for the `red` Data variables and look at the table and the figure. You can see that:
  1. The Array is `493.33 kiB` in total size and is broken into Chunks which have size `246.67 kiB`
  2. The Array shape is `(2, 391, 323) (time, y, x)` but each chunk is `(1,391,323)` because we specified the `time` dimension should have chunks of length `1`.
  3. The Array has `4` tasks - this is the number of tasks that will be executed in order to load the data. There are `2` chunk tasks, one for each time slice.
  4. The Array type is `int16` and is split up into chunks which are `numpy.ndarrays`.
  
The chunking has split the array loading into two Chunks. __Dask can execute these in parallel.__

We can look at the delayed tasks and how they will be executed by visualising the task graph for one of the variables. We'll use the red band measurement.

In [None]:
dataset.nbart_red.data.visualize()

Details on the task graph can be found in the dask user guide but what's clear is you have two independent paths of execution which produce one time slice each (0,0,0) and (1,0,0) these are the two chunks that that full array has been split into.

To retrieve the actual data we need to `compute()` the result, this will cause all the delayed tasks to be executed for the variable we are computing. Let's `compute()` the red variable.

In [None]:
%%time
actual_red = dataset.nbart_red.compute()
actual_red

As you can see we now have actual data. You can do the same thing for all arrays in the dataset in one go by computing the dataset itself.

In [None]:
%%time
actual_dataset = dataset.compute()
actual_dataset

## The impact of dask on ODC

From the above we can see that specifying `dask_chunks` in `datacube.load()` splits up the `load()` operation into a set of `chunk` shaped arrays and `delayed` _tasks_. Dask can now perform those tasks in _parallel_. Dask will only _compute_ the results for those parts of the data we are using but we can force the computation of all the `delayed` _tasks_ using `compute()`.

There is a _lot_ more opportunity than described in this simple example but let's just focus on the impact of dask on ODC for this simple case.

The time period and ROI are far to small to be interesting so let's change our time range to a full year of data.

In [None]:
set_time = ("2021-01-01", "2021-12-31")

First load the data without dask (no `dask_chunks` specified), this will take several minutes so be patient

In [None]:
%%time
dataset = None # clear results from any previous runs
dataset = dc.load(
            product=products,
            x=study_area_lon,
            y=study_area_lat,
            time=set_time,
            measurements=measurements,
            resampling={"fmask": "nearest", "*": "average"},
            output_crs=set_crs,
            resolution=set_resolution,
            group_by=group_by,
        )
dataset

46 time observations and in the order of 4-5 minutes to load.

Let's enable dask and repeat the load. We're chunking by time (length one) so dask will be able to load each time slice in parallel. The data variables are also independent so will be done in parallel as well.

In [None]:
chunks = {"time":1}

In [None]:
%%time
dataset = None # clear results from any previous runs
dataset = dc.load(
            product=products,
            x=study_area_lon,
            y=study_area_lat,
            time=set_time,
            measurements=measurements,
            resampling={"fmask": "nearest", "*": "average"},
            output_crs=set_crs,
            resolution=set_resolution,
            dask_chunks = chunks, ###### THIS IS THE ONLY LINE CHANGED. #####
            group_by=group_by,
        )
dataset

Woah!! that was fast - but we didn't actually compute anything so no load has occurred and all tasks are pending.
Open up the Data Variables, click the stacked cylinders and take a look at the delayed task counts. These exist for every variable.

Let's visualise the _task graph_ for the `nbart_red` band.

In [None]:
dataset.nbart_red.data.visualize()

Well that's not as useful, is it!

You should just be able to make out that each of the _chunks_ are able to independently `load()`. `time` _chunk_ is length 1 so these are individual times. This holds true for all the bands so dask can spread these out across multiple threads.

> __Tip__: Visualising task graphs is less effective as your task graph complexity increases. You may need to use simpler examples to see what is going on.

Let's get the actual data

In [None]:
%%time
actual_dataset = dataset.compute()

How fast this step is will depend on how many cores are in your Jupyter notebook's local cluster. For an 8-core cluster the `datacube.load()` is taking roughly a 1/4 of the time compared to without `dask`. This is great!

Why not 1/8 of time?

Dask has overheads, and `datacube.load()` itself is IO bound. There are all sorts of things that result in limits and part of the art of parallel computing is tuning your algorithm to reduce the impact of these and achieve greater performnance. As we scale up this example we'll explore some of these.

> __Tip__: Do not expect 8x as many cores to produce 8x the speed up. Algorithms can be tuned to perform better (or worse) as scale increases. This is part of the art of parallel programming. Dask does it's best, and you can often do better.

# Exploiting delayed tasks

Now let's repeat the full example, with NDVI calculation and masking, but this time with `dask` and `compute` to load the data in.

First the `dc.load()`...

In [None]:
chunks = {"time":1}

In [None]:
%%time
dataset = None # clear results from any previous runs
dataset = dc.load(
            product=products,
            x=study_area_lon,
            y=study_area_lat,
            time=set_time,
            measurements=measurements,
            resampling={"fmask": "nearest", "*": "average"},
            output_crs=set_crs,
            resolution=set_resolution,
            dask_chunks = chunks,
            group_by=group_by,
        )
actual_dataset = dataset.compute()

In [None]:
actual_dataset

Now use the `actual_result` to compute the NDVI for all observation times

In [None]:
%%time
# Identify pixels that are either "valid", "water" or "snow"
cloud_free_mask = (
    masking.make_mask(actual_dataset.oa_fmask, fmask="valid")
)
# Apply the mask
cloud_free = actual_dataset.where(cloud_free_mask)

# Calculate the components that make up the NDVI calculation
band_diff = cloud_free.nbart_nir - cloud_free.nbart_red
band_sum = cloud_free.nbart_nir + cloud_free.nbart_red
# Calculate NDVI and store it as a measurement in the original dataset ta da
ndvi = None
ndvi = band_diff / band_sum

Most of the time is in IO, the actual calculation is < 1 second.

Now let's repeat that entire load and NDVI calculation in a single cell and time it - this is just to get the total time for later comparison.

In [None]:
%%time
dataset = None # clear results from any previous runs
dataset = dc.load(
            product=products,
            x=study_area_lon,
            y=study_area_lat,
            time=set_time,
            measurements=measurements,
            resampling={"oa_fmask": "nearest", "*": "average"},
            output_crs=set_crs,
            resolution=set_resolution,
            dask_chunks = chunks, 
            group_by=group_by,
        )
actual_dataset = dataset.compute() ### Compute the dataset ###
# Identify pixels that are either "valid", "water" or "snow"
cloud_free_mask = (
    masking.make_mask(actual_dataset.oa_fmask, fmask="valid")
)
# Apply the mask
cloud_free = actual_dataset.where(cloud_free_mask)

# Calculate the components that make up the NDVI calculation
band_diff = cloud_free.nbart_nir - cloud_free.nbart_red
band_sum = cloud_free.nbart_nir + cloud_free.nbart_red
# Calculate NDVI and store it as a measurement in the original dataset ta da
ndvi = None
ndvi = band_diff / band_sum

40-50 seconds (for an 8-core cluster) or so. We can do better...

## Data and computational locality

When `compute()` is called `dask` not only executes all the tasks but it consolidates all the distributed chunks back into a normal array on the client machine - in this case the notebook's kernel. In the previous cell we have two variables that both refer to the data we are loading:
1. _dataset_ refers to the `delayed` version of the data. The `delayed` _tasks_ and the _chunks_ that make it up will be __on the cluster__
2. _actual_result_ refers to the actual array in the notebook kernel memory after execution of the _tasks_. The _actual_result_ is a complete array in memory in the notebook kernel (__on the _client___).

So in the previous cell everything _after_ the `actual_dataset = dataset.compute()` line is computed in the Jupyter kernel and doesn't use the dask cluster at all for computation.

If we shift the location of this `compute()` call we can perform more _tasks_ in parallel on the dask cluster. 

> __Tip__: Locality is an important concept and applies to both data and computation

Now let's repeat the load and NDVI calculation but this time rather than `compute()` on the full `dataset` we'll run `cloud_free = dataset.where(cloud__free_mask).compute()` so the masking operation can be performed in parallel. Let's see what the impact is...


In [None]:
%%time
dataset = None # clear results from any previous runs
dataset = dc.load(
            product=products,
            x=study_area_lon,
            y=study_area_lat,
            time=set_time,
            measurements=measurements,
            resampling={"oa_fmask": "nearest", "*": "average"},
            output_crs=set_crs,
            resolution=set_resolution,
            dask_chunks = chunks, 
            group_by=group_by,
        )

# Identify pixels that are either "valid", "water" or "snow"
cloud_free_mask = (
    masking.make_mask(dataset.oa_fmask, fmask="valid")
)
# Apply the mask
cloud_free = dataset.where(cloud_free_mask).compute()    ### COMPUTE MOVED HERE ###

# Calculate the components that make up the NDVI calculation
band_diff = cloud_free.nbart_nir - cloud_free.nbart_red
band_sum = cloud_free.nbart_nir + cloud_free.nbart_red
# Calculate NDVI and store it as a measurement in the original dataset ta da
ndvi = None
ndvi = band_diff / band_sum
actual_ndvi = ndvi

Not that different... Not too surprising since the masking operation is pretty quick (it's all numpy) and the IO is the bulk of the processing.

Dask can see the entire task graph for both load and mask computation. As a result _some_ of the computation can be performed concurrently with file IO, and CPUs are busier as a result, so it will be slightly faster in practice but with IO dominating we won't see much overall improvement.

Perhaps doing more of the calculation on the cluster will help. Let's also move `ndvi.compute()` so the entire calculation is done on the cluster and only the final result returned to the client.

In [None]:
%%time
dataset = None # clear results from any previous runs
dataset = dc.load(
            product=products,
            x=study_area_lon,
            y=study_area_lat,
            time=set_time,
            measurements=measurements,
            resampling={"oa_fmask": "nearest", "*": "average"},
            output_crs=set_crs,
            resolution=set_resolution,
            dask_chunks = chunks, 
            group_by=group_by,
        )

# Identify pixels that are either "valid", "water" or "snow"
cloud_free_mask = (
    masking.make_mask(dataset.oa_fmask, fmask="valid")
)
# Apply the mask
cloud_free = dataset.where(cloud_free_mask)

# Calculate the components that make up the NDVI calculation
band_diff = cloud_free.nbart_nir - cloud_free.nbart_red
band_sum = cloud_free.nbart_nir + cloud_free.nbart_red
# Calculate NDVI and store it as a measurement in the original dataset ta da
ndvi = None
ndvi = band_diff / band_sum
actual_ndvi = ndvi.compute()    ### COMPUTE MOVED HERE ###

Now we are seeing a huge difference!

You may be thinking "Hold on a sec, the NDVI calculation is pretty quick in this example with such a small dataset, why such a big difference?" - and you'd be right. There is more going on.

Remember that `dataset` is a _task graph_ with `delayed` tasks waiting to be executed __when the result is required__. In the example `dataset`, 22 data variables are available but _only 3 are used_ to produce the `ndvi` (`oa_fmask`, `nbart_red` and `nbart_nir`). As a result _`dask` doesn't load the other 19 variables_ and because computation time in this case is mostly IO related the execution time is a LOT faster.

Of course we can save `dask` the trouble of figuring this out on our behalf and only `load()` the `measurements` we need in the first place. Let's check that now, we should see a similar performance figure.


In [None]:
%%time
dataset = None # clear results from any previous runs
measurements = [ "oa_fmask", "nbart_red", "nbart_nir"]
dataset = dc.load(
            product=products,
            x=study_area_lon,
            y=study_area_lat,
            time=set_time,
            measurements=measurements,
            resampling={"oa_fmask": "nearest", "*": "average"},
            output_crs=set_crs,
            resolution=set_resolution,
            dask_chunks = chunks, 
            group_by=group_by,
        )

# Identify pixels that are either "valid", "water" or "snow"
cloud_free_mask = (
    masking.make_mask(dataset.oa_fmask, fmask="valid")
)
# Apply the mask
cloud_free = dataset.where(cloud_free_mask)

# Calculate the components that make up the NDVI calculation
band_diff = cloud_free.nbart_nir - cloud_free.nbart_red
band_sum = cloud_free.nbart_nir + cloud_free.nbart_red
# Calculate NDVI and store it as a measurement in the original dataset ta da
ndvi = None
ndvi = band_diff / band_sum
actual_ndvi = ndvi.compute()

Pretty similar as expected.
Now it can pay to give `dask` a hand and not have the _task graph_ cluttered with tasks you are not going to use. Still it's nice to see that `dask` can save you some time by only computing what is required when you need it.

# A quick check on the task graph

For completeness we will take a look at the _task graph_ for the full calculation, all the way to the NDVI result. Given the complexity of the full graph we'll simplify it to 2 time observations like we did when the task graph was introduced previously.


In [None]:
set_time = ("2021-01-01", "2021-01-14")

In [None]:
%%time
dataset = None # clear results from any previous runs
measurements = [ "oa_fmask", "nbart_red", "nbart_nir"]
dataset = dc.load(
            product=products,
            x=study_area_lon,
            y=study_area_lat,
            time=set_time,
            measurements=measurements,
            resampling={"oa_fmask": "nearest", "*": "average"},
            output_crs=set_crs,
            resolution=set_resolution,
            dask_chunks = chunks, 
            group_by=group_by,
        )

# Identify pixels that are either "valid", "water" or "snow"
cloud_free_mask = (
    masking.make_mask(dataset.oa_fmask, fmask="valid")
)
# Apply the mask
cloud_free = dataset.where(cloud_free_mask)

# Calculate the components that make up the NDVI calculation
band_diff = cloud_free.nbart_nir - cloud_free.nbart_red
band_sum = cloud_free.nbart_nir + cloud_free.nbart_red
# Calculate NDVI and store it as a measurement in the original dataset ta da
ndvi = None
ndvi = band_diff / band_sum


In [None]:
ndvi.data.visualize()

The computation flows from bottom to top in the _task graph_. You can see there are two main paths, one for each time (since the time chunk is length 1). You can also see the three data sources are loaded independently. After that it gets a little more difficult to follow but you can see `oa_fmask` being used to produce the mask (and, eq_). Then combined via the `where` function with other two datasets. Then finally the NDVI calculation - a sub, add and divide (truediv).

Dask has lots of internal optimizations that it uses to help identify the dependencies and parallel components of a task graph. Sometimes it will reorder or prune operations where possible to further optimise (for example, not loading _data variables_ that aren't used in the NDVI calculation).

> __Tip__: The _task graph_ can be complex but it is a useful tool in understanding your algorithm and how it scales.

# Be a good dask user - Clean up the cluster resources

In [None]:
client.close()

cluster.close()