## MOSAIKS feature extraction

This tutorial demonstrates the **MOSAIKS** method for extracting _feature vectors_ from satellite imagery patches for use in downstream modeling tasks. It will show:
- How to extract 1km$^2$ patches of Sentinel 2 or Landsat multispectral imagery for a list of latitude, longitude points
- How to extract summary features from each of these imagery patches
- How to use the summary features in a linear model of the population density at each point

### Background

Consider the case where you have a dataset of latitude and longitude points assosciated with some dependent variable (for example: population density, weather, housing prices, biodiversity) and, potentially, other independent variables. You would like to model the dependent variable as a function of the independent variables, but instead of including latitude and longitude directly in this model, you would like to include some high dimensional representation of what the Earth looks like at that point (that hopefully explains some of the variance in the dependent variable!). From the computer vision literature, there are various [representation learning techniques](https://en.wikipedia.org/wiki/Feature_learning) that can be used to do this, i.e. extract _features vectors_ from imagery. This notebook gives an implementation of the technique described in [Rolf et al. 2021](https://www.nature.com/articles/s41467-021-24638-z), "A generalizable and accessible approach to machine learning with global satellite imagery" called Multi-task Observation using Satellite Imagery & Kitchen Sinks (**MOSAIKS**). For more information about **MOSAIKS** see the [project's webpage](http://www.globalpolicy.science/mosaiks).

### Environment setup
This notebook works with or without an API key, but you will be given more permissive access to the data with an API key.
- If you're running this on the [Planetary Computer Hub](http://planetarycomputer.microsoft.com/compute), make sure to choose the **GPU - PyTorch** profile when presented with the form to choose your environment.
- The Planetary Computer Hub is pre-configured to use your API key.
- To use your API key locally, set the environment variable `PC_SDK_SUBSCRIPTION_KEY` or use `pc.settings.set_subscription_key(<YOUR API Key>)`.
    
**Notes**:
- This example uses either
    - [sentinel-2-l2a data](https://planetarycomputer.microsoft.com/dataset/sentinel-2-l2a)
    - [landsat-c2-l2 data](https://planetarycomputer.microsoft.com/dataset/landsat-c2-l2)
- The techniques used here apply equally well to other remote-sensing datasets.

In [1]:
# !pip install -q git+https://github.com/geopandas/dask-geopandas
!pip install -q pyhere

In [2]:
import warnings
import time
import os
import gc
import calendar
import re

RASTERIO_BEST_PRACTICES = dict(  # See https://github.com/pangeo-data/cog-best-practices
    CURL_CA_BUNDLE="/etc/ssl/certs/ca-certificates.crt",
    GDAL_DISABLE_READDIR_ON_OPEN="EMPTY_DIR",
    AWS_NO_SIGN_REQUEST="YES",
    GDAL_MAX_RAW_BLOCK_CACHE_SIZE="200000000",
    GDAL_SWATH_SIZE="200000000",
    VSI_CURL_CACHE_SIZE="200000000",
)
os.environ.update(RASTERIO_BEST_PRACTICES)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pyhere import here

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from scipy import ndimage as nd

import rasterio
import rasterio.warp
import rasterio.mask
import shapely.geometry
import geopandas
import dask_geopandas
from dask.distributed import Client

from pystac import Item
import stackstac
import pyproj

warnings.filterwarnings(action="ignore", category=UserWarning, module="torch")
warnings.filterwarnings(action="ignore", category=FutureWarning)
warnings.filterwarnings(action="ignore", category=RuntimeWarning)
warnings.filterwarnings(action="ignore", category=UserWarning)

import pystac_client
import planetary_computer as pc


# Disabling the benchmarking feature with torch.backends.cudnn.benchmark = False 
# causes cuDNN to deterministically select an algorithm, possibly at the cost of reduced performance.
# https://pytorch.org/docs/stable/notes/randomness.html
torch.backends.cudnn.benchmark = False

np.random.seed(42)
torch.manual_seed(42)

import random
random.seed(42)

## Set Parameters

In [3]:
num_features = 1024
country_code = 'ZMB'
use_file = True
# use_file = False

In [4]:
satellite = "landsat-c2-l2"
# bands = ["qa_pixel"]
# bands = ['cloud_qa']
bands = ['red']

In [5]:
if satellite == "landsat-c2-l2":
    resolution = 30
    min_image_edge = 6
else:
    resolution = 10
    min_image_edge = 20

## Create grid and sample points to featurize

In [6]:
if use_file:
    gdf = pd.read_feather(here('data', 'grid', 'ZMB_crop_weights_20k-points.feather'))
    gdf = (
        geopandas
        .GeoDataFrame(
            gdf, 
            geometry = geopandas.points_from_xy(x = gdf.lon, y = gdf.lat), 
            crs='EPSG:4326')
    )
else:
    cell_size = 0.01  # Roughly 1 km
    ### get country shape
    country_file_name = f"data/geo_boundaries/africa_adm0.geojson"
    africa = geopandas.read_file(country_file_name)
    country = africa[africa.adm0_a3 == country_code]
    #### This would be simpler, but throws an error down the line if used 
    # world = geopandas.read_file(geopandas.datasets.get_path('naturalearth_lowres'))
    # country = world.query(f'iso_a3 == "{country_code}"')
    ### Create grid of points
    cell_size = .01  # Very roughly 1 km
    xmin, ymin, xmax, ymax = country.total_bounds
    xs = list(np.arange(xmin, xmax + cell_size, cell_size))
    ys = list(np.arange(ymin, ymax + cell_size, cell_size))
    def make_cell(x, y, cell_size):
        ring = [
            (x, y),
            (x + cell_size, y),
            (x + cell_size, y + cell_size),
            (x, y + cell_size)
        ]
        cell = shapely.geometry.Polygon(ring).centroid
        return cell
    center_points = []
    for x in xs:
        for y in ys:
            cell = make_cell(x, y, cell_size)
            center_points.append(cell)
    ### Put grid into a GeDataFrame for cropping to country shape
    gdf = geopandas.GeoDataFrame({'geometry': center_points}, crs = 'EPSG:4326')
    gdf['lon'], gdf['lat'] = gdf.geometry.x, gdf.geometry.y
    ### Subset to country 
    ### This buffer ensures that no points are take at the border 
    ### which would lead to duplication with neighboring countries
    gdf = gdf[gdf.within(country.unary_union.buffer(-0.005))]
    gdf = gdf[['lon', 'lat', 'geometry']].reset_index(drop = True)
    gdf = gdf.sample(frac = 0.1, random_state=42, ignore_index=False)
    points = gdf[["lon", "lat"]].to_numpy()
pt_len = gdf.shape[0]
gdf.shape

(19598, 4)

In [7]:
NPARTITIONS = 250

ddf = dask_geopandas.from_geopandas(gdf, npartitions=1)
hd = ddf.hilbert_distance().compute()
gdf["hd"] = hd
gdf = gdf.sort_values("hd")

dgdf = dask_geopandas.from_geopandas(gdf, npartitions=NPARTITIONS, sort=False)

del ddf
del hd
del gdf
gc.collect()

110

In [8]:
start_month = 7

year_start = 2022
year_end = 2022

buffer_size = 0.005
cloud_limit = 20

batch_size = 1

workers = os.cpu_count() 

print(
   f"""
    Using:  
        Satellite: {satellite}  
        Pixel Resolution: {resolution}  
        Grid Resolution: {buffer_size * 2} degree squared (WGS84) 
        Cloud Limit: less than {cloud_limit}%  
        Bands: {bands} 
        Points: {pt_len} 
        Number Features: {num_features} features 
        Year Range: {year_start} to {year_end} 
    """
)
# for yr in range(year_start, year_end+1):
    
yr = 2013

features = pd.DataFrame()
ft = []

if (yr == year_start):
    month_range = range(start_month, 13)
else:
    month_range = range(1, 13) 
        
    # for mn in month_range:
        
mn = 1

if mn < 10:
    month = "0"+str(mn)
else:
    month = mn


    Using:  
        Satellite: landsat-c2-l2  
        Pixel Resolution: 30  
        Grid Resolution: 0.01 degree squared (WGS84) 
        Cloud Limit: less than 20%  
        Bands: ['red'] 
        Points: 19598 
        Number Features: 1024 features 
        Year Range: 2022 to 2022 
    


In [9]:
def query(points):
    """
    Find a STAC item for points in the `points` DataFrame

    Parameters
    ----------
    points : geopandas.GeoDataFrame
        A GeoDataFrame

    Returns
    -------
    geopandas.GeoDataFrame
        A new geopandas.GeoDataFrame with a `stac_item` column containing the STAC
        item that covers each point.
    """
    intersects = shapely.geometry.mapping(points.unary_union.convex_hull)

    catalog = pystac_client.Client.open(
        "https://planetarycomputer.microsoft.com/api/stac/v1"
    )
    # Define search date range for query
    ending_day = calendar.monthrange(yr, int(mn))[1]
    search_start = f"{yr}-{month}-1" 
    search_end = f"{yr}-{month}-{ending_day}" 

    # The time frame in which we search for non-cloudy imagery
    search = catalog.search(
        collections=[satellite],  
        intersects=intersects,
        datetime=[search_start, search_end],
        query={"eo:cloud_cover": {"lt": cloud_limit}},
        limit=500,
    )
    ic = search.get_all_items_as_dict()
    features = ic["features"]
    features_d = {item["id"]: item for item in features}
    data = {
        "eo:cloud_cover": [],
        "geometry": [],
    }
    index = []
    for item in features:
        data["eo:cloud_cover"].append(item["properties"]["eo:cloud_cover"])
        data["geometry"].append(shapely.geometry.shape(item["geometry"]))
        index.append(item["id"])
    items = geopandas.GeoDataFrame(data, index=index, geometry="geometry").sort_values(
        "eo:cloud_cover"
    )
    point_list = points.geometry.tolist()
    point_items = []
    # cloud_cover = []
    for point in point_list:
        covered_by = items[items.covers(point)]
        if len(covered_by):
            point_items.append(features_d[covered_by.index[0]])
            # cloud_cover.append(features_d[covered_by.index[0]].get('properties').get('eo:cloud_cover'))
        else:
            # There weren't any scenes matching our conditions for this point (too cloudy)
            point_items.append(None)
            # cloud_cover.append(None)
    return points.assign(stac_item=point_items)

In [10]:
tic = time.time()
print("Matching images to points for: ", mn, "-", yr, sep = "")

with Client(n_workers=16) as client:
    meta = dgdf._meta.assign(stac_item=[])
    df2 = dgdf.map_partitions(query, meta=meta).compute()
    
df3 = df2.dropna(subset=["stac_item"]).reset_index(drop = True)

matching_items = []
for item in df3.stac_item.tolist():
    signed_item = pc.sign(Item.from_dict(item))
    matching_items.append(signed_item)

points = df3[["lon", "lat"]].to_numpy()

print("Found acceptable images for ", 
      points.shape[0], "/", pt_len,
      " points in ", 
      f"{time.time()-tic:0.2f} seconds", 
      sep = "")

Matching images to points for: 1-2013
Found acceptable images for 1857/19598 points in 30.32 seconds


In [11]:
device = torch.device("cuda")

In [35]:
%%time
na_perc = np.zeros((points.shape[0], 1), dtype=float)

tic = time.time()

for i in range(0, len(points)):
    
    lon, lat = points[i]
    fn = matching_items[i]

    stack = stackstac.stack(fn, assets=bands, resolution=resolution)
    x_min, y_min = pyproj.Proj(stack.crs)(lon-buffer_size, lat-buffer_size)
    x_max, y_max = pyproj.Proj(stack.crs)(lon+buffer_size, lat+buffer_size)
    aoi = stack.loc[..., y_max:y_min, x_min:x_max]
    
    data = aoi.data.squeeze()
    na_perc[i] = (np.isnan(data).sum() / (data.shape[0] * data.shape[1]))#.compute()
    
    
#     out_image = aoi.data.squeeze()
#     out_image = torch.from_numpy(out_image.compute()).float()
#     out_image = out_image.to(device)
#     na_perc[i] = ((out_image.isnan()).sum() / out_image.numel()).item()
    
    if i % 500 == 0:
        print(
            f"{i}/{points.shape[0]} -- {i / points.shape[0] * 100:0.2f}%"
            + f" -- {time.time()-tic:0.2f} seconds"
        )
        tic = time.time()

0/1857 -- 0.00% -- 0.44 seconds
500/1857 -- 26.93% -- 59.68 seconds
1000/1857 -- 53.85% -- 56.57 seconds
1500/1857 -- 80.78% -- 60.49 seconds
CPU times: user 1min 20s, sys: 10.6 s, total: 1min 30s
Wall time: 3min 40s


In [36]:
na_perc

array([[0.38455598],
       [0.30632716],
       [0.38348765],
       ...,
       [0.47822823],
       [0.36261261],
       [0.41666667]])

In [30]:
class CustomDataset(Dataset):
    def __init__(self, points, items, buffer=buffer_size):
        self.points = points
        self.items = items
        self.buffer = buffer

    def __len__(self):
        return self.points.shape[0]

    def __getitem__(self, idx):

        lon, lat = self.points[idx]
        fn = self.items[idx]

        if fn is None:
            return None
        else:
            stack = stackstac.stack(fn, assets=bands, resolution=resolution)
            x_min, y_min = pyproj.Proj(stack.crs)(lon-self.buffer, lat-self.buffer)
            x_max, y_max = pyproj.Proj(stack.crs)(lon+self.buffer, lat+self.buffer)
            aoi = stack.loc[..., y_max:y_min, x_min:x_max]
            data = aoi.data.squeeze()
            na_percentage = np.isnan(data).sum() / (data.shape[0] * data.shape[1])
            return na_percentage

In [31]:
dataset = CustomDataset(points, matching_items)

dataloader = DataLoader(
    dataset,
    batch_size=10,
    shuffle=False,
    num_workers=workers * 2,
    collate_fn=lambda x: x,
    pin_memory=False,
    persistent_workers=False,
)

In [32]:
%%time
print("Collecting metadata: ", month, "-", yr, sep = "")

na_perc = np.zeros((points.shape[0], 1), dtype=float)
tic = time.time()
i = 0
for images in dataloader:
    for image in images:
        na_perc[i] = image
        
        if i % 500 == 0:
            print(
                f"{i}/{points.shape[0]} -- {i / points.shape[0] * 100:0.2f}%"
                + f" -- {time.time()-tic:0.2f} seconds"
            )
            tic = time.time()
        i += 1
        

Collecting metadata: 01-2013
0/1857 -- 0.00% -- 1.01 seconds
500/1857 -- 26.93% -- 50.76 seconds
1000/1857 -- 53.85% -- 44.70 seconds
1500/1857 -- 80.78% -- 50.56 seconds
CPU times: user 1min 2s, sys: 9.51 s, total: 1min 11s
Wall time: 3min 2s


In [33]:
na_perc

array([[0.38455598],
       [0.30632716],
       [0.38348765],
       ...,
       [0.47822823],
       [0.36261261],
       [0.41666667]])

In [34]:
workers

4

In [16]:
class CustomDataset(Dataset):
    def __init__(self, points, items, buffer=buffer_size):
        self.points = points
        self.items = items
        self.buffer = buffer

    def __len__(self):
        return self.points.shape[0]

    def __getitem__(self, idx):

        lon, lat = self.points[idx]
        fn = self.items[idx]

        if fn is None:
            return None
        else:
            stack = stackstac.stack(
                fn,
                assets=['qa_pixel'],
                resolution=resolution,
            )
            x_min, y_min = pyproj.Proj(stack.crs)(lon-self.buffer, lat-self.buffer)
            x_max, y_max = pyproj.Proj(stack.crs)(lon+self.buffer, lat+self.buffer)
            aoi = stack.loc[..., y_max:y_min, x_min:x_max]
            data = aoi.compute(
                scheduler="single-threaded"
                )
            out_image = data.data.squeeze()
            out_image = torch.from_numpy(out_image).float()
            return out_image

In [17]:
dataset = CustomDataset(points, matching_items, buffer = .05)

In [None]:
img = dataset[0]
img

In [None]:
# img['band']

In [None]:
# img['raster:bands']

In [None]:
import stackstac
import pystac_client
import pyproj
import planetary_computer as pc
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

lat, lon =  -5, 20

catalog = pystac_client.Client.open('https://planetarycomputer.microsoft.com/api/stac/v1')

search = catalog.search(
    collections=['landsat-c2-l2'],
    intersects=dict(type="Point", coordinates=[lon, lat]),
    datetime="2020-12-01/2021-01-01",
    query={"eo:cloud_cover": {"lt": 25}}
)

items = pc.sign(search)

stack = stackstac.stack(items, assets=["qa_pixel"])

x_utm, y_utm = pyproj.Proj(stack.crs)(lon, lat)

buffer = 2000  # meters

aoi = stack.loc[..., y_utm+buffer:y_utm-buffer, x_utm-buffer:x_utm+buffer]

image = aoi.data.squeeze().compute()

plt.figure(figsize=(10,10))
im = plt.imshow(image, interpolation='none')
values = np.unique(image.ravel())
colors = [im.cmap(im.norm(value)) for value in values]
patches = [mpatches.Patch(color=colors[i], label=f"{values[i]}") for i in range(len(values))]
plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0. )
plt.show()

In [None]:
stack['raster:bands']

In [None]:
dataset = CustomDataset(points, matching_items, buffer = 2)

In [None]:

import matplotlib.patches as mpatches
img = dataset[0]
plt.figure(figsize=(10,10))
im = plt.imshow(img, interpolation='none')

# get the unique values from data
# i.e. a sorted list of all values in data
values = np.unique(img.ravel())

# get the colors of the values, according to the 
# colormap used by imshow
colors = [ im.cmap(im.norm(value)) for value in values]
# create a patch (proxy artist) for every color 
patches = [ mpatches.Patch(color=colors[i], label=f"{values[i]}") for i in range(len(values)) ]
# put those patched as legend-handles into the legend
plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0. )


plt.show()
plt.close()

In [None]:
img

In [None]:
na_perc = ((img.isnan()).sum() / img.numel()).item()
na_perc

In [18]:
dataloader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=workers,
    collate_fn=lambda x: x,
    pin_memory=False,
    persistent_workers=False,
)

