<a href="https://colab.research.google.com/github/boothmanrylan/nonStandReplacingDisturbances/blob/colab_dev/nonStandReplacingDisturbances.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [None]:
!git clone https://github.com/boothmanrylan/presto.git
%cd presto

# replace all install_requires '==' with '>=' to make installing to colab env easier
setup_text = []
with open("setup.py", 'r') as f:
    for line in f.readlines():
        setup_text.append(line.replace("==", ">="))

with open("setup.py", "w") as f:
    for line in setup_text:
        f.write(line)

!pip install -e .
%cd ..

In [None]:
# !git clone https://github.com/boothmanrylan/nonStandReplacingDisturbances.git
# %cd nonStandReplacingDisturbances

In [None]:
import os

import google
from google.colab import auth
import ee
import geemap

In [None]:
auth.authenticate_user()

project = 'api-project-269347469410'
asset_path = f"projects/{project}/assets/rylan-nonstandreplacingdisturbances"

os.environ['GOOGLE_CLOUD_PROJECT'] = project
!gcloud config set project {project}

credentials, _ = google.auth.default()
ee.Initialize(
    credentials,
    project=project,
    opt_url='https://earthengine-highvolume.googleapis.com',
)

# Get points

In [None]:
disturbed_regions = ee.FeatureCollection(f"{asset_path}/my-data/usfs-nsr-disturbances")

# some of the regions had disturbances that were patchy over a large area
# I was lazy when drawing these polygons so hold them out as regions to test
# the model being applied over a large area
largest_disturbed_regions = (
    disturbed_regions
    .map(lambda x: x.set('area', x.area(100)))
    .limit(8, 'area', False)  # holding out 8 creates 5 regions after buffer/dissolve
).geometry(100)

buffered_disturbed_regions = disturbed_regions.map(
    lambda x: x.buffer(1000, 100).bounds(100)
)
buffered_geometry = buffered_disturbed_regions.geometry(100).dissolve(100)

def split_multipolygon(multipolygon):
    # based on: https://gis.stackexchange.com/a/444779
    size = multipolygon.coordinates().size()
    indices = ee.List.sequence(0, size.subtract(1))

    def grab_polygon(i):
        geom = ee.Geometry.Polygon(multipolygon.coordinates().get(i))
        return ee.Feature(geom, {'id': i, 'area': geom.area(100)})

    return ee.FeatureCollection(indices.map(grab_polygon))

split_geometry = split_multipolygon(buffered_geometry)
largest_regions = split_geometry.filterBounds(largest_disturbed_regions)
split_geometry = split_geometry.filter(ee.Filter.bounds(largest_disturbed_regions).Not())

split_geometry = split_geometry.randomColumn('random', 42)
train_regions = split_geometry.filter(ee.Filter.lte('random', 0.333))
test_regions = split_geometry.filter(ee.Filter.And(
    ee.Filter.gt('random', 0.333),
    ee.Filter.lte('random', 0.666),
))
val_regions = split_geometry.filter(ee.Filter.gt('random', 0.666))

def sample_points(rois):
    disturbed_polys = disturbed_regions.filterBounds(rois)

    # ensure each polygon has at least three samples in it
    specific_disturbed_points = ee.FeatureCollection(disturbed_polys.map(
        lambda x: ee.FeatureCollection.randomPoints(x.geometry(), 3, 42)
    )).flatten()

    # ensure that larger polygons have more samples in them
    N = specific_disturbed_points.size()
    other_disturbed_points = ee.FeatureCollection.randomPoints(
        disturbed_polys.geometry(),
        N.divide(10).int(),
        42
    )
    disturbed_points = specific_disturbed_points.merge(other_disturbed_points)
    disturbed_points = disturbed_points.map(lambda x: x.set('class', 1))

    # ensure that there is the same number of disturbed as undisturbed samples
    undisturbed_points = ee.FeatureCollection.randomPoints(
        rois.geometry().difference(disturbed_polys),
        disturbed_points.size(),
        42,
    ).map(lambda x: x.set('class', 0))

    return disturbed_points.merge(undisturbed_points)

train_points = sample_points(train_regions)
test_points = sample_points(test_regions)
val_points = sample_points(val_regions)

