## On Demand Training Data from Radiant MLHub and Planetary Computer

<img src='https://radiant-assets.s3-us-west-2.amazonaws.com/PrimaryRadiantMLHubLogo.png' alt='Radiant MLHub Logo' width='300'/>

In this tutorial, we will walk through the process of requesting on-demand traning data from the [Planetary Computer Data Catalog](https://planetarycomputer.microsoft.com/catalog) to pair with the [BigEarthNet](https://mlhub.earth/data/bigearthnet_v1) dataset downloaded from Radiant MLHub. This is an important workflow for someone in the geospatial community who wants to train an ML model on a datasource outside of a prepackaged dataset, such as those found on MLHub. They can start with any dataset containing source image and label collections in STAC, obtain a random sample to work with, fetch source images from a different collection or satellite product, and then reproject and crop those images to match the spatial and temporal extent of the original dataset.

**NOTE:** because the workflow documented below uses libraries like `pystac_client` and `stackstac`, the datasets queried need to be organized into STAC Collections.

Let's start by importing the Python libraries we'll use in this notebook.

In [None]:
import getpass
import tempfile
from pathlib import Path
import os
import json
from glob import glob
import requests
from typing import List, Tuple, Dict, Any
from datetime import datetime as dt
from datetime import timedelta as td

import planetary_computer
import pystac_client
from pystac import ItemCollection, Item, Asset
import dask

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from stackstac import stack
from geopandas import GeoDataFrame
import rasterio as rio
from rasterio.plot import show
import rioxarray
from xarray import DataArray
from shapely.geometry import shape
from shapely.geometry import Polygon
from pyproj import CRS

### Define global variables

We will also need to define other initial global variables to get our workflow started, e.g. a temporary working directory to download and write data to, the STAC API endpoints, names of Collections, and other variables like the RGB bands for those collections. These are pretty flexible depending on your individual needs.

In [None]:
# Temporary working directory on local machine or PC instance
TMP_DIR = tempfile.gettempdir()

# API endpoints for MLHub and Planetary Computer catalogs
MLHUB_API_URL = "https://api.radiant.earth/mlhub/v1"
MSPC_API_URL = "https://planetarycomputer.microsoft.com/api/stac/v1"

# Names of Collections that will be queried against using pystac_client
BIGEARTHNET_SOURCE_COLLECTION = "bigearthnet_v1_source"  # sentinel-2 source imagery
BIGEARTHNET_LABEL_COLLECTION = "bigearthnet_v1_labels"  # geojson classification labels
PLANETARY_COMPUTER_LANDSAT_8 = "landsat-8-c2-l2"  # landsat 8 source imagery on PC
OUTPUT_DIR = "landsat_8_source"

# Default variables that will be used in the API queries
BIGEARTHNET_TIME_RANGE = "2017-06-01/2018-05-31"  # full date range for BigEarthNet
LABEL_CRS = CRS("EPSG:4326")
DATE_BUFFER = 60
LANDSAT_8_RGB_BANDS = ["SR_B4", "SR_B3", "SR_B2"]  # names of RGB bands from BigEarthNet
BIGEARTHNET_RGB_BANDS = ["B04", "B03", "B02"]  # names of RGB bands from PC Landsat 8

# Bounding box for demonstration fetching Items over Luxembourg
LUXEMBOURG_AOI = [6.06, 49.58, 6.21, 49.66]  # aoi around Luxembourg
SPAIN_AOI = [-9.73, 35.84, 3.43, 43.87]

### Authentication with Radiant MLHub

Programmatic access to the Radiant MLHub API using the `pystac_client` library requires both the API end-point and an API key. You can obtain an API key for free by registering an account on [mlhub.earth](https://mlhub.earth/). This can be found under `Settings & API Key` from the drop-down once logged in.

In [None]:
MLHUB_API_KEY = getpass.getpass(prompt="MLHub API Key: ")

### Configure API connection to Radiant MLHub

This makes a connection to the Radiant MLHub Data Catalog using the API endpoint URL, and the API key from your account.

In [None]:
mlhub_catalog = pystac_client.Client.open(
    url=MLHUB_API_URL, parameters={"key": MLHUB_API_KEY}, ignore_conformance=True
)

### Fetch label items from BigEarthNet over Luxembourg

We will now use the `search` function from the API client to get label Items over Luxembourg as a simple use-case.

In [None]:
origin_label_items = mlhub_catalog.search(
    collections=BIGEARTHNET_LABEL_COLLECTION,
    bbox=LUXEMBOURG_AOI,
    datetime=BIGEARTHNET_TIME_RANGE,
    max_items=100
).get_all_items()

This is a helper function that simply displays the geometry for labels from an ItemCollection overlayed on a map of the region.

In [None]:
def explore_search_extent(items: ItemCollection) -> None:
    """Extracts geometry from ItemCollection to display polygons on a map.

    Args:
        items: ItemCollection of Items retrieved from pystac_client search

    Returns:
        GeoDataFrame object with the .explore() method called
    """
    item_feature_collection = items.to_dict()
    geom_df = GeoDataFrame.from_features(item_feature_collection).set_crs(4326)
    print(geom_df.bounds)
    return geom_df[["geometry", "datetime"]].explore(
        column="datetime", style_kwds={"fillOpacity": 0.2}, cmap="viridis"
    )

Here are the BigEarthNet chips with their bounding boxes that matched the spatial parameters for the city of Luxembourg and surrounding areas.

In [None]:
explore_search_extent(origin_label_items)

### Download BigEarthNet Source Items from Radiant MLHub

We could certainly use the method above to query all label and source Items directly from our connection to the Radiant MLHub API endpoint. However, on very large collections, such as in the case with BigEarthNet, pagination becomes a bottleneck issue in obtaining and resolving STAC item.  

Querying the entire Collection of nearly ~600,000 Items from a single collection alone would take almost an hour depending on your connection speed. This means it could possibly take a few hours to download all Items in the Catalog. 

One alternative is to download the `.tar.gz` of the collections directly from the Radiant MLHub dataset detail page. The filesize for the labels archive is not large, only 165 MB. However because there are over half a million objects, it takes a long time to decompress the entire download.

Therefore, we can showcase this workflow by paginating over the source Item Collection to fetch the first 5,000 Items available (which only represents 1% of the entire collection).

In [None]:
bigearthnet_source_search = mlhub_catalog.search(
    collections=BIGEARTHNET_SOURCE_COLLECTION,
    bbox=SPAIN_AOI,
    # limit=100, # limit of items per page
    max_items=5000 # total Item recall
)

It should take less than a minute to fetch all the STAC Items for the 5000 sample we've queried.

In [None]:
%%time
bigearthnet_source_items = bigearthnet_source_search.get_all_items()

In [None]:
explore_search_extent(bigearthnet_source_items)

We can see from this map that the location of the source items fetched are concentrated in Portugal. This is merely a consequence of the fact we fetched the first 5,000 source Items out of the Catalog API with a bounding box search criteria over Spain. Had we downloaded the entire Catalog locally, or ran an unfiltered search, we could fetch a random sample that is more representative of the entire dataset.

### Configure API connection to Planetary Computer

This makes a connection to the Planetary Computer Data Catalog using the API endpoint URL.

In [None]:
mspc_catalog = pystac_client.Client.open(MSPC_API_URL)

### Fetch Landsat 8 scenes based on source Item bbox and datetime

Since it is known that the BigEarthNet dataset from MLHub has a 1-to-1 pairing of source and labels, we can safely assume the first source item is the appropriate match for our label.

We will now use the API client with the helper function above to fetch the best Landsat 8 match for the sampled label Item. This will find only the scenes where the label is completely within the scene, and there is minimal cloud cover.

In [None]:
def temporal_buffer(item_datetime: str, date_delta: int) -> str:
    """Takes a datetime string and returns a buffer around that date

    Args:
        item_datetime: string of the datetime property from an Item
        date_delta: integer for days to add before and after a date

    Returns:
        a string range representing the full date buffer
    """
    delta = td(days=date_delta)
    item_dt = dt.strptime(item_datetime, "%Y-%m-%dT%H:%M:%SZ")

    dt_start = item_dt - delta
    dt_start_str = dt_start.strftime("%Y-%m-%d")

    dt_end = item_dt + delta
    dt_end_str = dt_end.strftime("%Y-%m-%d")

    return f"{dt_start_str}/{dt_end_str}"

In [None]:
def min_cloud_cover_scene(label_geom: Polygon, search_items: ItemCollection) -> Item:
    """Finds the Item with minimal cloud cover from an ItemCollection

    Args:
        label_geom: Polygon geometry to ensure label completely within scene
        search_items: ItemCollection of the Items found from pystac_client search

    Returns:
        Item where label completely contained within, and minimal cloud cover
    """
    min_cc = np.inf
    min_cc_item = None
    for item in search_items:
        item_geom = shape(item.geometry)
        item_cc = item.properties["eo:cloud_cover"]
        if item_cc < min_cc and label_geom.within(item_geom):
            min_cc = item_cc
            min_cc_item = item
    return min_cc_item

In [None]:
def get_landsat_8_match(bbox: List[float], geometry: Dict[str, Any], datetime: str) -> Item:
    """Finds the best Landsat 8 match using source Item datetime and bounding box.

    Args:
        bbox: bounding box of the STAC source Item
        datetime: datetime of the STAC source Item

    Returns:
        best_l8_match: matching Landsat 8 source Item
    """

    # search PC Catalog for L8 Items
    l8_items = mspc_catalog.search(
        collections=PLANETARY_COMPUTER_LANDSAT_8,
        bbox=bbox,
        datetime=temporal_buffer(datetime, DATE_BUFFER),
    ).get_all_items()

    # filter to best L8 Item match
    signed_l8_items = planetary_computer.sign(l8_items)
    best_l8_match = min_cloud_cover_scene(
        shape(geometry), 
        signed_l8_items
    )

    return best_l8_match

In [None]:
sample_source_item = bigearthnet_source_items[np.random.randint(0, len(bigearthnet_source_items))]

In [None]:
best_l8_match = get_landsat_8_match(
    sample_source_item.bbox,
    sample_source_item.geometry,
    sample_source_item.properties['datetime']
)

In [None]:
if best_l8_match:
    print(best_l8_match.id)
    print(best_l8_match.bbox)
    print(best_l8_match.geometry)
    print(best_l8_match.properties)

In [None]:
explore_search_extent(ItemCollection([best_l8_match]))

If everything worked correctly, the geographic scope of the Landsat 8 scene should encompass a much larger surface area than the Sentinel-2 source and label chips. From here we need to crop the image down and make sure the chips from both products match.

In [None]:
def get_redirect_url(asset: Asset) -> str:
    """Returns the direct URL to an asset.

    Args:
        asset: Asset object from an Item

    Returns:
        string response URL direct to Asset
    """
    response = requests.get(asset.href, allow_redirects=True)
    if response.status_code == 200:
        return response.url
    return None

In [None]:
def plot_rgb_chip(rgb_stack: DataArray, norm: int) -> None:
    img_arr = rgb_stack[0].to_numpy().squeeze()
    fig, ax = plt.subplots(figsize=(7,7))
    show(img_arr/norm, ax=ax)

In [None]:
s2_stack = stack(
    items=ItemCollection([sample_source_item]),
    assets=BIGEARTHNET_RGB_BANDS,
    epsg=rio.open(get_redirect_url(sample_source_item.assets["B02"])).crs.to_epsg(),
    resolution=10,
)

The `stackstac.stack` method returns a DataArray object with width and height for longitude and latitude, and a third dimension for the RGB bands.

In [None]:
s2_stack

This is the true color image representation of the Sentinel-2 chip we fetched RGB assets for.

In [None]:
plot_rgb_chip(s2_stack, 4000)

Here are the RGB bands all ploted on a subplot together in a row.

In [None]:
s2_stack[0].plot(col="band")

In [None]:
l8_original = stack(
    items=ItemCollection([best_l8_match]), assets=LANDSAT_8_RGB_BANDS, resolution=10
)

In [None]:
l8_original

As we can see from the metadata for the Xarray above, the Landsat 8 scene has a significantly larger geographic footprint, `~20,000 x ~20,000 pixels`, compared to `120 x 120 pixels` for the Sentinel-2 chips that were prepared for the dataset. We need to crop/mask the Landsat 8 images down so they represent the same geographical footprint.

Luckily, the `bounds_latlon` parameter of `stackstac` makes it easy to crop the image to this size automatically for all bands/assets requested.

In [None]:
l8_cropped = stack(
    items=ItemCollection([best_l8_match]),
    assets=LANDSAT_8_RGB_BANDS,
    bounds_latlon=sample_source_item.bbox,
    resolution=10,
)

In [None]:
l8_cropped

In [None]:
l8_cropped[0].data.compute()

In [None]:
plot_rgb_chip(l8_cropped, 23000)

Now we have a cropped Landsat 8 chip that spatially and temporally matches our Sentinel-2 source imagery and label sample from the BigEarthNet dataset. The first observation is that the Landsat 8 image appears blurry compared to Sentinel-2. This is because Sentinel-2 RGB bands have a 10m resolution, while the same bands for Landsat 8 have a 30m resolution.

### Scale the workflow using Dask Delayed

We will now use Dask to optimize processing the Landsat-8 scenes by parallelizing the workflow with a delayed computation graph. The Dask Client schedules, runs the delayed computations, and gathers the results. With parallel processing, we can speed up the runtime of our image processing workflow by 10-20x.

These are some helper functions that we will use to encapsulate the process of creating the cropped Landsat 8 chips and write them to disk in parallel using the Dask Delayed decorator.

In [None]:
def create_landsat_8_chip(source_item: Dict[str, any]) -> DataArray:
    """Creates a Landsat 8 chip from BigEarthNet label chip.

    Args:
        source_item: JSON/dictionary representation of source Item

    Returns:
        Landsat 8 DataArray that has been cropped to sentinel-2 bbox
    """

    # fetch the Landsat 8 scene that best matches the label
    l8_match = get_landsat_8_match(
        source_item['bbox'],
        source_item['geometry'],
        source_item['properties']['datetime']
    )

    # crop L8 match to S2 dims and read image data
    l8_stack = stack(
        items=ItemCollection([l8_match]),
        assets=LANDSAT_8_RGB_BANDS,
        bounds_latlon=source_item['bbox'],
        resolution=10,
    )

    return l8_stack

In [None]:
def write_tif_bands(l8_array: DataArray, l8_item_id: str) -> None:
    """Writes to a GeoTiff for each band in Landsat 8 DataArray

    Args:
        l8_array: the DataArray object created from the BigEarthNet label item
    """
    # write cropped L8 DataArray to a tiff file for each band
    for _band in LANDSAT_8_RGB_BANDS:
        l8_band_img = l8_array.sel(band=_band)
        l8_band_filename = os.path.join(
            TMP_DIR, OUTPUT_DIR, l8_item_id, f"{l8_item_id}_{_band}.tiff"
        )
        Path(os.path.split(l8_band_filename)[0]).mkdir(parents=True, exist_ok=True)
        l8_band_img[0].rio.to_raster(l8_band_filename)

This sets the stage for the Dask Task Scheduler by mapping all label Items to the `create_landsat_8_dataarray` function. Nothing in the task graph will actually be executed until the `.compute()` command is ran.

In [None]:
client = dask.distributed.Client()  # you can configure Dask client parameters here
client

One quirky nature of combining DataArray objects returned from `stackstac.stack()` (leveraging the `rioxarray` library under the hood) is that the kernel will throw an error that the DataArrays don't have the method `rio.to_raster()`. Normally we could solve this problem by explicitly importing the `rioxarray` library, but we also need to import the module onto each worker in the client cluster. 

In [None]:
import importlib
client.run(lambda: importlib.import_module("rioxarray"))

In [None]:
%%time
chunk_size = 125

for i in range(0, len(bigearthnet_source_items[0:500]), chunk_size):
    future_pool = []
    item_chunk=bigearthnet_source_items[i:i+chunk_size]
    for source_item in item_chunk:
        item_dict = dask.delayed(Item.to_dict)(source_item)
        l8_xarray = dask.delayed(create_landsat_8_chip)(item_dict)
        image_writer = dask.delayed(write_tif_bands)(l8_xarray, item_dict['id'])
        future_pool.append(image_writer)
    future_pool = dask.persist(*future_pool)
    dask.compute(*future_pool)

Now that our parallelized workflow has completed, let's confirm that folders with images were written to disk.

In [None]:
landsat_chip_dir = os.path.join(TMP_DIR, OUTPUT_DIR)
len(os.listdir(landsat_chip_dir))

We can also open one of the new Landsat 8 chips to inspect what it looks like.

In [None]:
landsat_images = glob(f"{landsat_chip_dir}/**/*.tiff", recursive=True)
first_l8_img = rioxarray.open_rasterio(landsat_images[0])
first_l8_img.plot()

Lastly, we will shutdown the Dask client to cleanup cluster resources.

In [None]:
client.shutdown()