## MOSAIKS feature extraction

This tutorial demonstrates the **MOSAIKS** method for extracting _feature vectors_ from satellite imagery patches for use in downstream modeling tasks. It will show:
- How to extract 1km$^2$ patches ofLandsat 8 multispectral imagery for a list of latitude, longitude points
- How to extract summary features from each of these imagery patches
- How to use the summary features in a linear model of the population density at each point

### Background

Consider the case where you have a dataset of latitude and longitude points assosciated with some dependent variable (for example: population density, weather, housing prices, biodiversity) and, potentially, other independent variables. You would like to model the dependent variable as a function of the independent variables, but instead of including latitude and longitude directly in this model, you would like to include some high dimensional representation of what the Earth looks like at that point (that hopefully explains some of the variance in the dependent variable!). From the computer vision literature, there are various [representation learning techniques](https://en.wikipedia.org/wiki/Feature_learning) that can be used to do this, i.e. extract _features vectors_ from imagery. This notebook gives an implementation of the technique described in [Rolf et al. 2021](https://www.nature.com/articles/s41467-021-24638-z), "A generalizable and accessible approach to machine learning with global satellite imagery" called Multi-task Observation using Satellite Imagery & Kitchen Sinks (**MOSAIKS**). For more information about **MOSAIKS** see the [project's webpage](http://www.globalpolicy.science/mosaiks).


**Notes**:
- If you're running this on the [Planetary Computer Hub](http://planetarycomputer.microsoft.com/compute), make sure to choose the **GPU - PyTorch** profile when presented with the form to choose your environment.

In [None]:
!pip install -q git+https://github.com/geopandas/dask-geopandas

In [None]:
import warnings
import time
import os

RASTERIO_BEST_PRACTICES = dict(  # See https://github.com/pangeo-data/cog-best-practices
    CURL_CA_BUNDLE="/etc/ssl/certs/ca-certificates.crt",
    GDAL_DISABLE_READDIR_ON_OPEN="EMPTY_DIR",
    AWS_NO_SIGN_REQUEST="YES",
    GDAL_MAX_RAW_BLOCK_CACHE_SIZE="200000000",
    GDAL_SWATH_SIZE="200000000",
    VSI_CURL_CACHE_SIZE="200000000",
)
os.environ.update(RASTERIO_BEST_PRACTICES)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import contextily as ctx

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import rasterio
import rasterio.warp
import rasterio.mask
import shapely.geometry
import geopandas
import dask_geopandas
from sklearn.linear_model import RidgeCV
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score
from scipy.stats import spearmanr
from scipy.linalg import LinAlgWarning
from dask.distributed import Client


warnings.filterwarnings(action="ignore", category=LinAlgWarning, module="sklearn")

import pystac_client
import planetary_computer as pc

First we define the pytorch model that we will use to extract the features and a helper method. The **MOSAIKS** methodology describes several ways to do this and we use the simplest.

In [None]:
def featurize(input_img, model, device):
    """Helper method for running an image patch through the model.

    Args:
        input_img (np.ndarray): Image in (C x H x W) format with a dtype of uint8.
        model (torch.nn.Module): Feature extractor network
    """
    assert len(input_img.shape) == 3
    input_img = torch.from_numpy(input_img / 255.0).float()
    input_img = input_img.to(device)
    with torch.no_grad():
        feats = model(input_img.unsqueeze(0)).cpu().numpy()
    return feats


class RCF(nn.Module):
    """A model for extracting Random Convolution Features (RCF) from input imagery."""

    def __init__(self, num_features=16, kernel_size=3, num_input_channels=1):   # ------------------------------------------------------------- Input channels
        super(RCF, self).__init__()

        # We create `num_features / 2` filters so require `num_features` to be divisible by 2
        assert num_features % 2 == 0

        self.conv1 = nn.Conv2d(
            num_input_channels,
            num_features // 2,
            kernel_size=kernel_size,
            stride=1,
            padding=0,
            dilation=1,
            bias=True,
        )

        nn.init.normal_(self.conv1.weight, mean=0.0, std=1.0)
        nn.init.constant_(self.conv1.bias, -1.0)

    def forward(self, x):
        x1a = F.relu(self.conv1(x), inplace=True)
        x1b = F.relu(-self.conv1(x), inplace=True)

        x1a = F.adaptive_avg_pool2d(x1a, (1, 1)).squeeze()
        x1b = F.adaptive_avg_pool2d(x1b, (1, 1)).squeeze()

        if len(x1a.shape) == 1:  # case where we passed a single input
            return torch.cat((x1a, x1b), dim=0)
        elif len(x1a.shape) == 2:  # case where we passed a batch of > 1 inputs
            return torch.cat((x1a, x1b), dim=1)

Next, we initialize the model and pytorch components

In [None]:
num_features = 2048

device = torch.device("cuda")
model = RCF(num_features).eval().to(device)

### Read dataset of (lat, lon) points and corresponding labels
Zambia:   1997-2015  
Tanzania: 2003-2010  
Nigeria:  1995-2006  

In [None]:
year = 2013
adm_level = "adm1"

np.random.seed(42)

In [None]:
# load Data
gdf_crop = geopandas.read_file("data/unified_crop_data.gpkg")

# Filter for 1 Country
gdf_crop = gdf_crop[gdf_crop.adm0 == 'zambia']

# Filter for 1 year but keep geometry without crop data
gdf_crop = gdf_crop[(gdf_crop.year == year) | (np.isnan(gdf_crop.year))]

# find the bounds of your geodataframe
x_min, y_min, x_max, y_max = gdf_crop.total_bounds

# set sample size (number of points inside bounding box)
# this will be reduced to only points inside the country
n = 2000

# generate random data within the bounds
x = np.random.uniform(x_min, x_max, n)
y = np.random.uniform(y_min, y_max, n)

# convert them to a points GeoSeries
gdf_points = geopandas.GeoSeries(geopandas.points_from_xy(x, y))

# only keep those points within polygons
gdf_points = gdf_points[gdf_points.within(gdf_crop.unary_union)]

# make points GeoSeries into GeoDataFrame
gdf_points = geopandas.GeoDataFrame(gdf_points).rename(columns={0:'geometry'}).set_geometry('geometry')

# Make blank GeoDataFrame
gdf = geopandas.GeoDataFrame()

# Extract lon, lat, and geometry values and assign to columns
gdf['lon'] = gdf_points['geometry'].x
gdf['lat'] = gdf_points['geometry'].y
gdf['geometry'] = gdf_points['geometry']

# Set CRS
gdf = gdf.set_crs('EPSG:4326')

# Also make a regular dataframe
points = pd.DataFrame(gdf)

In [None]:
len(points)

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))
gdf_crop.plot(
    ax = ax
    , color = "blue"
    , edgecolor = 'black'
    , alpha = .25
)
gdf.plot(ax = ax)
ax.grid(False)
ctx.add_basemap(ax, crs="EPSG:4326")