print(train_points.size().getInfo(), test_points.size().getInfo(), val_points.size().getInfo())

In [None]:
Map = geemap.Map()
Map.addLayer(disturbed_regions, {}, 'Disturbed Regions')
Map.addLayer(largest_regions, {"color": "purple"}, 'Largest Disturbed Regions')
Map.addLayer(split_geometry, {'color': 'white'}, 'Remaining Disturbed Regions')
Map.addLayer(train_points, {'color': 'red'}, 'Train Points')
Map.addLayer(test_points, {'color': 'blue'}, 'Test Points')
Map.addLayer(val_points, {'color': 'yellow'}, 'Val Points')
Map

# Create Presto Inputs

In [None]:
def create_time_chunk_list(
    start,
    end,
    delta=1,
    exclude_months=None,
    max_chunks=24,
    extend_past_end=False,
):
    """ Create a list of start, stop pairs that can be used in filterDate

    Regardless of what day of the month is given for start, start is
    always set to the first of the given month. End is always set to the
    first of the following month unless end is already the first of a mont
    (end is exclusive in all ee filterDate methods). E.g.,
    "2021-01-15"-"2021-06-30" would treated as "2021-01-01"-"2021-07-01".

    Args:
        start: string, start date in format YYYY-MM-dd.
        end: string, end date in format YYYY-MM-dd.
        delta: int, length (in months) for each time chunk.
        exclude_months, List[int], exclude all time chunks that contain any
            months in this list. Uses the ee convention of months being 1-12.
        max_chunks: int, maximum number of time chunks to return. If set and the
            number of time chunks created is greater than max_chunks, earlier
            time chunks will be dropped so that the number of time chunks
            returned is equal to max_chunks.
        extend_past_end: bool, if True always the final time chunk to extend
            past the end date.

    Returns:
        ee.List[ee.List[ee.Date]], start, stop date pairs
    """
    start = ee.Date(start)
    start = ee.Date.fromYMD(start.get("year"), start.get("month"), 1)

    end = ee.Date(end)
    if end.get("day").getInfo() != 1:
        end = ee.Date.fromYMD(end.get("year"), end.get("month").add(1), 1)

    if extend_past_end:
        total_months = end.difference(start, "month").ceil()
        starts = ee.List.sequence(0, total_months.add(delta), delta)
    else:
        total_months = end.difference(start, "month").floor()
        starts = ee.List.sequence(0, total_months, delta)

    starts = starts.map(lambda x: start.advance(x, "month"))
    ends = starts.map(lambda x: ee.Date(x).advance(delta, "month"))
    chunks = starts.zip(ends)

    if exclude_months is None:
        return chunks

    # earth engine doesnt allow filtering on an arbitrary function that returns
    # a bool, so create a hacky feature that has a property we can filter on
    def convert_to_features(startstop):
        startstop = ee.List(startstop)
        start = ee.Date(startstop.get(0))
        advances = ee.List.sequence(0, delta)
        months = advances.map(lambda x: start.advance(x, "month").get("month"))
        return ee.Feature(None, {"startstop": startstop, "months": months})

    chunks = ee.FeatureCollection(chunks.map(convert_to_features))

    filters = [
        ee.Filter.listContains(leftField="months", rightValue=x)
        for x in exclude_months
    ]
    valid_month_filter = ee.Filter.Or(*filters).Not()
    valid_chunks = chunks.filter(valid_month_filter)
    chunks = valid_chunks.aggregate_array("startstop")

    if max_chunks is None:
        return chunks

    offset = chunks.length().subtract(max_chunks)
    return chunks.slice(offset)

In [None]:
from multiprocessing import Pool

import numpy as np
from numpy.lib.recfunctions import structured_to_unstructured
from google.api_core import retry
import webdataset as wds
import torch
from presto import presto

SCALE = 10
CRS = "EPSG:3857"
PROJECTION = ee.Projection(CRS)
MILLIS_PER_MONTH = 30 * 24 * 60 * 60 * 1000
MAX_PARALLEL_TASKS = 40
S2_BANDS = ["B1", "B2", "B3", "B4", "B5", "B6", "B7",
            "B8", "B8A", "B10", "B11", "B12"]
