<a href="https://colab.research.google.com/github/boothmanrylan/nonStandReplacingDisturbances/blob/colab_dev/nonStandReplacingDisturbances.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/boothmanrylan/nonStandReplacingDisturbances.git
%cd nonStandReplacingDisturbances

In [None]:
import os

import google
from google.colab import auth
import ee
import geemap

In [None]:
auth.authenticate_user()

project = 'api-project-269347469410'
asset_path = f"projects/{project}/assets/rylan-nonstandreplacingdisturbances"

os.environ['GOOGLE_CLOUD_PROJECT'] = project
!gcloud config set project {project}

credentials, _ = google.auth.default()
ee.Initialize(
    credentials,
    project=project,
    # opt_url='https://earthengine-highvolume.googleapis.com',
)

In [None]:
NUM_POINTS = 500  # points per train/test/val group

disturbed_regions = ee.FeatureCollection(f"{asset_path}/my-data/usfs-nsr-disturbances")
buffered_disturbed_regions = disturbed_regions.map(
    lambda x: x.buffer(500, 100).bounds(100)
)
buffered_geometry = buffered_disturbed_regions.geometry(100).dissolve(100)

def split_multipolygon(multipolygon):
    # based on: https://gis.stackexchange.com/a/444779
    size = multipolygon.coordinates().size()
    indices = ee.List.sequence(0, size.subtract(1))

    def grab_polygon(i):
        geom = ee.Geometry.Polygon(multipolygon.coordinates().get(i))
        return ee.Feature(geom, {'id': i, 'area': geom.area(100)})

    return ee.FeatureCollection(indices.map(grab_polygon))

split_geometry = split_multipolygon(buffered_geometry)

# split into approx. 1/3 area to each of train/test/val by sorting by area and
# then extracting every third geometry
split_geometry = split_geometry.sort('area', False)
N = split_geometry.size().subtract(1)

train_indices = ee.List.sequence(0, N, 3)
test_indices = ee.List.sequence(1, N, 3)
val_indices = ee.List.sequence(2, N, 3)

train_regions = split_geometry.filter(ee.Filter.inList('id', train_indices))
test_regions = split_geometry.filter(ee.Filter.inList('id', test_indices))
val_regions = split_geometry.filter(ee.Filter.inList('id', val_indices))

def sample_points(rois):
    disturbed_polys = disturbed_regions.filterBounds(rois)

    # ensure each polygon has at least one samples in it
    specific_disturbed_points = ee.FeatureCollection(disturbed_polys.map(
        lambda x: ee.FeatureCollection.randomPoints(x.geometry(), 3, 42)
    )).flatten()

    # ensure that larger polygons have more than two samples in them
    N = specific_disturbed_points.size()
    print(N.getInfo())
    other_disturbed_points = ee.FeatureCollection.randomPoints(
        disturbed_polys.geometry(),
        N.divide(10).int(),
        42
    )
    disturbed_points = specific_disturbed_points.merge(other_disturbed_points)
    disturbed_points = disturbed_points.map(lambda x: x.set('class', 1))

    # ensure that there is the same number of disturbed as undisturbed samples
    undisturbed_points = ee.FeatureCollection.randomPoints(
        rois.geometry().difference(disturbed_polys),
        disturbed_points.size(),
        42,
    ).map(lambda x: x.set('class', 0))

    return disturbed_points.merge(undisturbed_points)

train_points = sample_points(train_regions)
test_points = sample_points(test_regions)
val_points = sample_points(val_regions)

# disturbance_mask = disturbed_regions.map(
#     lambda x: x.set('foo', 1)
# ).reducetoimage(
#     ['foo'], ee.reducer.first()
# ).unmask().gt(0)
print(train_points.size().getInfo(), test_points.size().getInfo(), val_points.size().getInfo())

In [None]:
# Map = geemap.Map()
Map.addLayer(disturbed_regions, {}, 'Disturbed Regions')
Map.addLayer(split_geometry, {'color': 'white'}, 'ROI')
Map.addLayer(train_points, {'color': 'red'}, 'Train Points')
Map.addLayer(test_points, {'color': 'blue'}, 'Test Points')
Map.addLayer(val_points, {'color': 'yellow'}, 'Val Points')
Map

