## On Demand Training Data from Radiant MLHub and Planetary Computer

<img src='https://radiant-assets.s3-us-west-2.amazonaws.com/PrimaryRadiantMLHubLogo.png' alt='Radiant MLHub Logo' width='300'/>

In this tutorial, we will walk through the process of requesting on-demand traning data from the [Planetary Computer Data Catalog](https://planetarycomputer.microsoft.com/catalog) to pair with the [BigEarthNet](https://mlhub.earth/data/bigearthnet_v1) dataset downloaded from Radiant MLHub. This is an important workflow for someone in the geospatial community who wants to train an ML model on a datasource outside of a prepackaged dataset, such as those found on MLHub. They can start with any dataset containing source image and label collections in STAC, obtain a random sample to work with, fetch source images from a different collection or satellite product, and then reproject and crop those images to match the spatial and temporal extent of the original dataset.

**NOTE:** because the workflow documented below uses libraries like `pystac_client` and `stackstac`, the datasets queried need to be organized into STAC Collections.

Let's start by importing the Python libraries we'll use in this notebook.

In [None]:
!pip install --upgrade pillow wget graphviz # not installed on PC by default

In [None]:
import getpass
import tempfile
from pathlib import Path
import os
import json
from glob import glob
import requests
from typing import List, Tuple
from datetime import datetime as dt
from datetime import timedelta as td

from radiant_mlhub import Collection
import planetary_computer
from pystac_client import Client as ps_client
from pystac import ItemCollection, Item, Asset
from dask.distributed import Client as dd_client
from dask import delayed, compute, persist

import numpy as np
from stackstac import stack
from geopandas import GeoDataFrame
import rasterio as rio
import rioxarray
from xarray import DataArray
from shapely.geometry import shape
from shapely.geometry import Polygon
from pyproj import CRS

### Define global variables

In addition to the API key, we will also need to define some other initial global variables to get our workflow started. e.g. a temporary working directory to download and write data to, the STAC API endpoints, names of Collections, and other variables like the RGB bands for those collections. These are pretty flexible depending on your individual needs.

In [None]:
# Temporary working directory on local machine or PC instance
TMP_DIR = tempfile.gettempdir()

# API endpoints for MLHub and Planetary Computer catalogs
MLHUB_API_URL = "https://api.radiant.earth/mlhub/v1"
MSPC_API_URL = "https://planetarycomputer.microsoft.com/api/stac/v1"

# Names of Collections that will be queried against using pystac_client
BIGEARTHNET_SOURCE_COLLECTION = "bigearthnet_v1_source"  # sentinel-2 source imagery
BIGEARTHNET_LABEL_COLLECTION = "bigearthnet_v1_labels"  # geojson classification labels
PLANETARY_COMPUTER_LANDSAT_8 = "landsat-8-c2-l2"  # landsat 8 source imagery on PC

# Default variables that will be used in the API queries
BIGEARTHNET_TIME_RANGE = "2017-06-01/2018-05-31"  # full date range for BigEarthNet
LABEL_CRS = CRS("EPSG:4326")
DATE_BUFFER = 60
LANDSAT_8_RGB_BANDS = ["SR_B4", "SR_B3", "SR_B2"]  # names of RGB bands from BigEarthNet
BIGEARTHNET_RGB_BANDS = ["B04", "B03", "B02"]  # names of RGB bands from PC Landsat 8

# Bounding box for demonstration fetching Items over Luxembourg
LUXEMBOURG_AOI = [6.06, 49.58, 6.21, 49.66]  # aoi around Luxembourg

### Authentication with Radiant MLHub

Programmatic access to the Radiant MLHub API using the `pystac_client` library requires both the API end-point and an API key. You can obtain an API key for free by registering an account on [mlhub.earth](https://mlhub.earth/). This can be found under `Settings & API Key` from the drop-down once logged in.

In [None]:
MLHUB_API_KEY = getpass.getpass(prompt="MLHub API Key: ")

Once you have your API key, you will need to create a default profile by setting up a .mlhub/profiles file in your home directory. You can use the `mlhub configure` command line tool to do this:

`$ mlhub configure`<br>
API Key: {<i>Enter your API key here</i>}<br>
Wrote profile to /home/jovyan/.mlhub/profiles

### Configure API connection to Radiant MLHub

This makes a connection to the Radiant MLHub Data Catalog using the API endpoint URL, and the API key from your account.

In [None]:
mlhub_catalog = ps_client.open(
    url=MLHUB_API_URL, parameters={"key": MLHUB_API_KEY}, ignore_conformance=True
)

### Fetch label items from BigEarthNet over Luxembourg

This helper function below encapsulates the process of querying a STAC API endpoint to fetch an ItemCollection matching query criteria.

In [None]:
def search_stac_api(
    catalog_client: ps_client,
    collections: List[str],
    bbox: List[float] = None,
    datetime: str = None,
    ids: List[str] = None,
) -> ItemCollection:
    """Uses a pystac client to query a STAC API endpoint for Items.
    Searching using either IDs or datetime and bbox params

    Args:
        catalog_client: an instance of pystac_client.Client
        collections: a list of string names matching valid STAC Collections
        bbox: a list of floats specifying [xmin, ymin, xmax, ymax] values
        datetime: a string representing a single datetime or date range
        cloud_cover: a float value representation of image covered by clouds

    Returns:
        ItemCollection: pystac collection (iterable) of items found
    """
    if ids:
        search = catalog_client.search(collections=collections, ids=ids)
    elif bbox and datetime:
        search = catalog_client.search(
            collections=collections, bbox=bbox, datetime=datetime
        )
    else:
        search = catalog_client.search(collections=collections)

    return search.get_all_items()

We will now use the API client with the helper function above to get label Items over Luxembourg.

In [None]:
origin_label_items = search_stac_api(
    catalog_client=mlhub_catalog,
    collections=[BIGEARTHNET_LABEL_COLLECTION],
    bbox=LUXEMBOURG_AOI,
    datetime=BIGEARTHNET_TIME_RANGE,
)

In [None]:
len(origin_label_items)

This is another helper function that simply displays the geometry for labels from an ItemCollection overlayed on a map of the region.

In [None]:
def explore_search_extent(items: ItemCollection) -> None:
    """Extracts geometry from ItemCollection to display polygons on a map.

    Args:
        items: ItemCollection of Items retrieved from pystac_client search

    Returns:
        GeoDataFrame object with the .explore() method called
    """
    item_feature_collection = items.to_dict()
    geom_df = GeoDataFrame.from_features(item_feature_collection).set_crs(4326)
    return geom_df[["geometry", "datetime"]].explore(
        column="datetime", style_kwds={"fillOpacity": 0.2}, cmap="viridis"
    )

Here are the BigEarthNet chips with their bounding boxes that matched the spatial parameters for the city of Luxembourg and surrounding areas.

In [None]:
explore_search_extent(origin_label_items)

### Download the entire label collection for BigEarthNet from Radiant MLHub

We could certainly use the method above to query label Items directly from our connection to the Radiant MLHub API endpoint. However, on very large collections, such as in the case with BigEarthNet, pagination becomes a bottleneck issue in obtaining and resolving STAC items, as it only returns 100 items at a time.  Querying the entire Collection of nearly ~600,000 Items could take hours.

Therefore, downloading the label Collection (which is only 160 MB) directly is preferrable to paginating over the entire Collection using the API.

In [None]:
label_collection_path = os.path.join(
    TMP_DIR, BIGEARTHNET_LABEL_COLLECTION, "collection.json"
)

Check if collection folder already exists before downloading 173 mb dataset. Otherwise download and uncompress the `.tar.gz` file to extract the label collection files.

In [None]:
if not os.path.exists(label_collection_path):
    collection = Collection.fetch(BIGEARTHNET_LABEL_COLLECTION)
    archive_path = collection.download(TMP_DIR)
    !tar -xf {archive_path.as_posix()} -C {TMP_DIR}
else:
    print("Archive file already downloaded from Radiant MLHub, skipping...")

In [None]:
bigearthnet_dir = os.listdir(os.path.join(TMP_DIR, BIGEARTHNET_LABEL_COLLECTION))
bigearthnet_dir[0:5]

This is the total count of label Item (chip) directories, plus one for the STAC Collection itself.

In [None]:
len(bigearthnet_dir)

### Obtain a random sample of label Items from BigEarthNet

We don't want to work with the entire dataset of nearly 600,000 labels. This would take too long to download, and we likely won't have enough disk space or space in memory, so let's work with a random sample of the dataset that is 10% of the original size.

In [None]:
assert os.path.exists(label_collection_path)
with open(label_collection_path, "r") as in_file:
    collection_data = json.load(in_file)

This confirms we have all of the label Items STAC objects and image data from the collection

In [None]:
label_item_links = [
    link["href"] for link in collection_data["links"] if link["rel"] == "item"
]
len(label_item_links)

Now we take a random sample that is 1/100th the original dataset size

In [None]:
label_item_sample = np.random.choice(
    a=label_item_links, size=int(len(label_item_links) / 100), replace=False
)

In [None]:
label_item_sample[0:5]

In [None]:
rand_idx = np.random.randint(len(label_item_sample))

In [None]:
first_label_item = Item.from_file(
    os.path.join(TMP_DIR, BIGEARTHNET_LABEL_COLLECTION, label_item_sample[rand_idx])
)

Chip ID for the sample label Item pulled:

In [None]:
first_label_item.id

Links for the sample label Item, take special note of the `rel=source` Link listed:

In [None]:
first_label_item.links

### Fetch source items for random sample from BigEarthNet

If we had the source collection archive downloaded and uncompressed in the same parent directory as the labels collection, we could reference the source Items and images directly. However the BigEarthNet source collection is over 60GB when compressed. Therefore to work around the disk size limitations of a Planetary Computer instance, we can query the same source items from the MLHub API endpoint, the same way we got the labels, but filter to the exact source item using IDs.

In [None]:
def get_source_item_ids(label_item: Item) -> List[str]:
    return [
        link.href.split("/")[-2] for link in label_item.links if link.rel == "source"
    ]

In [None]:
origin_source_items = search_stac_api(
    catalog_client=mlhub_catalog,
    collections=[BIGEARTHNET_SOURCE_COLLECTION],
    ids=get_source_item_ids(first_label_item),
)

This is the number of source items that match the query parameters we sent to the MLHub API using the first label's bounding box and datetime properties.

In [None]:
len(origin_source_items)

Taking a look at some of the properties of the first source Item found:

In [None]:
for source_item in origin_source_items:
    print(source_item.id)
    print(source_item.datetime)
    print(source_item.bbox)
    print(source_item.properties)
    break

With the properties from this sample source Item, we can observe where the chip is located, the relevant Sentinel-2 bands (assets) and datetime the image was captured.

In [None]:
explore_search_extent(origin_source_items)

This is the location of the source items fetched from the label Items sample.

### Fetch Landsat 8 scenes based on source Item bbox and datetime

Configure API connection for the microsoft planetary computer stac endpoint

In [None]:
def temporal_buffer(item_datetime: str, date_delta: int) -> str:
    """Takes a datetime string and returns a buffer around that date

    Args:
        item_datetime: string of the datetime property from an Item
        date_delta: integer for days to add before and after a date

    Returns:
        a string range representing the full date buffer
    """
    delta = td(days=date_delta)
    item_dt = dt.strptime(item_datetime, "%Y-%m-%dT%H:%M:%SZ")

    dt_start = item_dt - delta
    dt_start_str = dt_start.strftime("%Y-%m-%d")

    dt_end = item_dt + delta
    dt_end_str = dt_end.strftime("%Y-%m-%d")

    return f"{dt_start_str}/{dt_end_str}"

In [None]:
def min_cloud_cover_scene(label_geom: Polygon, search_items: ItemCollection) -> Item:
    """Finds the Item with minimal cloud cover from an ItemCollection

    Args:
        label_geom: Polygon geometry to ensure label completely within scene
        search_items: ItemCollection of the Items found from pystac_client search

    Returns:
        Item where label completely contained within, and minimal cloud cover
    """
    min_cc = np.inf
    min_cc_item = None
    for item in search_items:
        item_geom = shape(item.geometry)
        item_cc = item.properties["eo:cloud_cover"]
        if item_cc < min_cc and label_geom.within(item_geom):
            min_cc = item_cc
            min_cc_item = item
    return min_cc_item

In [None]:
def get_landsat_8_match(label_item: Item) -> Tuple[Item, Item]:
    """Finds the best Landsat 8 match using source Item datetime and bounding box.

    Args:
        label_item: the STAC label Item object

    Returns:
        Tuple of the BigEarthNet source Item and the Landsat 8 match Item
    """
    # get the matching source Item properties
    source_items = search_stac_api(
        catalog_client=mlhub_catalog,
        collections=[BIGEARTHNET_SOURCE_COLLECTION],
        ids=get_source_item_ids(label_item),
    )

    if source_items:
        source_item = source_items[0]
        source_bbox = source_item.bbox
        source_datetime = source_item.properties["datetime"]

        # search PC Catalog for L8 Items
        l8_items = search_stac_api(
            catalog_client=mspc_catalog,
            collections=PLANETARY_COMPUTER_LANDSAT_8,
            bbox=source_bbox,
            datetime=temporal_buffer(source_datetime, DATE_BUFFER),
        )

        # filter to best L8 Item match
        signed_l8_items = planetary_computer.sign(l8_items)
        best_l8_match = min_cloud_cover_scene(
            shape(source_item.geometry), signed_l8_items
        )

        if not best_l8_match:
            print(
                "No Landsat 8 Item was found on the Planetary "
                "Computer matching the query parameters:"
            )
            print(
                f"Source Item ID: {source_item.id} "
                f"Bbox: {source_bbox}, "
                f"Datetime: {source_datetime}"
            )
            best_l8_match = None
    else:
        print(
            "No Sentinel-2 source Item was found in the "
            "BigEarthNet dataset matching that label item!"
        )
        source_item = None
    return source_item, best_l8_match

Since it is known that the BigEarthNet dataset from MLHub has a 1-to-1 pairing of source and labels, we can safely assume the first source item is the appropriate match for our label.

This makes a connection to the Planetary Computer Data Catalog using the API endpoint URL.

In [None]:
mspc_catalog = ps_client.open(MSPC_API_URL)

We will now use the API client with the helper function above to fetch the best Landsat 8 match for the sampled label Item. This will find only the scenes where the label is completely within the scene, and there is minimal cloud cover.

In [None]:
source_item, best_l8_match = get_landsat_8_match(first_label_item)

In [None]:
if best_l8_match:
    print(best_l8_match.id)
    print(best_l8_match.bbox)
    print(best_l8_match.properties)

In [None]:
explore_search_extent(ItemCollection([best_l8_match]))

If everything worked correctly, the geographic scope of the Landsat 8 scene should encompass a much larger surface area than the Sentinel-2 source and label chips. From here we need to crop the image down and make sure the chips from both products match.

In [None]:
def get_redirect_url(asset: Asset) -> str:
    """Returns the direct URL to an asset.

    Args:
        asset: Asset object from an Item

    Returns:
        string response URL direct to Asset
    """
    response = requests.get(asset.href, allow_redirects=True)
    if response.status_code == 200:
        return response.url
    return None

In [None]:
s2_stack = stack(
    items=ItemCollection([source_item]),
    assets=BIGEARTHNET_RGB_BANDS,
    epsg=rio.open(get_redirect_url(source_item.assets["B02"])).crs.to_epsg(),
    resolution=10,
)

In [None]:
s2_stack

In [None]:
s2_stack[0].plot(col="band")

In [None]:
l8_original = stack(
    items=ItemCollection([best_l8_match]), assets=LANDSAT_8_RGB_BANDS, resolution=10
)

In [None]:
l8_original

As we can see from the metadata for the Xarray above, the Landsat 8 scene has a significantly larger geographic footprint, `~20,000 x ~20,000 pixels`, compared to `120 x 120 pixels` for the Sentinel-2 chips that were prepared for the dataset. We need to crop/mask the image down so they represent the exact same terrain.

Luckily, the `bounds_latlon` parameter of `stackstac` makes it easy to crop the image to this size automatically for all bands/assets requested.

In [None]:
l8_cropped = stack(
    items=ItemCollection([best_l8_match]),
    assets=LANDSAT_8_RGB_BANDS,
    bounds_latlon=source_item.bbox,
    resolution=10,
)

In [None]:
l8_cropped

In [None]:
l8_cropped[0].plot(col="band")

Now we have a cropped Landsat 8 chip that spatially and temporally matches our Sentinel-2 source imagery and label sample from the BigEarthNet dataset.

### Launch a Dask gateway cluster for parallel processing

We will use Dask to optimize our data processing of hundreds of Landsat-8 scenes by parallelizing the workflow with a delayed computation graph. The Dask Client schedules, runs the delayed computations, and gathers the results, while the Dask Gateway provides a secure and centralized way of managing the multiple client clusters. This is especially useful for running Dask on Planetary Computer.

In [None]:
client = dd_client(
    # you can configure Dask client parameters here
)
client

In [None]:
# this cell will only work on PC or a machine with gateway cluster configured
# gateway = dask_gateway.Gateway()
# options = gateway.cluster_options()
# options["worker_cores"] = 7

### Scale the workflow using Dask Delayed

These are two helper functions that we will use to encapsulate the process of creating the cropped Landsat 8 chips and write them to disk in parallel using the Dask Delayed decorator.

In [None]:
def create_landsat_8_dataarray(item_path: str) -> DataArray:
    """Creates a Landsat 8 chip from BigEarthNet label chip.

    Args:
        item_path: string path to the label item on disk

    Returns:
        Landsat 8 DataArray that has been cropped to label bbox
    """
    # read label Item object
    label_item = Item.from_file(
        os.path.join(TMP_DIR, BIGEARTHNET_LABEL_COLLECTION, item_path)
    )

    # fetch the Landsat 8 scene that best matches the label
    s2_source, l8_match = get_landsat_8_match(label_item)

    if l8_match:
        # crop L8 match to S2 dims and read image data
        l8_stack = stack(
            items=ItemCollection([l8_match]),
            assets=LANDSAT_8_RGB_BANDS,
            bounds_latlon=s2_source.bbox,
            resolution=10,
        )

        return l8_stack
    return None

In [None]:
def write_tifs_bands(l8_array: DataArray, l8_item_id: str) -> None:
    """Writes to a GeoTiff for each band in Landsat 8 DataArray

    Args:
        l8_array: the DataArray object created from the BigEarthNet label item
    """
    # write cropped L8 DataArray to a tiff file for each band
    for _band in LANDSAT_8_RGB_BANDS:
        l8_band_img = l8_array.sel(band=_band)
        l8_band_filename = os.path.join(
            TMP_DIR, "landsat_8_source", l8_item_id, f"{l8_item_id}_{_band}.tiff"
        )
        Path(os.path.split(l8_band_filename)[0]).mkdir(parents=True, exist_ok=True)
        l8_band_img[0].rio.to_raster(l8_band_filename)

This sets the stage for the Dask Task Scheduler by mapping all label Items to the `create_landsat_8_dataarray` function. Nothing in the task graph will actually be executed until the `.compute()` command is ran.

In [None]:
task_pool = []

for item_path in label_item_sample:
    delayed_task = delayed(create_landsat_8_dataarray)(item_path)
    task_pool.append(delayed_task)

Now we will persist the objects into memory and run the computations to create our DataArrays.

In [None]:
%%time
task_pool = persist(*task_pool)
task_pool = compute(*task_pool)

Lastly, we want to write a GeoTIFF to disk for each band of each Landsat 8 DataArray we created.

In [None]:
%%time
for l8_array in task_pool:
    if isinstance(l8_array, DataArray):
        write_tifs_bands(l8_array, l8_array.id.values[0])

This confirms that folders with images were written to disk. If there is a discrepancy between the sample size and the output, it's likely that there wasn't always a matching Landsat 8 scene given the geometry and datetime parameters for a particular Sentinel-2 source Item.

In [None]:
landsat_chip_dir = os.path.join(TMP_DIR, "landsat_8_source")
len(os.listdir(landsat_chip_dir))

Open one of the new Landsat 8 chips to inspect what it looks like.

In [None]:
landsat_images = glob(f"{landsat_chip_dir}/**/*.tiff", recursive=True)
first_l8_img = rioxarray.open_rasterio(landsat_images[0])
first_l8_img.plot()

Shutdown the Dask client to cleanup cluster resources.

In [None]:
client.shutdown()