Get rid of points with nodata population values

### Extract features from the imagery around each point

We need to find a suitable Sentinel 2 scene for each point. As usual, we'll use `pystac-client` to search for items matching some conditions, but we don't just want do make a `.search()` call for each of the 67,968 remaining points. Each HTTP request is relatively slow. Instead, we will *batch* or points and search *in parallel*.

We need to be a bit careful with how we batch up our points though. Since a single Sentinel 2 scene will cover many points, we want to make sure that points which are spatially close together end up in the same batch. In short, we need to spatially partition the dataset. This is implemented in `dask-geopandas`.

So the overall workflow will be

1. Find an appropriate STAC item for each point (in parallel, using the spatially partitioned dataset)
2. Feed the points and STAC items to a custom Dataset that can read imagery given a point and the URL of a overlapping S2 scene
3. Use a custom Dataloader, which uses our Dataset, to feed our model imagery and save the corresponding features

In [None]:
NPARTITIONS = 250

ddf = dask_geopandas.from_geopandas(gdf, npartitions=1)
hd = ddf.hilbert_distance().compute()
gdf["hd"] = hd
gdf = gdf.sort_values("hd")
gdf = gdf.reset_index()

dgdf = dask_geopandas.from_geopandas(gdf, npartitions=NPARTITIONS, sort=False)

We'll write a helper function that 

In [None]:
def query(points):
    """
    Find a STAC item for points in the `points` DataFrame

    Parameters
    ----------
    points : geopandas.GeoDataFrame
        A GeoDataFrame

    Returns
    -------
    geopandas.GeoDataFrame
        A new geopandas.GeoDataFrame with a `stac_item` column containing the STAC
        item that covers each point.
    """
    intersects = shapely.geometry.mapping(points.unary_union.convex_hull)

    search_start = f"{year}-01-01"
    search_end = f"{year}-12-31"
    catalog = pystac_client.Client.open(
        "https://planetarycomputer.microsoft.com/api/stac/v1"
    )

    # The time frame in which we search for non-cloudy imagery
    search = catalog.search(
        collections=["landsat-8-c2-l2"],  # "landsat-8-c2-l2"   "sentinel-2-l2a"
        intersects=intersects,
        datetime=[search_start, search_end],
        query={"eo:cloud_cover": {"lt": 10}},
        limit=500,
    )
    ic = search.get_all_items_as_dict()

    features = ic["features"]
    features_d = {item["id"]: item for item in features}

    data = {
        "eo:cloud_cover": [],
        "geometry": [],
    }

    index = []

    for item in features:
        data["eo:cloud_cover"].append(item["properties"]["eo:cloud_cover"])
        data["geometry"].append(shapely.geometry.shape(item["geometry"]))
        index.append(item["id"])

    items = geopandas.GeoDataFrame(data, index=index, geometry="geometry").sort_values(
        "eo:cloud_cover"
    )
    point_list = points.geometry.tolist()

    point_items = []
    for point in point_list:
        covered_by = items[items.covers(point)]
        if len(covered_by):
            point_items.append(features_d[covered_by.index[0]])
        else:
            # There weren't any scenes matching our conditions for this point (too cloudy)
            point_items.append(None)

    return points.assign(stac_item=point_items)

