dask_gateway is required to setup the connection to the cluster. Otherwise, xarray has builin support for dask so no further magic is required.

In [None]:
%matplotlib inline

import os
import pathlib

import dask_gateway

import rioxarray
import rioxarray.merge
import numpy as np
import matplotlib.pyplot as plt
import IPython.display
import xarray as xr

Connect to the the dask_gateway VM. Authentiaction is via a token available from the notebook's environment.

In [None]:
gw = dask_gateway.Gateway("https://dask-gateway.jasmin.ac.uk", auth="jupyterhub")

Have the gateway create a new dask cluster. This creates a dask scheduler job in LOTUS running as the user running the notebook.
Since LOTUS jobs have to queue, we re-use our cluster rather than destroying it at the end of the script.

In [None]:
clusters = gw.list_clusters()
if not clusters:
    cluster = gw.new_cluster(shutdown_on_close=False)
else:
    cluster = gw.connect(clusters[0].name)

Scale the cluster to have three workers. Each of these workers becomes it's own job in LOTUS, and is managed my the scheduler created in the last step.
These workers have 4 cores and 16GiB of RAM each, but it is possible to ask for less.

In [None]:
cluster.scale(3)

Now that we have a stood-up cluster, we can get a client with which to inteact with it.

In [None]:
client = cluster.get_client()

Inspecting the client object will give the dashboard URL, which is proxied from the scheduler job via the gateway VM.

In [None]:
client

Dask is now setup. xarray and some other libraries will use it automatically if you tell them to split the data up into chunks.
Otherwise, you can use dask.delayed or other tools.

In this example we will use xarray to load some sentinel2 data, merge, plot and inspect it.
Without dask, this crashes your notebook kernel due to the memory required to load and process the files.

Find some sentinel2 data from Southeast England.

In [None]:
data_folder = pathlib.Path("/neodc/sentinel_ard/data/sentinel_2/2022/08/06")
images = [
    "S2B_20220806_lat51lon062_T30UYB_ORB094_utm30n_osgb",
    "S2B_20220806_lat51lon08_T30UXB_ORB094_utm30n_osgb",
    "S2B_20220806_lat52lon075_T30UXC_ORB094_utm30n_osgb",
    "S2B_20220806_lat52lon07_T30UYC_ORB094_utm30n_osgb"
]

#images = [images[0]]

cloud_files = []
col_files = []
for image in images:
    cloud_files += data_folder.glob(f"{image}*clouds.tif")
    col_files += data_folder.glob(f"{image}*stdsref.tif")

Load the files. The important argument here is "chunks" which tells xarray to create a dask array. Here we let it automagically choose chunk sizes.

In [None]:
# These files are ~1.5G each.
col_loaded = []
for file_ in col_files:
    col_loaded.append(rioxarray.open_rasterio(file_, chunks="auto"))
    
cloud_loaded = []
for file_ in cloud_files:
    cloud_loaded.append(rioxarray.open_rasterio(file_, chunks="auto"))

Inspecting the file shows how many chunks it was broken up into. Note that the file has not been loaded from disk into this notebook.

In [None]:
col_loaded[0]

We merge the files into one array, downscale them slightly and keep the new array as a dask array.

In [None]:
col = rioxarray.merge.merge_arrays(col_loaded, res=(40,40)).chunk("auto")
cloud = rioxarray.merge.merge_arrays(cloud_loaded, res=(40,40)).chunk("auto")

In [None]:
col

We can then plot-out the marged array. All the calculations happen on the workers and just the resulting plot is sent back to the notebook.

In [None]:
%%capture
fig = plt.figure(figsize=(12.8, 9.6))
ax = plt.axes()

In [None]:
col.sel(band=[3,2,1]).plot.imshow(ax=ax)

In [None]:
IPython.display.display(fig)

In [None]:
cloud.where(cloud > 1).squeeze().plot.imshow(ax=ax, cmap="Greys_r", vmin=0, vmax=0)

In [None]:
IPython.display.display(fig)

Operations on the arrays use dask and will not return a result until you call .compute()

In [None]:
cloud_cells = cloud.where(cloud > 1).count()
cloud_cells

In [None]:
cloud_cells.compute()

In [None]:
percent_cells_cloud = (cloud_cells / cloud.where(~np.isnan(cloud)).count()) * 100
percent_cells_cloud

In [None]:
percent_cells_cloud.compute()