In [25]:
%%time
print("Collecting metadata: ", month, "-", yr, sep = "")

x_all = np.zeros((points.shape[0], 1), dtype=float)
tic = time.time()
i = 0
for images in dataloader:
    for image in images:
        if i % 500 == 0:
            print(
                f"{i}/{points.shape[0]} -- {i / points.shape[0] * 100:0.2f}%"
                + f" -- {time.time()-tic:0.2f} seconds"
            )
            tic = time.time()
            
        x_all[i] = ((image.isnan()).sum() / image.numel()).item()
        
        i += 1

Collecting metadata: 05-2022
0/19598 -- 0.00% -- 0.45 seconds
500/19598 -- 2.55% -- 20.19 seconds
1000/19598 -- 5.10% -- 21.35 seconds
1500/19598 -- 7.65% -- 146.39 seconds


KeyboardInterrupt: 

In [26]:
i

1759

In [27]:
x_all

array([[0.],
       [0.],
       [0.],
       ...,
       [0.],
       [0.],
       [0.]])

In [28]:
x_all.min()

0.0

In [29]:
x_all.max()

0.0

In [30]:
x_all.mean()

0.0

In [31]:
df3 = df2.dropna(subset=["stac_item"]).reset_index(drop = True)

In [32]:
df3['stac_item'].apply(pd.Series)['properties']