S1_BANDS = ["VV", "VH"]
ERA5_BANDS = ["temperature_2m", "total_precipitation"]
SRTM_BANDS = ["elevation", "slope"]
DYNAMIC_WORLD_BANDS = ["label"]

In [None]:
def get_s1(point, time_chunks):
    """ Creats a median Sentinel1 image for each period in time_chunks.

    Based on:
    openmapflow/openmapflow/eo/sentinel1.py:get_image_collection
    and
    openmapflow/openmapflow/eo/sentinel1.py:get_single_image
    from: https://github.com/nasaharvest/openmapflow/tree/main

    Args:
        point: ee.Geometry, used to filterBounds of the complete sentinel1
        time_chunks: ee.List[ee.List[ee.Date]], list of start stop date pairs.
            Can be generate with create_time_chunk_list

    Returns:
        ee.Image
    """
    true_start = ee.Date(ee.List(time_chunks.get(0)).get(0))
    true_end = ee.Date(ee.List(time_chunks.get(-1)).get(1))
    col = (
        ee.ImageCollection("COPERNICUS/S1_GRD")
        .filterDate(true_start, true_end)
        .filterBounds(point)
        .filter(ee.Filter.eq("instrumentMode", "IW"))
        .filter(ee.Filter.listContains("transmitterReceiverPolarisation", "VV"))
        .filter(ee.Filter.listContains("transmitterReceiverPolarisation", "VH"))
    )

    # want all images to either be descending or ascending, but not both. also
    # want as many observations as possible, so take the collection that has the
    # most images in it
    descend_col = col.filter(ee.Filter.eq("orbitProperties_pass", "DESCENDING"))
    ascend_col = col.filter(ee.Filter.eq("orbitProperties_pass", "ASCENDING"))
    col = ee.ImageCollection(ee.Algorithms.If(
        descend_col.size().gt(ascend_col.size()),
        descend_col,
        ascend_col,
    ))

    def process_time_chunk(chunk):
        chunk = ee.List(chunk)
        start = ee.Date(chunk.get(0))
        stop = ee.Date(chunk.get(1))
        curr_chunk = col.filterDate(start, stop).median()
        # in case there are no images in chunk back fill with previous 3 months median
        # prev = col.filterDate(start.advance(-6, "month"), stop).median()
        # backfilled_chunk = ee.ImageCollection([curr_chunk, prev]).mosaic().select(S1_BANDS)
        # return backfilled_chunk.toLong()
        return curr_chunk.select(S1_BANDS)

    return ee.ImageCollection(time_chunks.map(process_time_chunk)).toBands(), S1_BANDS

In [None]:
def get_s2(point, time_chunks):
    """ Gets a quality mosaic Sentinel2 image for each period in time_chunks.

    Based on:
    openmapflow/openmapflow/eo/sentinel1.py:get_single_image
    from: https://github.com/nasaharvest/openmapflow/tree/main

    Args:
        point: ee.Geometry, used to filterBounds
        time_chunks: ee.List[ee.List[ee.Date]], list of start stop date pairs.
            Can be generated with create_time_chunk_list

    Returns:
        ee.Image
    """
    col = (
        ee.ImageCollection("COPERNICUS/S2")
        .filterBounds(point)
    )

    def process_time_chunk(chunk):
        chunk = ee.List(chunk)
        cloud_score_plus = ee.ImageCollection("GOOGLE/CLOUD_SCORE_PLUS/V1/S2_HARMONIZED")
        curr_chunk = (
            col.filterDate(chunk.get(0), chunk.get(1))
            .linkCollection(cloud_score_plus, ["cs_cdf"])
        )
        # TODO: is a qualityMosaic better than masking and taking the median?
        return curr_chunk.qualityMosaic("cs_cdf").select(S2_BANDS).toLong()

    return ee.ImageCollection(time_chunks.map(process_time_chunk)).toBands(), S2_BANDS

