## MOSAIKS feature extraction

This tutorial demonstrates the **MOSAIKS** method for extracting _feature vectors_ from satellite imagery patches for use in downstream modeling tasks. It will show:
- How to extract 1km$^2$ patches of Sentinel 2 multispectral imagery for a list of latitude, longitude points
- How to extract summary features from each of these imagery patches
- How to use the summary features in a linear model of the population density at each point

### Background

Consider the case where you have a dataset of latitude and longitude points assosciated with some dependent variable (for example: population density, weather, housing prices, biodiversity) and, potentially, other independent variables. You would like to model the dependent variable as a function of the independent variables, but instead of including latitude and longitude directly in this model, you would like to include some high dimensional representation of what the Earth looks like at that point (that hopefully explains some of the variance in the dependent variable!). From the computer vision literature, there are various [representation learning techniques](https://en.wikipedia.org/wiki/Feature_learning) that can be used to do this, i.e. extract _features vectors_ from imagery. This notebook gives an implementation of the technique described in [Rolf et al. 2021](https://www.nature.com/articles/s41467-021-24638-z), "A generalizable and accessible approach to machine learning with global satellite imagery" called Multi-task Observation using Satellite Imagery & Kitchen Sinks (**MOSAIKS**). For more information about **MOSAIKS** see the [project's webpage](http://www.globalpolicy.science/mosaiks).


**Notes**:
- This example uses [Sentinel-2 Level-2A data](https://planetarycomputer.microsoft.com/dataset/sentinel-2-l2a). The techniques used here apply equally well to other remote-sensing datasets.
- If you're running this on the [Planetary Computer Hub](http://planetarycomputer.microsoft.com/compute), make sure to choose the **GPU - PyTorch** profile when presented with the form to choose your environment.

In [1]:
!pip install -q git+https://github.com/geopandas/dask-geopandas

In [2]:
import warnings
import time
import os
import gc

RASTERIO_BEST_PRACTICES = dict(  # See https://github.com/pangeo-data/cog-best-practices
    CURL_CA_BUNDLE="/etc/ssl/certs/ca-certificates.crt",
    GDAL_DISABLE_READDIR_ON_OPEN="EMPTY_DIR",
    AWS_NO_SIGN_REQUEST="YES",
    GDAL_MAX_RAW_BLOCK_CACHE_SIZE="200000000",
    GDAL_SWATH_SIZE="200000000",
    VSI_CURL_CACHE_SIZE="200000000",
)
os.environ.update(RASTERIO_BEST_PRACTICES)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import rasterio
import rasterio.warp
import rasterio.mask
import shapely.geometry
import geopandas
import dask_geopandas
from sklearn.linear_model import RidgeCV
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score
from scipy.stats import spearmanr
from scipy.linalg import LinAlgWarning
from dask.distributed import Client


warnings.filterwarnings(action="ignore", category=LinAlgWarning, module="sklearn")
warnings.filterwarnings(action="ignore", category=UserWarning, module="torch")

import pystac_client
import planetary_computer as pc

## Set Parameters

In [7]:
num_features = 4000
cell_size = 2000
geodetic_epsg = 'EPSG:4326'
projected_epsg = 'EPSG:20935' 
country_code = 'ZMB'
# [
    # 'NGA', # Nigeria 
    # 'TZA', # Tanzania 
    # 'ZMB', # Zambia
# ]

## Create grid and sample points to featurize

In [8]:
africa = geopandas.read_file('data/africa_adm0.geojson')
country = africa[africa.adm0_a3 == country_code]

# Project country into local EPSG with units in meters
country_prj = country.to_crs(projected_epsg)

# Create grid of points
xmin, ymin, xmax, ymax = country_prj.total_bounds
xs = list(np.arange(xmin, xmax + cell_size, cell_size))
ys = list(np.arange(ymin, ymax + cell_size, cell_size))
def make_cell(x, y, cell_size):
    ring = [
        (x, y),
        (x + cell_size, y),
        (x + cell_size, y + cell_size),
        (x, y + cell_size)
    ]
    cell = shapely.geometry.Polygon(ring).centroid
    return cell
cells = []
for x in xs:
    for y in ys:
        cell = make_cell(x, y, cell_size)
        cells.append(cell)

# Put grid into a GeDataFrame and select points
grid = geopandas.GeoDataFrame({'geometry': cells}, crs = projected_epsg)
grid['lon'] = grid.geometry.x
grid['lat'] = grid.geometry.y
grid['x'] = grid.groupby(['lon']).ngroup() + 1
grid['y'] = grid.groupby(['lat']).ngroup() + 1
grid['includepoint'] = (grid.y + grid.x) % 2 == 0 
grid = grid[grid.includepoint]

# Reproject grid and subset to country again
gdf = grid[grid.within(country_prj.unary_union)]
gdf = gdf.to_crs(geodetic_epsg)
gdf['lon'] = gdf.geometry.x
gdf['lat'] = gdf.geometry.y
gdf = gdf[['lon', 'lat', 'geometry']].reset_index(drop = True)
gdf.shape

(94184, 3)

In [None]:
376779

First we define the pytorch model that we will use to extract the features and a helper method. The **MOSAIKS** methodology describes several ways to do this and we use the simplest.

In [5]:
class RCF(nn.Module):
    """A model for extracting Random Convolution Features (RCF) from input imagery."""

    def __init__(self, num_features=16, kernel_size=3, num_input_channels=3):
        super(RCF, self).__init__()

        # We create `num_features / 2` filters so require `num_features` to be divisible by 2
        assert num_features % 2 == 0, "Please enter an even number of features."

        # Applies a 2D convolution over an input image composed of several input planes.
        self.conv1 = nn.Conv2d(
            num_input_channels,
            num_features // 2,
            kernel_size=kernel_size,
            stride=1,
            padding=0,
            dilation=1,
            bias=True,
        )

        # Fills the input Tensor 'conv1.weight' with values drawn from the normal distribution
        nn.init.normal_(self.conv1.weight, mean=0.0, std=1.0)
        
        # Fills the input Tensor 'conv1.bias' with the value 'val = -1'.
        nn.init.constant_(self.conv1.bias, -1.0)

    def forward(self, x):
        # The rectified linear activation function or ReLU for short is a piecewise linear function 
        # that will output the input directly if it is positive, otherwise, it will output zero.
        x1a = F.relu(self.conv1(x), inplace=True)
        # The below step is where we take the inverse which is appended later
        x1b = F.relu(-self.conv1(x), inplace=True)

        # Applies a 2D adaptive average pooling over an input signal composed of several input planes.
        x1a = F.adaptive_avg_pool2d(x1a, (1, 1)).squeeze()
        x1b = F.adaptive_avg_pool2d(x1b, (1, 1)).squeeze()

        if len(x1a.shape) == 1:  # case where we passed a single input
            return torch.cat((x1a, x1b), dim=0)
        elif len(x1a.shape) == 2:  # case where we passed a batch of > 1 inputs
            return torch.cat((x1a, x1b), dim=1)

Next, we initialize the model and pytorch components

In [6]:
device = torch.device("cuda")
model = RCF(num_features).eval().to(device)

### Extract features from the imagery around each point

We need to find a suitable Sentinel 2 scene for each point. As usual, we'll use `pystac-client` to search for items matching some conditions, but we don't just want do make a `.search()` call for each of the 67,968 remaining points. Each HTTP request is relatively slow. Instead, we will *batch* or points and search *in parallel*.

We need to be a bit careful with how we batch up our points though. Since a single Sentinel 2 scene will cover many points, we want to make sure that points which are spatially close together end up in the same batch. In short, we need to spatially partition the dataset. This is implemented in `dask-geopandas`.

So the overall workflow will be

1. Find an appropriate STAC item for each point (in parallel, using the spatially partitioned dataset)
2. Feed the points and STAC items to a custom Dataset that can read imagery given a point and the URL of a overlapping S2 scene
3. Use a custom Dataloader, which uses our Dataset, to feed our model imagery and save the corresponding features

In [7]:
NPARTITIONS = 250

ddf = dask_geopandas.from_geopandas(gdf, npartitions=1)
hd = ddf.hilbert_distance().compute()
gdf["hd"] = hd
gdf = gdf.sort_values("hd")

dgdf = dask_geopandas.from_geopandas(gdf, npartitions=NPARTITIONS, sort=False)

We'll write a helper function that 

### Zambia Crop Info
median plant:   
date = 318  
month = 11 (Nov)    

median harvest:     
month = 5 (May)  

Season:  
length = 7 months  

In [8]:
%%time
for yr in range(2016, 2019):
    def query(points):
        """
        Find a STAC item for points in the `points` DataFrame

        Parameters
        ----------
        points : geopandas.GeoDataFrame
            A GeoDataFrame

        Returns
        -------
        geopandas.GeoDataFrame
            A new geopandas.GeoDataFrame with a `stac_item` column containing the STAC
            item that covers each point.
        """
        intersects = shapely.geometry.mapping(points.unary_union.convex_hull)

        # search_start = f"{yr - 1}-11-01"
        search_start = f"{yr}-03-01"
        search_end = f"{yr}-05-30"
        catalog = pystac_client.Client.open(
            "https://planetarycomputer.microsoft.com/api/stac/v1"
        )

        # The time frame in which we search for non-cloudy imagery
        search = catalog.search(
            collections=["sentinel-2-l2a"],
            intersects=intersects,
            datetime=[search_start, search_end],
            query={"eo:cloud_cover": {"lt": 10}},
            limit=500,
        )
        ic = search.get_all_items_as_dict()

        features = ic["features"]
        features_d = {item["id"]: item for item in features}

        data = {
            "eo:cloud_cover": [],
            "geometry": [],
        }

        index = []

        for item in features:
            data["eo:cloud_cover"].append(item["properties"]["eo:cloud_cover"])
            data["geometry"].append(shapely.geometry.shape(item["geometry"]))
            index.append(item["id"])

        items = geopandas.GeoDataFrame(data, index=index, geometry="geometry").sort_values(
            "eo:cloud_cover"
        )
        point_list = points.geometry.tolist()

        point_items = []
        for point in point_list:
            covered_by = items[items.covers(point)]
            if len(covered_by):
                point_items.append(features_d[covered_by.index[0]])
            else:
                # There weren't any scenes matching our conditions for this point (too cloudy)
                point_items.append(None)

        return points.assign(stac_item=point_items)
    
    print("Matching images to points for", yr)
    
    with Client(n_workers=16) as client:
        meta = dgdf._meta.assign(stac_item=[])
        df2 = dgdf.map_partitions(query, meta=meta).compute()
    df3 = df2.dropna(subset=["stac_item"])

    matching_urls = [
        pc.sign(item["assets"]["visual"]["href"]) for item in df3.stac_item.tolist()
    ]

    points = df3[["lon", "lat"]].to_numpy()

    class CustomDataset(Dataset):
        def __init__(self, points, fns, buffer=500):
            self.points = points
            self.fns = fns
            self.buffer = buffer

        def __len__(self):
            return self.points.shape[0]

        def __getitem__(self, idx):

            lon, lat = self.points[idx]
            fn = self.fns[idx]

            if fn is None:
                return None
            else:
                point_geom = shapely.geometry.mapping(shapely.geometry.Point(lon, lat))

                with rasterio.Env():
                    with rasterio.open(fn, "r") as f:
                        point_geom = rasterio.warp.transform_geom(
                            "epsg:4326", f.crs.to_string(), point_geom
                        )
                        point_shape = shapely.geometry.shape(point_geom)
                        mask_shape = point_shape.buffer(self.buffer).envelope
                        mask_geom = shapely.geometry.mapping(mask_shape)
                        try:
                            out_image, out_transform = rasterio.mask.mask(
                                f, [mask_geom], crop=True
                            )
                        except ValueError as e:
                            if "Input shapes do not overlap raster." in str(e):
                                return None

                out_image = out_image / 255.0
                out_image = torch.from_numpy(out_image).float()
                return out_image
    dataset = CustomDataset(points, matching_urls)

    dataloader = DataLoader(
        dataset,
        batch_size=8,
        shuffle=False,
        num_workers=os.cpu_count(),
        collate_fn=lambda x: x,
        pin_memory=False,
    )
    
    x_all = np.zeros((points.shape[0], num_features), dtype=float)
    tic = time.time()
    i = 0
    
    print("Featurizing year:", yr)
    for images in dataloader:
        for image in images:
            if i % 5000 == 0:
                print(
                    f"{i}/{points.shape[0]} -- {i / points.shape[0] * 100:0.2f}%"
                    + f" -- {time.time()-tic:0.2f} seconds"
                )
                tic = time.time()
            if image is not None:
                # A full image should be ~101x101 pixels (i.e. ~1km^2 at a 10m/px spatial
                # resolution), however we can receive smaller images if an input point
                # happens to be at the edge of a S2 scene (a literal edge case). To deal
                # with these (edge) cases we crudely drop all images where the spatial
                # dimensions aren't both greater than 20 pixels.
                if image.shape[1] >= 20 and image.shape[2] >= 20:
                    image = image.to(device)
                    with torch.no_grad():
                        feats = model(image.unsqueeze(0)).cpu().numpy()
                    x_all[i] = feats
                else:
                    # this happens if the point is close to the edge of a scene
                    # (one or both of the spatial dimensions of the image are very small)
                    pass
            else:
                pass  # this happens if we do not find a S2 scene for some point
            i += 1
    features = pd.DataFrame(x_all)
    features[["lon", "lat"]] = points.tolist()
    features['year'] = yr
    features.columns = features.columns.astype(str)
    
    # Save the features to a feather file
    file_name = (f'data/sentinel_2_{country_code}_{len(points)/1000:.0f}'+
                 f'k-points_{round(num_features, -4)}k-features_{yr}.feather')
    # file_name = f'data/features_{yr}.feather'
    print("Saving file as:", file_name)
    features.to_feather(file_name)
    print("Save finished!")
    
    # Free memory before loop iterates
    del x_all
    del features
    gc.collect()
    print("")

Matching images to points for 2016
Featurizing year: 2016
0/60222 -- 0.00% -- 3.57 seconds
5000/60222 -- 8.30% -- 48.04 seconds
10000/60222 -- 16.61% -- 52.52 seconds
15000/60222 -- 24.91% -- 42.61 seconds
20000/60222 -- 33.21% -- 44.71 seconds
25000/60222 -- 41.51% -- 40.44 seconds
30000/60222 -- 49.82% -- 33.72 seconds
35000/60222 -- 58.12% -- 35.52 seconds
40000/60222 -- 66.42% -- 33.88 seconds
45000/60222 -- 74.72% -- 36.49 seconds
50000/60222 -- 83.03% -- 32.95 seconds
55000/60222 -- 91.33% -- 39.81 seconds
60000/60222 -- 99.63% -- 35.91 seconds
Saving file as: data/sentinel_2_ZMB_60k-points_2k-features_2016.feather
Save finished!

Matching images to points for 2017
Featurizing year: 2017
0/59117 -- 0.00% -- 0.44 seconds
5000/59117 -- 8.46% -- 49.76 seconds
10000/59117 -- 16.92% -- 46.39 seconds
15000/59117 -- 25.37% -- 42.18 seconds
20000/59117 -- 33.83% -- 42.40 seconds
25000/59117 -- 42.29% -- 47.37 seconds
30000/59117 -- 50.75% -- 43.21 seconds
35000/59117 -- 59.20% -- 44.20 s

In [9]:
f'data/sentinel_2_{country_code}_{len(points)/1000:.0f}'+\
 f'k-points_{num_features/1000:.0f}k-features_{yr}.feather'

'data/sentinel_2_ZMB_60k-points_2k-features_2018.feather'

In [10]:
len(points)

60283

In [11]:
# import seaborn as sns
# plt.figure(figsize = (15,10))
# sns.heatmap(features_df, annot=False, cmap = 'viridis')

In [12]:
# file_name = f'data/features_{year}.feather'
# features.to_feather(file_name)

In [13]:
# plt.figure(figsize = (15,15))
# plt.scatter(features.lon, features.lat, c = features['998'], cmap = 'viridis', s = 4, alpha = 1)

In [14]:
# features_2 = features.copy()
# features_2['total'] = features.iloc[:, 0:-3].mean(axis=1)
# features_2

In [15]:
# plt.figure(figsize = (15,15))
# plt.scatter(features_2.lon, features_2.lat, c = features_2['total'], cmap = 'viridis', s = 1, alpha = 1)