In [None]:
trial_export_region = ee.Geometry.Rectangle([[-120.49118, 40.033924], [-120.29068, 40.208246]])
trial_export_points = train_points.filterBounds(trial_export_region).randomColumn('random', 42).limit(50, 'random')

In [None]:
from multiprocessing import Pool

from goolge.api_core import retry
import webdataset as wbs
import torch

def create_time_chunk_list(start, end, delta):
    start = ee.Date(start).millis()
    end = ee.Date(end).millis()
    starts = ee.List.sequence(start, end, delta)
    ends = starts.map(lambda x: x.add(delta))
    return starts.zip(ends)


def get_s1(point, start, end, delta):
    """ Creats a median Sentinel1 image for each delta between start and end.

    Based on:
    openmapflow/openmapflow/eo/sentinel1.py:get_image_collection
    and
    openmapflow/openmapflow/eo/sentinel1.py:get_single_image
    from: https://github.com/nasaharvest/openmapflow/tree/main

    Args:
        point: ee.Geometry, used to filterBounds of the complete sentinel1
        start: string, start date in format YYYY-MM-dd
        end: string, end date in format YYYY-MM-dd
        delta: int, chunk length to split total time period into (in ms)

    Returns:
        ee.Image
    """
    col = (
        ee.ImageCollection("COPERNICUS/S1_GRD")
        .filterDate(start, end)
        .filterBounds(point)
        .filter(ee.Filter.eq("instrumentMode", "IW"))
        .filter(ee.Filter.listContains("transmitterReceiverPolarisation", "VV"))
        .filter(ee.Filter.listContains("transmitterReceiverPolarisation", "VH"))
    )

    # want all images to either be descending or ascending, but not both. also
    # want as many observations as possible to take the collection that has the
    # most images in it
    descend_col = col.filter(ee.Filter.eq("orbitProperties_pass", "DESCENDING"))
    ascend_col = col.filter(ee.filter.eq("orbitProperties_pass", "ASCENDING"))
    col = ee.Algorithms.If(
        descend_col.size().gt(ascend_col.size()),
        descend_col,
        ascend_col,
    )

    def process_time_chunk(chunk):
        chunk = ee.List(chunk)
        curr_chunk = col.filterDate(chunk.get(0), chunk.get(1))
        return curr_chunk.median().select(["VV", "VH"])

    time_chunks = create_time_chunk_list(start, end, delta)

    return ee.ImageCollection(time_chunks.map(process_time_chunk)).toBands()


def get_s2(point, start, end, delta):
    """ Gets a quality mosaic Sentinel2 image for each delta between start and end.

    Based on:
    openmapflow/openmapflow/eo/sentinel1.py:get_single_image
    from: https://github.com/nasaharvest/openmapflow/tree/main

    Args:
        point: ee.Geometry, used to filterBounds
        start: string, start date in format YYYY-MM-dd
        end: string, end date in format YYYY-MM-dd

    Returns:
        ee.Image
    """
    col = (
        ee.ImageCollection("COPERNICUS/S2")
        .filterDate(start, end)
        .filterBounds(point)
    )

    def process_time_chunk(chunk):
        chunk = ee.List(chunk)
        cloud_score_plus = ee.ImageCollection("GOOGLE/CLOUD_SCORE_PLUS/V1/S2_HARMONIZED")
        curr_chunk = (
            col.filterDate(chunk.get(0), chunk.get(1))
            .linkCollection(cloud_score_plust, ["cs_cdf"])
        )
        # TODO: is a qualityMosaic better than masking and taking the median?
        return curr_chunk.qualityMosaic("cs_cdf")

    time_chunks = create_time_chunk_list(start, end, delta)

    return ee.ImageCollection(time_chunks.map(process_time_chunk)).toBands()