In [None]:
%%time

with Client(n_workers=16) as client:
    print(client.dashboard_link)
    meta = dgdf._meta.assign(stac_item=[])
    df2 = dgdf.map_partitions(query, meta=meta).compute()

In [None]:
df2.shape

In [None]:
df3 = df2.dropna(subset=["stac_item"])

matching_urls =(
    [pc.sign(item["assets"]["SR_B1"]["href"]) for item in df3.stac_item.tolist()] +
    [pc.sign(item["assets"]["SR_B2"]["href"]) for item in df3.stac_item.tolist()] +
    [pc.sign(item["assets"]["SR_B3"]["href"]) for item in df3.stac_item.tolist()]
)

points = df3[["lon", "lat"]].to_numpy()
df3.shape

In [None]:
class CustomDataset(Dataset):
    def __init__(self, points, fns, buffer=500):
        self.points = points
        self.fns = fns
        self.buffer = buffer

    def __len__(self):
        return self.points.shape[0]

    def __getitem__(self, idx):

        lon, lat = self.points[idx]
        fn = self.fns[idx]

        if fn is None:
            return None
        else:
            point_geom = shapely.geometry.mapping(shapely.geometry.Point(lon, lat))

            with rasterio.Env():
                with rasterio.open(fn, "r") as f:
                    point_geom = rasterio.warp.transform_geom(
                        "epsg:4326", f.crs.to_string(), point_geom
                    )
                    point_shape = shapely.geometry.shape(point_geom)
                    mask_shape = point_shape.buffer(self.buffer).envelope
                    mask_geom = shapely.geometry.mapping(mask_shape)
                    try:
                        out_image, out_transform = rasterio.mask.mask(
                            f, [mask_geom], crop=True
                        )
                    except ValueError as e:
                        if "Input shapes do not overlap raster." in str(e):
                            return None

            out_image = out_image / 255.0
            out_image = torch.from_numpy(out_image).float()
            return out_image

In [None]:
dataset = CustomDataset(points, matching_urls)

dataloader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=False,
    num_workers=os.cpu_count() ,
    collate_fn=lambda x: x,
    pin_memory=False,
)

In [None]:
x_all = np.zeros((points.shape[0], num_features), dtype=float)

tic = time.time()
i = 0
for images in dataloader:
    for image in images:

        if image is not None:
            # Edit this below to reflect landsat data
            
            # A full image should be ~101x101 pixels (i.e. ~1km^2 at a 30m/px spatial
            # resolution), however we can receive smaller images if an input point
            # happens to be at the edge of a landsat scene (a literal edge case). To deal
            # with these (edge) cases we crudely drop all images where the spatial
            # dimensions aren't both greater than 20 pixels.
            if image.shape[1] >= 20 and image.shape[2] >= 20:
                image = image.to(device)
                with torch.no_grad():
                    feats = model(image.unsqueeze(0)).cpu().numpy()

                x_all[i] = feats
            else:
                # this happens if the point is close to the edge of a scene
                # (one or both of the spatial dimensions of the image are very small)
                pass
        else:
            pass  # this happens if we do not find a S2 scene for some point

        if i % 1000 == 0:
            print(
                f"{i}/{points.shape[0]} -- {i / points.shape[0] * 100:0.2f}%"
                + f" -- {time.time()-tic:0.2f} seconds"
            )
            tic = time.time()
        i += 1

In [None]:
x_all.shape

In [None]:
x_all = pd.DataFrame(x_all)
x_all

In [None]:
gdf

In [None]:
gdf_features = gdf.join(x_all)

In [None]:
gdf_features = gdf_features.drop(['index', 'lon', 'lat', 'hd'], axis = 1)

In [None]:
gdf_features

In [None]:
cols = range(0, num_features)
gdf_features_long = pd.melt(gdf_features,  
                            id_vars=['geometry'], 
                            value_vars=cols, 
                            var_name = 'feature')

In [None]:
features = gdf_crop.sjoin(gdf_features_long, how = 'right', predicate = 'intersects')
features

In [None]:
features_summary = features.groupby([adm_level, 'year', 'feature']).agg({'value': 'mean'})

In [None]:
features_summary = features_summary.reset_index()
features_summary

In [None]:
features_summary_wide = features_summary.pivot(index = [adm_level, "year"], columns='feature', values='value')

In [None]:
features_summary_wide = features_summary_wide.reset_index().rename_axis(None, axis=1)

In [None]:
features_summary_wide