## Example processing Sentinel-2 data with Dask (S3 write)
- Writing results directly to S3

In [None]:
import boto3
import datacube
import datacube.utils.cog
from dask import delayed
from dask_gateway import Gateway
from dask.distributed import Client
import matplotlib.pyplot as plt
import os
import time
from typing import Dict
import xarray as xr

In [None]:
# Initialise datacube

dc = datacube.Datacube()

In [None]:
# (Central NSW)
x_min, x_max = 1200000, 1300000  # 100km wide
y_min, y_max = -3600000, -3700000  # 100km high
date_range = ("2024-01-01", "2024-02-28")

In [None]:
# Load datasets (lazy)

product = "ga_s2bm_ard_3"  # Sentinel-2 B
measurements = ["nbart_red", "nbart_blue", "oa_s2cloudless_mask"]
output_crs = "EPSG:3577"
resolution = [-30, 30]

dask_chunks = {
    "time": 1,  # Each time has its own chunk
    "y": 500,  # Each tile is 500 pixels high
    "x": 500  # Each tile is 500 pixels wide
}

ds = dc.load(product=product,
             measurements=measurements,
             crs="EPSG:3577",
             x=(x_min, x_max),
             y=(y_min, y_max),
             time=date_range,
             output_crs=output_crs,
             resolution=resolution,
             dask_chunks=dask_chunks,
             dataset_predicate=lambda ds: ds.metadata.dataset_maturity == "final",
             skip_broken_datasets=True  # Important!
             )

In [None]:
ds

In [None]:
# Start a remote Dask cluster

gateway = Gateway()

# List available cluster options (optional)
print(gateway.list_clusters())

options = gateway.cluster_options()

options.worker_cores = 1
options.worker_threads = 1
options.worker_memory = 2  # (GB)

# Create a new cluster
cluster = gateway.new_cluster(cluster_options=options)

# Scale workers (optional)
num_workers = 4
cluster.scale(num_workers)  # or .adapt(minimum=4, maximum=16)

# Connect to it
client = Client(cluster)

# Dashboard link (optional)
print(client.dashboard_link)

# Await cluster initialisation
client.wait_for_workers(n_workers=num_workers)

In [None]:
def get_session_client(credentials: Dict[str, str]):
    """
    Return new boto3 client; using credentials dictionary if not None.
    """
    session = boto3.session.Session()
    if credentials:
        client = session.client("s3",
                                aws_access_key_id=credentials["AWS_ACCESS_KEY_ID"],
                                aws_secret_access_key=credentials["AWS_SECRET_ACCESS_KEY"],
                                aws_session_token=credentials["AWS_SESSION_TOKEN"]
                               )
    else:
        client = session.client("s3")
    return client

In [None]:
def write_cog(da: xr.DataArray, fname: str, nodata: float) -> None:
    """
    Write DataArray to COG.
    """
    datacube.utils.cog.write_cog(geo_im=da, fname=fname, nodata=nodata)

In [None]:
def write_data_to_s3(
    binary_data: bytes,
    bucket_name: str,
    key: str,
    credentials: Dict[str, str]
) -> None:
    """
    Write binary data to S3.
    """
    client = get_session_client(credentials)
    client.put_object(Body=binary_data, Bucket=bucket_name, Key=key)

In [None]:
def da_to_mem_cog(da: xr.DataArray, nodata: float) -> bytes:
    """
    Create an in-memory COG from a DataArray.
    """
    return datacube.utils.cog.write_cog(geo_im=da, fname=":mem:", nodata=nodata)

In [None]:
def da_to_s3(
    da: xr.DataArray,
    nodata: float,
    bucket: str,
    key: str,
    credentials: Dict[str, str]
) -> None:
    """
    Convert a DataArray to an in-memory COG and push it to S3.
    """
    ms = da_to_mem_cog(da, nodata)
    write_data_to_s3(ms, bucket, key, credentials)

In [None]:
def export_to_s3(
    da: xr.DataArray,
    nodata: float,
    bucket: str,
    key: str,
    credentials: Dict[str, str]
) -> delayed:
    """
    Returns a delayed call to convert a DataArray to an in-memory COG and push it to S3.
    """
    return delayed(da_to_s3)(da, nodata, bucket, key, credentials)

In [None]:
# Define the computation
no_clouds_ds = ds.where(ds["oa_s2cloudless_mask"] == 1)
ratio_ds = no_clouds_ds["nbart_red"] / no_clouds_ds["nbart_blue"]
mean_ratio_da = ratio_ds.mean(dim="time", skipna=True)

In [None]:
# Persist to compute and hold the result in memory (spread across cluster)
mean_ratio_da = mean_ratio_da.persist()

In [None]:
# Define the task to write the result to S3
BUCKET = "easihub-csiro-user-scratch"
USER_ID = boto3.client("sts").get_caller_identity()["UserId"]
KEY = f"{USER_ID}/test_cog.tif"
write_task = export_to_s3(mean_ratio_da, -999, BUCKET, KEY, None)

In [None]:
# Now write to Zarr, deferred
#import s3fs
#fs = s3fs.S3FileSystem(anon=False)
#store = fs.get_mapper("s3://<bucket>/<key>.zarr")
#write_task = mean_ratio_da.to_zarr(store, mode="w", consolidated=True, compute=False)

In [None]:
# Run the write to S3 task
future = client.compute(write_task)

In [None]:
# Wait until complete
result = client.gather(future)

In [None]:
# Clean-up
client.close()
cluster.close(shutdown=True)

In [None]:
# Test it worked by downloading the COG we just made from S3 to the local machine
# and visualizing it.

In [None]:
# Copy file from S3 to local machine

LOCAL_COG_FP = "/home/jovyan/test_cog.tif"

client = get_session_client(None)
client.download_file(BUCKET, KEY, LOCAL_COG_FP)

In [None]:
cog_ds = xr.open_dataset(LOCAL_COG_FP)

In [None]:
# Visualise mean ratio dataset

band = cog_ds["band_data"].isel(band=0)

# Plot with xarray’s wrapper around matplotlib
band.plot.imshow(cmap="viridis")  # or cmap='gray', 'RdYlGn', etc.
plt.title("Result")
plt.xlabel("x")
plt.ylabel("y")
plt.show()

In [None]:
os.remove(LOCAL_COG_FP)