## Example processing Sentinel-2 data with Dask Futures

In [None]:
import datacube
from dask.distributed import Client, LocalCluster
import matplotlib.pyplot as plt
from odc.geo import CRS, Resolution, xy_
from odc.geo.geobox import GeoBox
from odc.geo.gridspec import GridSpec
import time
import xarray as xr
from typing import List, Tuple

In [None]:
METHOD = 2  # [1 or 2]

In [None]:
# Initialise datacube

dc = datacube.Datacube()

In [None]:
# (Central NSW)
x_min, x_max = 1200000, 1300000  # 100km wide
y_min, y_max = -3600000, -3700000  # 100km high
date_range = ("2024-01-01", "2024-02-28")

In [None]:
# Load datasets (lazy)

product = "ga_s2bm_ard_3"  # Sentinel-2 B
measurements = ["nbart_red", "nbart_blue", "oa_s2cloudless_mask"]
output_crs = "EPSG:3577"
resolution = [-30, 30]

dask_chunks = {
    "time": 1,  # Each time has its own chunk
    "y": 500,  # Each tile is 500 pixels high
    "x": 500  # Each tile is 500 pixels wide
}

ds = dc.load(product=product,
             measurements=measurements,
             crs="EPSG:3577",
             x=(x_min, x_max),
             y=(y_min, y_max),
             time=date_range,
             output_crs=output_crs,
             resolution=resolution,
             dask_chunks=dask_chunks,
             dataset_predicate=lambda ds: ds.metadata.dataset_maturity == "final",
             skip_broken_datasets=True  # Important!
             )

In [None]:
ds

In [None]:
# Define some computations

no_clouds_ds = ds.where(ds["oa_s2cloudless_mask"] == 1)
ratio_ds = no_clouds_ds["nbart_red"] / no_clouds_ds["nbart_blue"]
mean_ratio_ds = ratio_ds.mean(dim="time", skipna=True)

In [None]:
# Break-up job into smaller tiles for distribution on cluster

# Note: Adjustment of 15m in origin is because coordinates in dataset are
# pixel centers, whereas origin wants pixel edge. So need to correct by
# half a pixel width/height.
HALF_PIXEL_WIDTH = 15
gridspec = GridSpec(
    crs=CRS("EPSG:3577"),
    tile_shape=(500, 500),  # 500x500 pixels
    resolution=Resolution(30, -30),  # Pixel resolution (meters)
    origin=xy_(mean_ratio_ds.x.data[0]-HALF_PIXEL_WIDTH,
               mean_ratio_ds.y.data[0]+HALF_PIXEL_WIDTH),
    flipy=True
)

# Get the geobox
geobox = mean_ratio_ds.odc.geobox

# Use GridSpec.tiles() to get tile keys intersecting the dataset
tile_keys = list(gridspec.tiles(geobox.boundingbox))

In [None]:
len(tile_keys)  # 49 = 7x7

In [None]:
if METHOD == 1:
    HALF_PIXEL_WIDTH = 15
    tiles = []
    for tk in tile_keys:
        ident = tk[0]
        geobox = tk[1]
        bbox = geobox.extent.boundingbox

        tile = mean_ratio_ds.sel(
            x=slice(bbox.left+HALF_PIXEL_WIDTH, bbox.right-HALF_PIXEL_WIDTH),
            y=slice(bbox.top-HALF_PIXEL_WIDTH, bbox.bottom+HALF_PIXEL_WIDTH)
        )
        tiles.append(tile)

In [None]:
# Start a local Dask cluster

cluster = LocalCluster()
client = Client(cluster)

# Optional: View the dashboard URL
print(client.dashboard_link)

In [None]:
def process_tile(
    tile_index: int,
    tile_keys: List[Tuple[Tuple[int, int], GeoBox]],
    ds: xr.Dataset
) -> xr.Dataset:
    """
    Process tile
    """

    HALF_PIXEL_WIDTH = 15

    if tile_index < 0 or tile_index >= len(tile_keys):
        raise ValueError("tile_index is outside valid range!")

    # Extract tile info for tile index
    tk = tile_keys[tile_index]
    geobox = tk[1]
    bbox = geobox.extent.boundingbox

    # Extract tile to process
    tile_ds = ds.sel(
        x=slice(bbox.left+HALF_PIXEL_WIDTH, bbox.right-HALF_PIXEL_WIDTH),
        y=slice(bbox.top-HALF_PIXEL_WIDTH, bbox.bottom+HALF_PIXEL_WIDTH)
    )

    # Define computations
    tile_no_clouds_ds = tile_ds.where(tile_ds["oa_s2cloudless_mask"] == 1)
    tile_ratio_ds = tile_no_clouds_ds["nbart_red"] / tile_no_clouds_ds["nbart_blue"]
    tile_mean_ratio_ds = tile_ratio_ds.mean(dim="time", skipna=True)

    # Compute and return result
    return tile_mean_ratio_ds.compute()

In [None]:
if METHOD == 1:
    futures = [client.compute(t) for t in tiles]

if METHOD == 2:
    futures = [client.submit(process_tile, ti, tile_keys, ds)
               for ti in range(len(tile_keys))]

In [None]:
%%time
processed_tiles = client.gather(futures)

In [None]:
# Clean-up
client.close()
cluster.close()

In [None]:
result_ds = xr.combine_by_coords(processed_tiles)

In [None]:
result_ds

In [None]:
# Visualise mean ratio dataset

band = result_ds

# Plot with xarray’s wrapper around matplotlib
band.plot.imshow(cmap="viridis")  # or cmap='gray', 'RdYlGn', etc.
plt.title("Result")
plt.xlabel("x")
plt.ylabel("y")
plt.show()