## Example processing Sentinel-2 data with Dask (customized workers)
- Customizing remote workers

In [None]:
import datacube
import datacube.utils.cog
from dask_gateway import Gateway
from dask.distributed import Client, WorkerPlugin
import subprocess
import sys
import time
import xarray as xr

In [None]:
# Initialise datacube

dc = datacube.Datacube()

In [None]:
# (Central NSW)
x_min, x_max = 1200000, 1300000  # 100km wide
y_min, y_max = -3600000, -3700000  # 100km high
date_range = ("2024-01-01", "2024-02-28")

In [None]:
# Load datasets (lazy)

product = "ga_s2bm_ard_3"  # Sentinel-2 B
measurements = ["nbart_red", "nbart_blue", "oa_s2cloudless_mask"]
output_crs = "EPSG:3577"
resolution = [-30, 30]

dask_chunks = {
    "time": 1,  # Each time has its own chunk
    "y": 500,  # Each tile is 500 pixels high
    "x": 500  # Each tile is 500 pixels wide
}

ds = dc.load(product=product,
             measurements=measurements,
             crs="EPSG:3577",
             x=(x_min, x_max),
             y=(y_min, y_max),
             time=date_range,
             output_crs=output_crs,
             resolution=resolution,
             dask_chunks=dask_chunks,
             dataset_predicate=lambda ds: ds.metadata.dataset_maturity == "final",
             skip_broken_datasets=True  # Important!
             )

In [None]:
ds

In [None]:
class PipInstallerPlugin(WorkerPlugin):

    def __init__(self, packages):
        self.packages = packages

    def setup(self, worker):
        for package in self.packages:
            subprocess.check_call([sys.executable, "-m", "pip", "install", package])

In [None]:
# Start a remote Dask cluster

gateway = Gateway()

# List available cluster options (optional)
print(gateway.list_clusters())

options = gateway.cluster_options()

options.worker_cores = 1
options.worker_threads = 1
options.worker_memory = 3  # (GB)

# Create a new cluster
cluster = gateway.new_cluster(cluster_options=options)

# Scale workers (optional)
num_workers = 4
cluster.scale(num_workers)  # or .adapt(minimum=4, maximum=16)

# Connect to it
client = Client(cluster)

# Dashboard link (optional)
print(client.dashboard_link)

# Create and register the plugin to install packages on all workers
plugin = PipInstallerPlugin(packages=["PowerBlur"])
client.register_plugin(plugin)

# Await cluster initialisation
client.wait_for_workers(n_workers=num_workers)

In [None]:
def process(ds: xr.Dataset):

    import numpy as np

    # Ensure ds is computed here
    ds = ds.compute()

    # Convert result to greyscale image

    min_val = mean_ratio_da.min()
    max_val = mean_ratio_da.max()

    # Normalize to 0-1
    da_norm = (mean_ratio_da - min_val) / (max_val - min_val)

    # Scale to 0-255 and convert to uint8
    da_uint8 = (da_norm * 255).astype(np.uint8)

    from PIL import Image

    # Convert to PIL Image
    image = Image.fromarray(da_uint8.values)

    # Blur image using package installed on workers

    from io import BytesIO
    import PowerBlur

    width, height = image.size

    # Apply the power blur effect
    PowerBlur.power_blur(
        image,
        (int(width*0.1),
         int(height*0.1),
         int(width*0.9),
         int(height*0.9))
    )

    buffer = BytesIO()
    image.save(buffer, format="PNG")

    # Get the binary data
    binary_data = buffer.getvalue()

    return binary_data

In [None]:
# Define computation
no_clouds_ds = ds.where(ds["oa_s2cloudless_mask"] == 1)
ratio_ds = no_clouds_ds["nbart_red"] / no_clouds_ds["nbart_blue"]
mean_ratio_da = ratio_ds.mean(dim="time", skipna=True)

In [None]:
# Start computation and save result on cluster
mean_ratio_da = mean_ratio_da.persist()

In [None]:
# Run image blur on cluster
future = client.submit(process, ds)

In [None]:
%%time
# Wait for all processing to finish
result = future.result()

In [None]:
# Write returned result to local disk
with open("/home/jovyan/blurred.png", "wb") as fo:
    fo.write(result)

In [None]:
# Clean-up
client.close()
cluster.close(shutdown=True)