In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import datetime as dt
import glob

import numpy as np
import pandas as pd
from pyproj import CRS

from torchgeo.experimental.spatiotemporal.gpd_index import (
    build_fake_sparse_rasters,
    build_raster_index,
    jitter_geometries,
)

## Geopandas as a Tile Index for SpatioTemporally Sparse Tiles

Examples of temporally/spatially sparse tiffs for various collections. Collections here refer to uniform tifs such as landsat with expected bands, sentinel, and a target (like land use land cover, etc.)

We build three collections:
- landsat at 30m for 10 dates for 10 aois with 4 bands in utm zone (32610)
- sentinel at 10m for 10 dates for 10 aois with 4 bands in webmercator (3857)
- target at 30m for 3 dates for 10 aois with 1 band in 4326

Note:
- we do not cover the case where targets are vector data and since we use geopandas, this should be relatively natural to implement

In [3]:
num_overlapping_sites = 10
rng = np.random.default_rng(42)
random_lat_lon_origins = [
    (
        rng.uniform(-123, -117),  # within UTM 32610 North Hemisphere zone
        rng.uniform(35, 55),  # on land
    )
    for _ in range(num_overlapping_sites)
]

# case where targets are rasters
targets = build_fake_sparse_rasters(
    output_dir='/tmp/example_rasters/targets',
    lat_lons=random_lat_lon_origins,
    crs=CRS.from_epsg(4326),
    num_dates=3,
    resolution_meters=30,
    start_date=dt.datetime(2023, 1, 1),
    end_date=dt.datetime(2023, 1, 31),
    nbands=1,
    rng=rng,
)

sentinel2 = build_fake_sparse_rasters(
    output_dir='/tmp/example_rasters/sentinel2',
    lat_lons=random_lat_lon_origins,
    crs=CRS.from_epsg(3857),  # likely actually UTM
    num_dates=10,
    resolution_meters=10,
    start_date=dt.datetime(2023, 1, 1),
    end_date=dt.datetime(2023, 1, 31),
    nbands=4,
    rng=rng,
)

landsat = build_fake_sparse_rasters(
    output_dir='/tmp/example_rasters/landsat',
    lat_lons=random_lat_lon_origins,
    crs=CRS.from_epsg(32610),
    num_dates=10,
    resolution_meters=30,
    start_date=dt.datetime(2023, 1, 1),
    end_date=dt.datetime(2023, 1, 31),
    nbands=4,
    rng=rng,
)

# sanity check
len(glob.glob('/tmp/example_rasters/**/*.tif'))

230

Think of each index below as defining an independent dataset.

In [4]:
date_regex = r'_(\d{8})\.tif$'

target_index = build_raster_index(
    tifs=targets, date_regex=date_regex, collection='target'
)
sentinel2_tiff_index = build_raster_index(
    tifs=sentinel2, date_regex=date_regex, collection='sentinel2'
)
landsat_tiff_index = build_raster_index(
    tifs=landsat, date_regex=date_regex, collection='landsat'
)
collection_index = pd.concat([sentinel2_tiff_index, landsat_tiff_index])

In [5]:
len(target_index), len(collection_index)

(30, 200)

Given tile index(es) for target collection, and multiple feature collections, let's demonstrate applying criteria to build a geopandas gdf that we can use for sampling.

In this example, we require:
- feature collections for a target tile fall within an 8day window before the target tile as a candidate feature tile
- all collections (target and features) intersect generating a viable sampling area
- all feature collections must be present as a viable sampling area

Note: likely room to optimize these opterations. This would sit somewhere between the union operator on datasets with index being a geopandas gdf (instead of rtree) and the sampler. This should also be precomputed somehow.

Think of the below operation as one way to union spatialtemporal datasets to define valid aois from which to sample.

In [6]:
%time
window = 8  # valid aoi must fall within 8days leading up to target tif
required_collection = {'sentinel2', 'landsat', 'target'}

valid_data = []
for i, target_tile in target_index.groupby('location'):
    # subset all collections w.r.t. time
    k = collection_index.datetime - target_tile.iloc[0].datetime
    idx = (k < pd.Timedelta(0, 'D')) & (k > -pd.Timedelta(window, 'D'))
    subset = collection_index[idx]
    # subset all collections wr.t. space
    subset = subset[subset.intersects(target_tile.iloc[0]['geometry'])]

    # find the intersection of all collection geometries
    # we need to actually use overlay or some other way of handling the case
    # where not all collections intersect. e.g. if one is disjoint, we will
    # get an empty GeoSeries
    combined = pd.concat([subset, target_tile])
    combined = combined.clip(combined.intersection_all())

    # combined = pd.concat([target_subset, collection_subset])
    combined['target_id'] = target_tile.iloc[0].name

    # apply any criteria like min number w.r.t. a collection or that all
    # collections are present, etc.
    if set(combined.collection.unique()) != required_collection:
        continue

    valid_data.append(combined)

final = pd.concat(valid_data)

CPU times: user 4 μs, sys: 0 ns, total: 4 μs
Wall time: 8.11 μs


In [7]:
all(
    final.groupby('target_id').apply(
        lambda x: x['collection'].nunique(), include_groups=False
    )
    == 3
)

True

In [8]:
# now if we split by target_id and sample, we will have a variable number of scenes
# that overlap. We would have to decide if our ML algorithm caan handle a non-uniform
# temporal grid or if we want to resample to a temporal grid or to composite with the
# median for example.
final.groupby('target_id').count().head(5)

Unnamed: 0_level_0,crs,geometry,location,datetime,collection
target_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,3,3,3,3,3
1,3,3,3,3,3
2,8,8,8,8,8
4,3,3,3,3,3
5,4,4,4,4,4


In [9]:
jitter_geometries(final).explore()

Now, we need to decide which grid we want to align to. I am guessing we want to align on a global 4326 grid with the same resolution of the target? Or do we align to the grid in the target tifs? If we don't have target tifs, but instead have target vector data, there will be no grid fwiw. 

In [10]:
for nm, subset in final.groupby('target_id'):
    break

In [37]:
from torchgeo.experimental.spatiotemporal.gpd_index import extract_slice

window_size = 16
target_example = subset.loc[subset.collection == 'target'].iloc[0]
ex_path = target_example['location']
sample_bounds = target_example['geometry'].buffer(-0.0002 * window_size).bounds

batch_size = 128
batch = np.empty((batch_size, 9, window_size, window_size), dtype=np.float32)

for i in range(batch_size):
    # sample a point inside the geometry
    x = np.random.uniform(sample_bounds[0], sample_bounds[2])
    y = np.random.uniform(sample_bounds[1], sample_bounds[3])

    # Example usage
    res = 0.0002695
    slice_bounds = (
        x,
        y,
        x + window_size * res + res * 0.01,  # hack
        y + window_size * res + res * 0.01,  # hack
    )

    d = []
    for url in subset['location']:
        slice_array = extract_slice(
            url,
            target_crs='EPSG:4326',
            slice_bounds=slice_bounds,
            target_resolution=res,
        )
        d.append(slice_array)
    example_patch = np.concatenate(d, axis=0)
    batch[i] = example_patch

In [38]:
batch.shape

(128, 9, 16, 16)

TODO: Using Xarray to add a time dimension