In [None]:
def get_era5(point, time_chunks):
    """ Gets a mean ERA5 image for each period in time_chunks.

    All ERA5 images are dated to the first of the month, we choose to use
    the ERA5 image based on the start date of each delta period.

    Only returns the ERA5 temperature_2m and total_precipitation_sum bands.

    Based on:
    openmapflow/openmapflow/eo/era5.py:get_single_image
    from: https://github.com/nasaharvest/openmapflow/tree/main

    Args:
        point: ee.Geometry, used to filterBounds
        time_chunks: ee.List[ee.List[ee.Date]], list of start stop date pairs.
            Can be generated with create_time_chunk_list

    Returns:
        ee.Image
    """
    col = (
        ee.ImageCollection("ECMWF/ERA5_LAND/MONTHLY_AGGR")
        .filterBounds(point)
    )

    def process_time_chunk(chunk):
        start = ee.Date(ee.List(chunk).get(0)).advance(-1, "day")
        end = ee.Date(ee.List(chunk).get(1))

        curr_chunk = col.filterDate(start, end).mean()

        temp = curr_chunk.select("temperature_2m")
        percip = curr_chunk.select("total_precipitation_sum")
        return ee.Image.cat([temp, percip]).toLong()

    return ee.ImageCollection(time_chunks.map(process_time_chunk)).toBands(), ERA5_BANDS

In [None]:
def get_srtm(*args, **kwargs):
    """ Gets the SRTM DEM with calculated slope.

    The SRTM is a single image that covers the entire globe at a single point in
    time, therefore no need to create an image for each time step.

    Args:
        *args, **kwargs to allow for consistent usage with other data getters

    Returns:
        ee.Image
    """
    elevation = ee.Image("USGS/SRTMGL1_003").rename("elevation")
    slope = ee.Terrain.slope(elevation).rename("slope")
    return ee.Image.cat([elevation, slope]).toLong(), SRTM_BANDS

In [None]:
def get_dynamic_world(point, time_chunks):
    """ Gets a mode Dynamic World image for each delta between start and stop.

    Based on:
    presto/presto/dataops/pipelines/dynamicworld.py:DynamicWorldMonthly2020_2021
    from: https://github.com/nasaharvest/presto

    Args:
        point: ee.Geometry, used to filterBounds
        time_chunks: ee.List[ee.List[ee.Date]], list of start stop date pairs.
            Can be generated with create_time_chunk_list

    Returns:
        ee.Image
    """
    col = (
        ee.ImageCollection("GOOGLE/DYNAMICWORLD/V1")
        .filterBounds(point)
    )

    def process_time_chunk(chunk):
        chunk = ee.List(chunk)
        start = ee.Date(chunk.get(0))
        end = ee.Date(chunk.get(1))

        # in case there is no data in the current chunk replace with the mode
        # class label from the three months prior to the current chunk
        previous_year_data = col.filterDate(start.advance(-3, "month"), end)
        previous_year_mode = previous_year_data.mode().select("label")

        curr_chunk = (
            col.filterDate(start, end)
            .mode()
            .select("label")
            .unmask(previous_year_mode)
        )

        # in case there is also no data in the previous three months replace with 9
        curr_chunk = curr_chunk.unmask(9)

        return curr_chunk

    return ee.ImageCollection(time_chunks.map(process_time_chunk)).toBands(), DYNAMIC_WORLD_BANDS

# Multiprocessing export

In [None]:
def create_month_tensor(time_chunks):
    # subtract 1 b/c Presto expects months to 0 - 11 but ee return 1 - 12
    start_months = time_chunks.map(
        lambda x: ee.Date(ee.List(x).get(0)).get("month").subtract(1)
    )
    return torch.as_tensor(start_months.getInfo())