def get_era5(point, start, end, delta):
    """ Gets an ERA5 image for each delta between start and end.

    All ERA5 images are dated to the first of the month, we choose to use
    the ERA5 image based on the start date of each delta period.

    Based on:
    openmapflow/openmapflow/eo/era5.py:get_single_image
    from: https://github.com/nasaharvest/openmapflow/tree/main

    Args:
        point: ee.Geometry, used to filterBounds
        start: string, start date in format YYYY-MM-dd
        end: string, end date in format YYYY-MM-dd

    Returns:
        ee.Image
    """
    col = (
        ee.ImageCollection("ECMWF/ERA5_LAND/MONTHLY_AGGR")
        .filterDate(start, end)
        .filterBounds(point)
    )

    def process_time_chunk(chunk):
        start = ee.List(chunk).get(0)
        start_year = chunk.get("year")
        start_month = chunk.get("month")
        start_date = ee.Date.fromYMD(start_year, start_month, 1)
        curr_chunk = col.filterDate(
            start_date.advance(-1, "day"),
            start_date.advance(1, "day"),
        )
        return curr_chunk.mean()

    time_chunks = create_time_chunk_list(start, end, delta)

    return ee.ImageCollection(time_chunks.map(process_time_chunk)).toBands()


def get_srtm(*args, **kwargs):
    """ Gets the SRTM DEM with calculated slope.

    The SRTM is a single image that covers the entire globe at a single point in
    time, therefore no need to create an image for each time step.

    Args:
        *args, **kwargs to allow for consistent usage with other data getters

    Returns:
        ee.Image
    """
    elevation = ee.Image("USGS/SRTMGL1_003").rename("elevation")
    slope = ee.Terrain.slope(elevation).rename("slope")
    return elevation.addBands(slope)


def get_dynamic_world(point, start, end, delta):
    """ Gets a mode Dynamic World image for each delta between start and stop.

    Based on:
    presto/presto/dataops/pipelines/dynamicworld.py:DynamicWorldMonthly2020_2021
    from: https://github.com/nasaharvest/presto

    Args:
        point: ee.Geometry, used to filterBounds
        start: string, start date in format YYYY-MM-dd
        end: string, end date in format YYYY-MM-dd

    Returns:
        ee.Image
    """
    col = (
        ee.ImageCollection("GOOGLE/DYNAMICWORLD/V1")
        .filterBounds(point)
        .filterDate(start, end)
    )

    def process_time_chunk(chunk):
        chunk = ee.List(chunk)
        curr_chunk = col.filterDate(chunk.get(0), chunk.get(1))
        return curr_chunk.mode()

    time_chunks = create_time_chunk_list(start, end, delta)

    return ee.ImageCollection(time_chunks.map(process_time_chunk)).toBands()


@retry.Retry()
def process_point_fn(point, start, end, delta, projection=PROJECTION, scale=SCALE):
    point = ee.Feature(point)

    # project the point
    coords = point.geometry(1, projection).getInfo()["coordinates"]

    request = {
        "fileFormat": "NUMPY_NDARRAY",
        "grid": {
            "dimensions": {
                "width": 1,
                "height": 1,
            },
            "affineTransform": {
                "scaleX": scale,
                "shearX": 0,
                "translateX": coords[0],
                "shearY": 0,
                "scaleY": -scale,
                "translateY": coords[1],
            },
            "crsCode": crs
        }
    }

    data = {}
    bands = {}
    data_getters = {
        "s1": get_s1,
        "s2": get_s2,
        "era5": get_era5,
        "srtm": get_srtm,
        "dynamic_world": get_dynamic_world,
    }
    for source, getter in data_gettters.items():
        image = getter(point, start, end, delta)
        request["expression"] = image
        data[source] = torch.as_tensor(ee.data.computePixels(request))
        if source != "dynamic_world":
            bands[f"{source}_bands"] = image.bandNames().getInfo()

    presto_input = construct_single_presto_input(**data, **bands)
    months = create_month_tensor(start, end, delta)
    class_label = point.get("class")

    # TODO also need to save the dates (month) of each input
    # TODO also need to save the latitude/longitude of the point

    # TODO: save the presto input using webdataset

with Pool(MAX_PARALLEL_TASKS) as p:
    p.map(process_point_fn, points_list)