## Example processing Sentinel-2 data with Dask Delayed

In [None]:
import datacube
from dask import delayed
from dask.distributed import Client, LocalCluster
import matplotlib.pyplot as plt
from odc.geo import CRS, Resolution, xy_
from odc.geo.geobox import GeoBox
from odc.geo.gridspec import GridSpec
import time
import xarray as xr
from typing import List, Tuple

In [None]:
# Initialise datacube

dc = datacube.Datacube()

In [None]:
# (Central NSW)
x_min, x_max = 1200000, 1300000  # 100km wide
y_min, y_max = -3600000, -3700000  # 100km high
date_range = ("2024-01-01", "2024-02-28")

In [None]:
# Load datasets (lazy)

product = "ga_s2bm_ard_3"  # Sentinel-2 B
measurements = ["nbart_red", "nbart_blue", "oa_s2cloudless_mask"]
output_crs = "EPSG:3577"
resolution = [-30, 30]

dask_chunks = {
    "time": 1,  # Each time has its own chunk
    "y": 500,  # Each tile is 500 pixels high
    "x": 500  # Each tile is 500 pixels wide
}

ds = dc.load(product=product,
             measurements=measurements,
             crs="EPSG:3577",
             x=(x_min, x_max),
             y=(y_min, y_max),
             time=date_range,
             output_crs=output_crs,
             resolution=resolution,
             dask_chunks=dask_chunks,
             dataset_predicate=lambda ds: ds.metadata.dataset_maturity == "final",
             skip_broken_datasets=True  # Important!
             )

In [None]:
ds

In [None]:
# Break-up job into smaller tiles for distribution on cluster

# Note: Adjustment of 15m in origin is because coordinates in dataset are
# pixel centers, whereas origin wants pixel edge. So need to correct by
# half a pixel width/height.
HALF_PIXEL_WIDTH = 15
gridspec = GridSpec(
    crs=CRS("EPSG:3577"),
    tile_shape=(500, 500),  # 500x500 pixels
    resolution=Resolution(30, -30),  # Pixel resolution (meters)
    origin=xy_(ds.x.data[0]-HALF_PIXEL_WIDTH,
               ds.y.data[0]+HALF_PIXEL_WIDTH),
    flipy=True
)

# Get the geobox
geobox = ds.odc.geobox

# Use GridSpec.tiles() to get tile keys intersecting the dataset
tile_keys = list(gridspec.tiles(geobox.boundingbox))

In [None]:
len(tile_keys)  # 49 = 7x7

In [None]:
HALF_PIXEL_WIDTH = 15
tiles = []
for tk in tile_keys:
    ident = tk[0]
    geobox = tk[1]
    bbox = geobox.extent.boundingbox

    tile = ds.sel(
        x=slice(bbox.left+HALF_PIXEL_WIDTH, bbox.right-HALF_PIXEL_WIDTH),
        y=slice(bbox.top-HALF_PIXEL_WIDTH, bbox.bottom+HALF_PIXEL_WIDTH)
    )
    tiles.append(tile)

In [None]:
@delayed
def process_tile(
    tile_ds: xr.Dataset
) -> xr.Dataset:
    """
    Process tile
    """

    # Define computations
    tile_no_clouds_ds = tile_ds.where(tile_ds["oa_s2cloudless_mask"] == 1)
    tile_ratio_ds = tile_no_clouds_ds["nbart_red"] / tile_no_clouds_ds["nbart_blue"]
    tile_mean_ratio_ds = tile_ratio_ds.mean(dim="time", skipna=True)

    # Possible other code that runs when this function is executed (not-lazyily)

    # Return result
    return tile_mean_ratio_ds

In [None]:
# Start a local Dask cluster

cluster = LocalCluster()
client = Client(cluster)

# Optional: View the dashboard URL
print(client.dashboard_link)

In [None]:
delayed_tiles = [process_tile(t) for t in tiles]

In [None]:
futures = client.compute(delayed_tiles)

In [None]:
%%time
processed_tiles = client.gather(futures)

In [None]:
# Clean-up
client.close()
cluster.close()

In [None]:
result_ds = xr.combine_by_coords(processed_tiles)

In [None]:
result_ds

In [None]:
# Visualise mean ratio dataset

band = result_ds

# Plot with xarray’s wrapper around matplotlib
band.plot.imshow(cmap="viridis")  # or cmap='gray', 'RdYlGn', etc.
plt.title("Result")
plt.xlabel("x")
plt.ylabel("y")
plt.show()