## LandCoverNet Data Preparation

<img src='https://radiant-assets.s3-us-west-2.amazonaws.com/PrimaryRadiantMLHubLogo.png' alt='Radiant MLHub Logo' width='300'/>

This tutorial delves into building a scalable model on the LandCoverNet dataset.

This portion of the tutorial is focused on developing a semantic segmentation model for LandCoverNet data
Here:

1. We will inspect the source imagery for the labels we have

2. We will process the source imagery in parallel using Dask

3. We will select the labels and filtered source images from Dask to be loaded 

4. We will save the images and associated labels data as a `pickle` file ('.pkl') on our directory to be loaded for model training

The esip-summer-2021-geospatial-ml tutorial was helpful in creating this notebook, which can be found [here](https://github.com/TomAugspurger/esip-summer-2021-geospatial-ml/blob/main/segmentation-model.ipynb). It was particularly useful for loading the STAC items and Sentinel-2 scenes using the `stackstac` library.

### Authentication

As demonstrated in the [Data Exploration notebook](/1.%20Data%20Exploration.ipynb") of this tutorial series, access to the Radiant MLHub API using the `pystac_client` library requires both an API end-point and API key. This notebook assumes that you have already followed the steps in `1. Data Exploration.ipynb` and also already have an MLHub API key that is not expired.

In [None]:
import getpass

MLHUB_API_KEY = getpass.getpass(prompt="MLHub API Key: ")
MLHUB_ROOT_URL = "https://api.radiant.earth/mlhub/v1"

There are a number of STAC and geospatial related libraries used in this notebook that need to be imported.

In [None]:
import os
import pystac
import warnings
import pystac_client
from shapely.geometry import mapping, shape
import rioxarray
from pystac import Item
from typing import List, Tuple

warnings.simplefilter("ignore", UserWarning)  # ignore warnings that get printed out
from datetime import datetime

import stackstac
import rasterio as rio
import rasterio.plot

warnings.simplefilter(action="ignore", category=FutureWarning)
import numpy as np

import pandas as pd
import pickle
import dask
import dask_gateway
from pystac.item_collection import ItemCollection

warnings.filterwarnings("ignore", "Creating an ndarray from ragged")

### Launch a Dask gateway cluster for parallel processing

We will use Dask to optimize our data processing of thousands of source image chips by parallelizing the workflow with a delayed computation graph. The Dask Client schedules, runs the delayed computations, and gathers the results, while the Dask Gateway provides a secure and centralized way of managing the multiple client clusters. This is especially useful for running Dask on Planetary Computer.

In [None]:
from dask.distributed import Client

client = Client()
client.run(lambda: warnings.filterwarnings("ignore", "Creating an ndarray from ragged"))

In [None]:
gateway = dask_gateway.Gateway()
options = gateway.cluster_options()
options["worker_cores"] = 7

### Instantiate an instance of the MLHub API Client

Here again we demonstrate how to instantiate an API client connected to the MLHub end-point using the `pystac_client` library.

In [None]:
mlhub_client = pystac_client.Client.open(
    url=MLHUB_ROOT_URL, parameters={"key": MLHUB_API_KEY}, ignore_conformance=True
)

We set the temporary directory based on the current working directory.

In [None]:
tmp_dir = os.path.join(os.getcwd(), "landcovernet")
labels_dir = os.path.join(tmp_dir, "labels")

We need to make sure that the labels collection has already been downloaded to the Planetary Computer instance we are running, stored in the shared directory `/home/jovyan/PlanetaryComputerExamples/landcovnet`. Please double-check the active working directory to make sure that the catalog is found when you run this notebook.

In [None]:
# check for data in collection file
catalog = pystac.read_file(
    os.path.join(labels_dir, "ref_landcovernet_v1_labels/collection.json")
)

### Loading the source imagery



In order to fetch the source images from the source Item Assets, first we need to gather all of the label Items from the LandCoverNet labels collection we downloaded in the previous tutorial. This grabs all of the label Item STAC objects. 

In [None]:
links = catalog.get_item_links()  # links from the catalog
label_items = [link.resolve_stac_object().target for link in links]

This is a helper function to calculate the percent of a raster image that is covered with clouds. It is assumed that the image input dimensions are 256 by 256 pixels. The sum of cloud cover across the image is normalized and divided by the total area of the chip. This returns an integer value of cloud cover between 0 and 100 to be passed to the STAC Item metadata. ***NOTE: This function is only called if the Item metadata does not include the `eo:cloud_cover` property.***

In [None]:
def calculate_cloud_cover(img_arr: np.ndarray) -> int:

    """Takes a chip cloud cover band and returns the integer score
    by dividing the sum of normalized values by the chip area (HxW).

    Args:
    img_arr: np.ndarray - 2d array of cloud cover mask

    Returns:
    arr_cc: int - integer value of cloud cover score

    """
    CHIP_AREA = 256 * 256
    arr_filled = np.nan_to_num(img_arr)
    arr_norm = arr_filled / 100
    arr_sum = arr_norm.sum()
    arr_cc = arr_sum / CHIP_AREA * 100
    return int(arr_cc)

For our use-case, we decided not to train the model on the entire source LandCoverNet dataset. Instead we take chip samples that were representative of each season or quadrimester, or any custom number of bins spread over a temporal range. 

This is a helper function that returns the median date from a set of all dates in a range representing each source Item linked to a label Item.

In [None]:
def get_median_date(id_arr: np.ndarray) -> int:

    """Takes a 2d array of source Item IDs for a quarter, and returns median date

    Args: id_arr: np.ndarray - 2d array of string values for source Item IDs

    Returns:
    median_date: int - the calculated median date value for input array

    """

    dates = [int(s[-8:]) for s in id_arr]
    dates.sort()

    n = len(dates)

    # case in which multiple items returned
    if n > 1:
        if n % 2 == 0:
            mid = int(n / 2)
        else:
            mid = int((n + 1) / 2)
        median_date = dates[mid]
    # base case there is only one source item
    elif n == 1:
        median_date = dates[0]
    # base case there are no source items
    else:
        median_date = 0

    return median_date

This helper function assigns an integer value from a datetime value based on the `period_ranges` variable created by another function below called `get_date_ranges()`. The period value assigned is used later to group and rank source Items on their cloud cover value so that within each temporal period, we are only working with the images with minimal cloud cover.

In [None]:
def assign_temporal_period(dt: datetime) -> int:
    """Takes a datetime and returns an integer based on n_periods defined"""
    for ix, pair in enumerate(period_ranges):
        if dt >= pair[0] and dt <= pair[1]:
            return ix + 1

    return None

This takes the DataFrame created from Item metadata in `get_season_min_cloud_cover()`, ranks the dates for each period by cloud cover value, and returns a single source chip for each datetime periods the Items are split into.

In [None]:
def filter_period_items(cc_df: pd.DataFrame) -> pd.DataFrame:

    """Takes a dataframe of source Items with metadata and filters
    on ranked cloudcover by period/season (quadrimester).

    Args:
    cc_df: pd.DataFrame - unfiltered dataframe

    Returns:
    filtered_df: pd.Dataframe - filtered dataframe

    """
    pd.options.mode.chained_assignment = None
    # assigns quarter and rank by quarter
    cc_df["date_time"] = pd.to_datetime(cc_df["date_time"])
    cc_df["period"] = cc_df["date_time"].apply(assign_temporal_period)
    cc_df["rank"] = cc_df.groupby("period")["cloud_cover"].rank(
        method="min", ascending=True
    )

    id_prefix = cc_df.iloc[0]["id"][:-8]
    median_dates = []

    # filters DataFrame on rank
    min_cc_df = cc_df[cc_df["rank"] == 1]

    # for each quadrimester in year, get the median date of source items
    for i in range(1, n_periods + 1):
        quarter_df = min_cc_df[min_cc_df["period"] == i]
        quarter_median_date = get_median_date(quarter_df["id"].values)
        quarter_median_id = id_prefix + str(quarter_median_date)
        median_dates.append(quarter_median_id)

    # filter the ranked DataFrame by median date
    filtered_df = min_cc_df[min_cc_df["id"].isin(median_dates)]
    return filtered_df

This is a wrapper function that creates a DataFrame from a list of Items, and calls the nested filtering functions defined above.

In [None]:
def get_season_min_cloud_cover(item_list: List[Item]) -> ItemCollection:

    """Takes a list of source Items and returns a single chip per season
    ranked by the minimum cloud cover from eo:cloud_cover property

    Args:
    item_list: List[Item] - iterable of source Items returned from search

    Returns:
    ItemCollection - STAC Iterable containing Items filtered by cloud cover
    """

    # constructs a DataFrame of each source item properties
    df_list = []
    for ui in item_list:
        if "eo:cloud_cover" in ui.properties:
            cloud_cover = ui.properties["eo:cloud_cover"]
        else:
            cloud_cover = calculate_cloud_cover(
                rio.open(ui.get_assets()["CLD"].href).read()
            )
        uid = {
            "item": ui,
            "id": ui.id,
            "cloud_cover": cloud_cover,
            "date_time": ui.datetime,
        }
        df_list.append(uid)

    cc_df = pd.DataFrame(df_list)

    # filters source items by cloud cover rank and returns ItemCollection
    if not cc_df.empty:
        filtered_df = filter_period_items(cc_df)

        return ItemCollection(filtered_df["item"].tolist())

    return None

This will take the temporal and spatial extent of an Item to query MLHub API client for matching source Items.

In [None]:
def get_label_item_collection(label_item: Item) -> ItemCollection:

    """Takes a label Item from the LandCoverNet Collection and searches
    for source imagery for chips that match spatial and temporal criteria

    Args:
    label_item: Item - item of current iteration in the get_item() Dask parallelization

    Returns:
    ItemCollection - STAC Iterable containing Items that match search criteria
    """

    n = 0
    cc_thresh = 10
    year_collection = ItemCollection([])

    # iterate over each start and end date per quarter
    for start, end in period_ranges:

        while n == 0:

            # performs a temporal and spatial search for each label item
            search = mlhub_client.search(
                collections=["ref_landcovernet_v1_source"],
                intersects=mapping(shape(label_item.geometry)),
                datetime=[start, end],
                query={"eo:cloud_cover": {"lt": cc_thresh}},
            )

            # converts search results to ItemCollection
            item_results = search.get_all_items()

            if not item_results:
                cc_thresh += 5
            else:
                n = len(item_results)

        year_collection += item_results  # concatenate ItemCollections for each quarter
        n = 0  # reset the length criteria for search results

    filtered_items = get_season_min_cloud_cover(year_collection.items)

    return filtered_items

This is the primary function that drives all the processing required to filter and load source imagery and label data into a stack of Xarray DataArrays for further processing, e.g. splitting the dataset into training and validation sets prior to training a machine learning model.

In [None]:
def get_item(label_item: Item, assets: Tuple[str]) -> (np.ndarray, np.ndarray):

    """Takes label Item and asset bands to construct n-darrays for model training

    Args:
    label_item: Item - item of current iteration in the get_item() Dask parallelization
    assets: Tuple[str] - a set of strings corresponding to the Asset band names

    Returns:
    data: np.ndarray, labels: np.ndarray - X and y n-darrays for model training
    """
    warnings.simplefilter(action="ignore", category=FutureWarning)
    assets = list(assets)
    labels = rioxarray.open_rasterio(
        tmp_dir + "/labels/ref_landcovernet_v1_labels/" + label_item.id + "/labels.tif",
    ).squeeze()

    source_item_collection = get_label_item_collection(label_item)

    if len(source_item_collection) > 0:

        bounds = tuple(round(x, 0) for x in labels.rio.bounds())

        data = stackstac.stack(
            items=source_item_collection,
            assets=assets,
            dtype="float32",
            resolution=10,
            bounds=bounds,
            epsg=labels.rio.crs.to_epsg(),
        )

        data = data.assign_coords(x=labels.x.data, y=labels.y.data)
        data /= 4000
        data = np.clip(data, 0, 1)

        return data, labels.astype("int64")

This takes in the temporal range of the Collection as well as a global variable `n_periods` defined below to return a list of datetime ranges split up into equal sized buckets based on the designated number of periods. E.g.  `n_periods=3` will return quadrimesters.

In [None]:
def get_date_ranges(
    start: datetime, end: datetime, periods: int
) -> List[List[datetime]]:

    """Builds a list of start and end date ranges for every four in the year

    Args: None
    Returns:
    quarter_ranges: List[List[datetime]] - a list of pairs of strings representing
        the start and end dates.

    """

    period_ranges = []
    all_dates = pd.DataFrame(
        pd.date_range(start=temporal_start, end=temporal_end, freq="1D"),
        columns=["Date"],
    )
    date_bins = pd.cut(all_dates.Date, bins=periods).drop_duplicates()

    for interval in date_bins:
        period_ranges.append(
            [
                interval.left.tz_localize("UTC").to_pydatetime(),
                interval.right.tz_localize("UTC").to_pydatetime(),
            ]
        )

    return period_ranges

Here we specify the temporal extent of the Catalog, n periods to divide the temporal range into, and bands to fetch for each source Item.

In [None]:
temporal_start = catalog.extent.temporal.intervals[0][0].strftime(
    "%Y-%m-%d"
)  # global starting datetime for label Collection
temporal_end = catalog.extent.temporal.intervals[0][1].strftime(
    "%Y-%m-%d"
)  # global ending datetime for label Collection
n_periods = 5

period_ranges = get_date_ranges(temporal_start, temporal_end, n_periods)
assets = ("B04", "B03", "B02")  # we will make use of the RGB bands

### Load the source imagery

Now we will bring everything together. We setup the `get_item()` function defined above as a Dask delayed function, and append the lazy results from fetching N source images (`n_periods` variable above) for each label Item in the Catalog. Then the actual computation occurs in parallel, and the results are appended to a list of DataArrays containing the aligned images and labels (X and y features).

In [None]:
%%time

Xys_list = []
import logging

logger = logging.getLogger("distributed.utils_perf")
logger.setLevel(logging.ERROR)

chunk_size = 20
for i in range(0, len(label_items), chunk_size):
    label_chunk = label_items[i : i + chunk_size]

    Xys = []
    get_item_ = dask.delayed(get_item, nout=5)

    Xys.append([get_item_(label, assets) for label in label_chunk])
    Xys = dask.persist(*Xys)
    Xys = dask.compute(*Xys)
    Xys_list.append(Xys[0])

The Dask client can be shutdown with the following command

In [None]:
client.shutdown()

We stacked the results of our parallelized function into chunks of 20 Items at a time, so this will flatten the list.

In [None]:
flat_list = [item for sublist in Xys_list for item in sublist]

This confirms that every item was extracted, e.g. the flattened list has the length of label Items fetched.

In [None]:
len(flat_list)  # confirm every item was extracted

This confirms that the shape of data for each label item has n source items.

In [None]:
flat_list[0][0].shape  # confirm that we have the desired shape for a chip

In [None]:
%rm -rf labels #clear up labels to clear PC space. may choose to leave the labels

The last step before training a neural network is to dump the image dataset we just created into a pickle file stored locally on the Planetary Computer instance running. This is an efficient way to store and load the dataset in the next notebook and to conserve memory resources.

In [None]:
pickle.dump((flat_list), open(f"{tmp_dir}/items" + ".pkl", "ab"))