In [None]:
# https://github.com/microsoft/PlanetaryComputerExamples/blob/main/competitions/s1floods/generate_auxiliary_input.ipynb

from dataclasses import dataclass
import os
from tempfile import TemporaryDirectory
from typing import List, Any, Dict

from shapely.geometry import box, mapping
import rasterio
from rasterio.warp import reproject, Resampling
import pyproj
from osgeo import gdal

from pystac_client import Client
import planetary_computer as pc
from pathlib import Path
import shutil

## Get the path to all training files

In [None]:
DATA_PATH = Path.cwd().parent.parent / "data" / "raw" / "train_features" / "train_features"

In [13]:
chip_paths = []
for file_name in os.listdir(DATA_PATH):
    if file_name.endswith("_vv.tif"):
        chip_paths.append(os.path.join(DATA_PATH, file_name))
print(f"{len(chip_paths)} chips found.")

542 chips found.


## Clean external data directory and prepare directories

In [None]:
EXTERNAL_DATA_PATH = Path.cwd().parent.parent / "data" / "external"

In [None]:
for files in os.listdir(EXTERNAL_DATA_PATH):
    path = os.path.join(EXTERNAL_DATA_PATH, files)
    try:
        shutil.rmtree(path)
    except OSError:
        if(files != ".gitkeep"):
            os.remove(path)

In [None]:
directories = ["nasadem", "jrc_extent", "jrc_occurrence", "jrc_recurrence", "jrc_seasonality", "jrc_transitions", "jrc_change"]
for directory in directories:
    os.mkdir(EXTERNAL_DATA_PATH / directory)

## Connect to the planetary computer API

In [None]:
STAC_API = "https://planetarycomputer.microsoft.com/api/stac/v1"
catalog = Client.open(STAC_API)

## Define functions and class

In [None]:
@dataclass
class ChipInfo:
    """
    Holds information about a training chip, including geospatial info for coregistration
    """

    path: str
    prefix: str
    crs: Any
    shape: List[int]
    transform: List[float]
    bounds: rasterio.coords.BoundingBox
    footprint: Dict[str, Any]


def get_footprint(bounds, crs):
    """Gets a GeoJSON footprint (in epsg:4326) from rasterio bounds and CRS"""
    transformer = pyproj.Transformer.from_crs(crs, "epsg:4326", always_xy=True)
    minx, miny = transformer.transform(bounds.left, bounds.bottom)
    maxx, maxy = transformer.transform(bounds.right, bounds.top)
    return mapping(box(minx, miny, maxx, maxy))


def get_chip_info(chip_path):
    """Gets chip info from a GeoTIFF file"""
    with rasterio.open(chip_path) as ds:
        chip_crs = ds.crs
        chip_shape = ds.shape
        chip_transform = ds.transform
        chip_bounds = ds.bounds

    # Use the first part of the chip filename as a prefix
    prefix = os.path.basename(chip_path).split("_")[0]

    return ChipInfo(
        path=chip_path,
        prefix=prefix,
        crs=chip_crs,
        shape=chip_shape,
        transform=chip_transform,
        bounds=chip_bounds,
        footprint=get_footprint(chip_bounds, chip_crs),
    )

In [None]:
def reproject_to_chip(
    chip_info, input_path, output_path, resampling=Resampling.nearest
):
    """
    Reproject a raster at input_path to chip_info, saving to output_path.

    Use Resampling.nearest for classification rasters. Otherwise use something
    like Resampling.bilinear for continuous data.
    """
    with rasterio.open(input_path) as src:
        kwargs = src.meta.copy()
        kwargs.update(
            {
                "crs": chip_info.crs,
                "transform": chip_info.transform,
                "width": chip_info.shape[1],
                "height": chip_info.shape[0],
                "driver": "GTiff",
            }
        )

        with rasterio.open(output_path, "w", **kwargs) as dst:
            for i in range(1, src.count + 1):
                reproject(
                    source=rasterio.band(src, i),
                    destination=rasterio.band(dst, i),
                    src_transform=src.transform,
                    src_crs=src.crs,
                    dst_transform=chip_info.transform,
                    dst_crs=chip_info.crs,
                    resampling=Resampling.nearest,
                )

In [None]:
def write_vrt(items, asset_key, dest_path):
    """Write a VRT with hrefs extracted from a list of items for a specific asset."""
    hrefs = [pc.sign(item.assets[asset_key].href) for item in items]
    vsi_hrefs = [f"/vsicurl/{href}" for href in hrefs]
    gdal.BuildVRT(dest_path, vsi_hrefs).FlushCache()

In [None]:
def create_chip_aux_file(
    dir_, chip_info, collection_id, asset_key, file_name, resampling=Resampling.nearest
):
    """
    Write an auxiliary chip file.

    The auxiliary chip file includes chip_info for the Collection and Asset, and is
    saved in the same directory as the original chip with the given file_name.
    """
    output_path = EXTERNAL_DATA_PATH / dir_ / f"{chip_info.prefix}.tif"
    search = catalog.search(collections=[collection_id], intersects=chip_info.footprint)
    items = list(search.get_items())
    with TemporaryDirectory() as tmp_dir:
        vrt_path = os.path.join(tmp_dir, "source.vrt")
        write_vrt(items, asset_key, vrt_path)
        reproject_to_chip(chip_info, vrt_path, output_path, resampling=resampling)
    return output_path

In [None]:
# Define a set of parameters to pass into create_chip_aux_file    
aux_file_params = [
    ("nasadem", "nasadem", "elevation", "nasadem.tif", Resampling.bilinear),
    ("jrc_extent", "jrc-gsw", "extent", "jrc-gsw-extent.tif", Resampling.nearest),
    ("jrc_occurrence", "jrc-gsw", "occurrence", "jrc-gsw-occurrence.tif", Resampling.nearest),
    ("jrc_recurrence", "jrc-gsw", "recurrence", "jrc-gsw-recurrence.tif", Resampling.nearest),
    ("jrc_seasonality", "jrc-gsw", "seasonality", "jrc-gsw-seasonality.tif", Resampling.nearest),
    ("jrc_transitions", "jrc-gsw", "transitions", "jrc-gsw-transitions.tif", Resampling.nearest),
    ("jrc_change", "jrc-gsw", "change", "jrc-gsw-change.tif", Resampling.nearest),
]

In [None]:
# Iterate over the chips and generate all aux input files.
count = len(chip_paths)
for i, chip_path in enumerate(chip_paths):
    print(f"({i+1} of {count}) {chip_path}")
    if not (EXTERNAL_DATA_PATH / "nasadem" / f"{chip_path.split('/')[-1]}").exists():
        chip_info = get_chip_info(chip_path)
        for dir_, collection_id, asset_key, file_name, resampling_method in aux_file_params:
            print(f"  ... Creating chip data for {collection_id} {asset_key}")
            create_chip_aux_file(
                dir_, chip_info, collection_id, asset_key, file_name, resampling=resampling_method
            )