@retry.Retry()
def process_point_fn(
    index,
    point_list,
    time_chunks,
    crs=CRS,
    scale=SCALE,
):
    """ Creates Presto input for given point.

    Uses ee.data.computePixels to fetch data from Earth Engine.

    Can be used as the funtion in multiprocessing.Pool.map to make GEE requests
    in parallel.

    Args:
        index: int, index of current point to process
        point_list: ee.List, all points to process
        time_chunks: ee.List[ee.List[Date]], list of start stop date pairs
        start: string, start date in format YYYY-MM-dd
        end: string, end date in format YYYY-MM-dd
        delta: int, chunk length to split total time period into (in ms).
        crs: string, EPSG crs code defining the projection to get the data in.
        scale: int, scale to get the data in.

    Returns:
        None
    """
    point = ee.Feature(point_list.get(index))
    lon, lat = point.geometry().getInfo()["coordinates"]

    # project the point to the *unscaled* projection
    projection = ee.Projection(crs)
    coords = point.geometry(1, projection).getInfo()["coordinates"]

    request = {
        "fileFormat": "NUMPY_NDARRAY",
        "grid": {
            "dimensions": {
                "width": 1,
                "height": 1,
            },
            "affineTransform": {
                "scaleX": scale,
                "shearX": 0,
                "translateX": coords[0],
                "shearY": 0,
                "scaleY": -scale,
                "translateY": coords[1],
            },
            "crsCode": crs
        }
    }

    data = {}
    bands = {}
    data_getters = {
        "s1": get_s1,
        "s2": get_s2,
        "era5": get_era5,
        "srtm": get_srtm,
        "dynamic_world": get_dynamic_world,
    }
    for source, get_image_fn in data_getters.items():
        image, band_names = get_image_fn(point.geometry(), time_chunks)
        request["expression"] = image
        raw = ee.data.computePixels(request)
        arr = structured_to_unstructured(raw).reshape(-1, len(band_names))
        arr = arr.astype(np.float32, copy=False)

        if source == "dynamic_world":
            data[source] = torch.as_tensor(arr.squeeze(), dtype=torch.int)
        else:
            bands[f"{source}_bands"] = band_names
            data[source] = torch.as_tensor(arr, dtype=torch.float32)

    # repeat srtm to be the same shape as all other inputs
    num_timesteps = data["dynamic_world"].shape[0]
    data["srtm"] = data["srtm"].expand(num_timesteps, -1)

    x, mask, dw = presto.construct_single_presto_input(**data, **bands)
    latlon = torch.tensor([lat, lon])
    months = create_month_tensor(time_chunks)

    return {
        "__key__": f"sample{index:6d}",
        "inputs.pth": x,
        "masks.pth": mask,
        "dynamic_worlds.pth": dw,
        "latlons.pth": latlon,
        "months.pth": months,
        "cls": point.getNumber('class').getInfo(),
    }



In [None]:
col = train_points.randomColumn('random', 42).limit(50, 'random')
chunks = create_time_chunk_list("2019-01-01", "2024-04-01", 1, [10, 11, 12, 1, 2, 3], 24)
output_path = "trial_dataset.tar"

N = col.size().getInfo()
indices = list(range(N))

# create named function because multiprocessing requires the function to be
# pickleable and lambda functions are not
def fn(index):
    return process_point_fn(index, col.toList(N), chunks)

with Pool(MAX_PARALLEL_TASKS) as p:
    outputs = p.map(fn, indices)

with wds.TarWriter(output_path) as sink:
    for output in outputs:
        sink.write(output)

# Run downstream task

In [None]:
raw_dataset = wds.WebDataset("trial_dataset.tar")
decoded_dataset = raw_dataset.decode().to_tuple(
    "cls", "inputs.pth", "masks.pth", "dynamic_worlds.pth", "latlons.pth", "months.pth",
)
# dataloader = torch.utils.data.DataLoader(decoded_dataset, batch_size=10, shuffle=False)

In [None]:
from tqdm import tqdm

pretrained_model = presto.Presto.load_pretrained()
features_list = []
class_list = []
for (class_label, x, mask, dw, latlons, month) in tqdm(dataloader):
    with torch.no_grad():
        encodings = pretrained_model.encoder(
            x, dynamic_world=dw, mask=mask, latlons=latlons, month=month
        ).cpu().numpy()
        features_list.append(encodings)
        class_list.append(class_label)

train_features = np.concatenate(features_list[:3])
train_labels = np.concatenate(class_list[:3])
test_features = np.concatenate(features_list[3:])
test_labels = np.concatenate(class_list[3:])

In [None]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import f1_score

model = RandomForestClassifier(class_weight="balanced", random_state=42)
model.fit(train_features, train_labels)
predictions = np.argmax(model.predict_proba(test_features), 1)
f1_score(test_labels, predictions)