0        {'gsd': 30, 'created': '2022-06-02T09:16:38.13...
1        {'gsd': 30, 'created': '2022-06-02T09:16:38.13...
2        {'gsd': 30, 'created': '2022-06-02T09:16:38.13...
3        {'gsd': 30, 'created': '2022-06-02T09:16:38.13...
4        {'gsd': 30, 'created': '2022-09-14T06:16:10.62...
                               ...                        
19593    {'gsd': 30, 'created': '2022-09-14T06:15:23.87...
19594    {'gsd': 30, 'created': '2022-09-14T06:15:23.87...
19595    {'gsd': 30, 'created': '2022-09-14T06:15:23.87...
19596    {'gsd': 30, 'created': '2022-09-14T06:15:23.87...
19597    {'gsd': 30, 'created': '2022-09-14T06:15:23.87...
Name: properties, Length: 19598, dtype: object

In [33]:
df3['stac_id'] = df3['stac_item'].apply(pd.Series)['id']
df3['platform'] = df3['stac_item'].apply(pd.Series)['properties'].apply(pd.Series)['platform']
df3['cloud_cover'] = df3['stac_item'].apply(pd.Series)['properties'].apply(pd.Series)['eo:cloud_cover']
df3.drop(['geometry', 'hd', 'stac_item'], axis = 1, inplace = True)
df3[['na_percent', 'year', "month"]] = x_all, yr, mn
df3 = pd.DataFrame(df3)
df3

Unnamed: 0,crop_perc,lon,lat,stac_id,platform,cloud_cover,na_percent,year,month
0,0.201666,22.144878,-16.384232,LC09_L2SP_175072_20220525_02_T1,landsat-9,0.0,0.0,2022,5
1,0.279001,22.124878,-16.384232,LC09_L2SP_175072_20220525_02_T1,landsat-9,0.0,0.0,2022,5
2,0.311719,22.134878,-16.384232,LC09_L2SP_175072_20220525_02_T1,landsat-9,0.0,0.0,2022,5
3,0.419393,22.134878,-16.394232,LC09_L2SP_175072_20220525_02_T1,landsat-9,0.0,0.0,2022,5
4,0.305175,22.104878,-16.324232,LE07_L2SP_176071_20220525_02_T1,landsat-7,0.0,0.0,2022,5
...,...,...,...,...,...,...,...,...,...
19593,0.505651,27.844878,-16.784232,LE07_L2SP_172072_20220522_02_T1,landsat-7,0.0,0.0,2022,5
19594,0.475907,27.854878,-16.784232,LE07_L2SP_172072_20220522_02_T1,landsat-7,0.0,0.0,2022,5
19595,0.437835,27.864878,-16.784232,LE07_L2SP_172072_20220522_02_T1,landsat-7,0.0,0.0,2022,5
19596,0.466984,27.864878,-16.794232,LE07_L2SP_172072_20220522_02_T1,landsat-7,0.0,0.0,2022,5


In [None]:
fn = f'{satellite}_{country_code}_{pt_len/1000:.0f}k-points_meta_{yr}_{mn}.csv'
file_name = here('data', 'random_features', satellite, fn)
fn