<a href="https://colab.research.google.com/github/boothmanrylan/nonStandReplacingDisturbances/blob/colab_dev/nonStandReplacingDisturbances.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/boothmanrylan/presto.git
%cd presto

# replace all install_requires '==' with '>=' to make installing to colab env easier
setup_text = []
with open("setup.py", 'r') as f:
    for line in f.readlines():
        setup_text.append(line.replace("==", ">="))

with open("setup.py", "w") as f:
    for line in setup_text:
        f.write(line)

!pip install -e .
%cd ..

In [None]:
!git clone https://github.com/boothmanrylan/nonStandReplacingDisturbances.git
%cd nonStandReplacingDisturbances

In [None]:
import os

import google
from google.colab import auth
import ee
import geemap

In [None]:
auth.authenticate_user()

project = 'api-project-269347469410'
asset_path = f"projects/{project}/assets/rylan-nonstandreplacingdisturbances"

os.environ['GOOGLE_CLOUD_PROJECT'] = project
!gcloud config set project {project}

credentials, _ = google.auth.default()
ee.Initialize(
    credentials,
    project=project,
    # opt_url='https://earthengine-highvolume.googleapis.com',
)

In [None]:
NUM_POINTS = 500  # points per train/test/val group

disturbed_regions = ee.FeatureCollection(f"{asset_path}/my-data/usfs-nsr-disturbances")
buffered_disturbed_regions = disturbed_regions.map(
    lambda x: x.buffer(500, 100).bounds(100)
)
buffered_geometry = buffered_disturbed_regions.geometry(100).dissolve(100)

def split_multipolygon(multipolygon):
    # based on: https://gis.stackexchange.com/a/444779
    size = multipolygon.coordinates().size()
    indices = ee.List.sequence(0, size.subtract(1))

    def grab_polygon(i):
        geom = ee.Geometry.Polygon(multipolygon.coordinates().get(i))
        return ee.Feature(geom, {'id': i, 'area': geom.area(100)})

    return ee.FeatureCollection(indices.map(grab_polygon))

split_geometry = split_multipolygon(buffered_geometry)

# split into approx. 1/3 area to each of train/test/val by sorting by area and
# then extracting every third geometry
split_geometry = split_geometry.sort('area', False)
N = split_geometry.size().subtract(1)

train_indices = ee.List.sequence(0, N, 3)
test_indices = ee.List.sequence(1, N, 3)
val_indices = ee.List.sequence(2, N, 3)

train_regions = split_geometry.filter(ee.Filter.inList('id', train_indices))
test_regions = split_geometry.filter(ee.Filter.inList('id', test_indices))
val_regions = split_geometry.filter(ee.Filter.inList('id', val_indices))

def sample_points(rois):
    disturbed_polys = disturbed_regions.filterBounds(rois)

    # ensure each polygon has at least one samples in it
    specific_disturbed_points = ee.FeatureCollection(disturbed_polys.map(
        lambda x: ee.FeatureCollection.randomPoints(x.geometry(), 3, 42)
    )).flatten()

    # ensure that larger polygons have more than two samples in them
    N = specific_disturbed_points.size()
    print(N.getInfo())
    other_disturbed_points = ee.FeatureCollection.randomPoints(
        disturbed_polys.geometry(),
        N.divide(10).int(),
        42
    )
    disturbed_points = specific_disturbed_points.merge(other_disturbed_points)
    disturbed_points = disturbed_points.map(lambda x: x.set('class', 1))

    # ensure that there is the same number of disturbed as undisturbed samples
    undisturbed_points = ee.FeatureCollection.randomPoints(
        rois.geometry().difference(disturbed_polys),
        disturbed_points.size(),
        42,
    ).map(lambda x: x.set('class', 0))

    return disturbed_points.merge(undisturbed_points)

train_points = sample_points(train_regions)
test_points = sample_points(test_regions)
val_points = sample_points(val_regions)

# disturbance_mask = disturbed_regions.map(
#     lambda x: x.set('foo', 1)
# ).reducetoimage(
#     ['foo'], ee.reducer.first()
# ).unmask().gt(0)
print(train_points.size().getInfo(), test_points.size().getInfo(), val_points.size().getInfo())

In [None]:
Map = geemap.Map()
Map.addLayer(disturbed_regions, {}, 'Disturbed Regions')
Map.addLayer(split_geometry, {'color': 'white'}, 'ROI')
Map.addLayer(train_points, {'color': 'red'}, 'Train Points')
Map.addLayer(test_points, {'color': 'blue'}, 'Test Points')
Map.addLayer(val_points, {'color': 'yellow'}, 'Val Points')
Map

In [None]:
trial_export_region = ee.Geometry.Rectangle([[-120.49118, 40.033924], [-120.29068, 40.208246]])
trial_export_points = train_points.filterBounds(trial_export_region).randomColumn('random', 42).limit(50, 'random')

In [None]:
from multiprocessing import Pool

import numpy as np
from numpy.lib.recfunctions import structured_to_unstructured
from google.api_core import retry
import webdataset as wds
import torch
from presto import presto

SCALE = 10
CRS = "EPSG:3857"
PROJECTION = ee.Projection(CRS)
MILLIS_PER_MONTH = 30 * 24 * 60 * 60 * 1000
MAX_PARALLEL_TASKS = 40
S2_BANDS = ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8",
            "B8A", "B9", "B10", "B11", "B12"]

# TODO: do we want to exclude winter months (or allow for winter months to be excluded)
def create_time_chunk_list(start, end, delta):
    start = ee.Date(start).millis()
    end = ee.Date(end).millis()
    starts = ee.List.sequence(start, end, delta)
    ends = starts.map(lambda x: ee.Number(x).add(delta))
    return starts.zip(ends)

# TODO: add normalization (scaling and offsets) to each data getter
# see presto/presto/dataops/pipelines/s1_s2_era5_srtm.py normalize method

def get_s1(point, start, end, delta):
    """ Creats a median Sentinel1 image for each delta between start and end.

    Based on:
    openmapflow/openmapflow/eo/sentinel1.py:get_image_collection
    and
    openmapflow/openmapflow/eo/sentinel1.py:get_single_image
    from: https://github.com/nasaharvest/openmapflow/tree/main

    Args:
        point: ee.Geometry, used to filterBounds of the complete sentinel1
        start: string, start date in format YYYY-MM-dd
        end: string, end date in format YYYY-MM-dd
        delta: int, chunk length to split total time period into (in ms)

    Returns:
        ee.Image
    """
    col = (
        ee.ImageCollection("COPERNICUS/S1_GRD")
        .filterDate(start, end)
        .filterBounds(point)
        .filter(ee.Filter.eq("instrumentMode", "IW"))
        .filter(ee.Filter.listContains("transmitterReceiverPolarisation", "VV"))
        .filter(ee.Filter.listContains("transmitterReceiverPolarisation", "VH"))
    )

    # want all images to either be descending or ascending, but not both. also
    # want as many observations as possible to take the collection that has the
    # most images in it
    descend_col = col.filter(ee.Filter.eq("orbitProperties_pass", "DESCENDING"))
    ascend_col = col.filter(ee.Filter.eq("orbitProperties_pass", "ASCENDING"))
    col = ee.ImageCollection(ee.Algorithms.If(
        descend_col.size().gt(ascend_col.size()),
        descend_col,
        ascend_col,
    ))

    def process_time_chunk(chunk):
        chunk = ee.List(chunk)
        curr_chunk = col.filterDate(chunk.get(0), chunk.get(1))
        return curr_chunk.median().select(["VV", "VH"])

    time_chunks = create_time_chunk_list(start, end, delta)

    return ee.ImageCollection(time_chunks.map(process_time_chunk)).toBands(), 2


def get_s2(point, start, end, delta):
    """ Gets a quality mosaic Sentinel2 image for each delta between start and end.

    Based on:
    openmapflow/openmapflow/eo/sentinel1.py:get_single_image
    from: https://github.com/nasaharvest/openmapflow/tree/main

    Args:
        point: ee.Geometry, used to filterBounds
        start: string, start date in format YYYY-MM-dd
        end: string, end date in format YYYY-MM-dd

    Returns:
        ee.Image
    """
    col = (
        ee.ImageCollection("COPERNICUS/S2")
        .filterDate(start, end)
        .filterBounds(point)
    )

    def process_time_chunk(chunk):
        chunk = ee.List(chunk)
        cloud_score_plus = ee.ImageCollection("GOOGLE/CLOUD_SCORE_PLUS/V1/S2_HARMONIZED")
        curr_chunk = (
            col.filterDate(chunk.get(0), chunk.get(1))
            .linkCollection(cloud_score_plus, ["cs_cdf"])
        )
        # TODO: is a qualityMosaic better than masking and taking the median?
        return curr_chunk.qualityMosaic("cs_cdf").select(S2_BANDS)

    time_chunks = create_time_chunk_list(start, end, delta)

    return ee.ImageCollection(time_chunks.map(process_time_chunk)).toBands(), len(S2_BANDS)


def get_era5(point, start, end, delta):
    """ Gets an ERA5 image for each delta between start and end.

    All ERA5 images are dated to the first of the month, we choose to use
    the ERA5 image based on the start date of each delta period.

    Based on:
    openmapflow/openmapflow/eo/era5.py:get_single_image
    from: https://github.com/nasaharvest/openmapflow/tree/main

    Args:
        point: ee.Geometry, used to filterBounds
        start: string, start date in format YYYY-MM-dd
        end: string, end date in format YYYY-MM-dd

    Returns:
        ee.Image
    """
    col = (
        ee.ImageCollection("ECMWF/ERA5_LAND/MONTHLY_AGGR")
        .filterDate(start, end)
        .filterBounds(point)
    )

    def process_time_chunk(chunk):
        start = ee.Date(ee.List(chunk).get(0))
        start_year = start.get("year")
        start_month = start.get("month")
        start_date = ee.Date.fromYMD(start_year, start_month, 1)
        curr_chunk = col.filterDate(
            start_date.advance(-1, "day"),
            start_date.advance(1, "day"),
        )
        return curr_chunk.select(
            ["temperature_2m", "total_precipitation_sum"]
        ).rename(
            ["temperature_2m", "total_precipitation"]  # match naming convention of Presto
        ).mean()

    time_chunks = create_time_chunk_list(start, end, delta)

    return ee.ImageCollection(time_chunks.map(process_time_chunk)).toBands(), 2


def get_srtm(*args, **kwargs):
    """ Gets the SRTM DEM with calculated slope.

    The SRTM is a single image that covers the entire globe at a single point in
    time, therefore no need to create an image for each time step.

    Args:
        *args, **kwargs to allow for consistent usage with other data getters

    Returns:
        ee.Image
    """
    elevation = ee.Image("USGS/SRTMGL1_003").rename("elevation")
    slope = ee.Terrain.slope(elevation).rename("slope")
    return elevation.addBands(slope), 2


def get_dynamic_world(point, start, end, delta):
    """ Gets a mode Dynamic World image for each delta between start and stop.

    Based on:
    presto/presto/dataops/pipelines/dynamicworld.py:DynamicWorldMonthly2020_2021
    from: https://github.com/nasaharvest/presto

    Args:
        point: ee.Geometry, used to filterBounds
        start: string, start date in format YYYY-MM-dd
        end: string, end date in format YYYY-MM-dd

    Returns:
        ee.Image
    """
    col = (
        ee.ImageCollection("GOOGLE/DYNAMICWORLD/V1")
        .filterBounds(point)
        .filterDate(start, end)
    )

    def process_time_chunk(chunk):
        chunk = ee.List(chunk)
        curr_chunk = col.filterDate(chunk.get(0), chunk.get(1))
        return curr_chunk.select("label").mode()

    time_chunks = create_time_chunk_list(start, end, delta)

    return ee.ImageCollection(time_chunks.map(process_time_chunk)).toBands(), 1


def create_month_tensor(start, end, delta):
    chunks = create_time_chunk_list(start, end, delta)
    start_months = chunks.map(lambda x: ee.Date(ee.List(x).get(0)).get("month"))
    return torch.as_tensor(start_months.getInfo())


@retry.Retry()
def process_point_fn(
    point,
    sink,
    start="2021-01-01",
    end="2024-01-01",
    delta=MILLIS_PER_MONTH,
    crs=CRS,
    scale=SCALE,
):
    """ Creates Presto input for given point and writes it to sink.

    Uses ee.data.computePixels to fetch data from Earth Engine.

    Can be used as the funtion in multiprocessing.Pool.map to make GEE requests
    in parallel.

    Args:
        point: list[Number], should be the points index, the latitude of the
            point, the longitude of the point, and the class label of the point.
            Passing a FeatureCollection to create_points_list will convert the
            FeatureCollection to the proper format to map this function over.
        sink: an opened wds.TarWriter, where the results get saved to.
        start: string, start date in format YYYY-MM-dd
        end: string, end date in format YYYY-MM-dd
        delta: int, chunk length to split total time period into (in ms).
        crs: string, EPSG crs code defining the projection to get the data in.
        scale: int, scale to get the data in.

    Returns:
        None
    """
    index, lat, lon, class_label = point
    point = ee.Feature(ee.Geometry.Point((lon, lat)))

    # project the point to the *unscaled* projection
    projection = ee.Projection(crs)
    coords = point.geometry(1, projection).getInfo()["coordinates"]

    request = {
        "fileFormat": "NUMPY_NDARRAY",
        "grid": {
            "dimensions": {
                "width": 1,
                "height": 1,
            },
            "affineTransform": {
                "scaleX": scale,
                "shearX": 0,
                "translateX": coords[0],
                "shearY": 0,
                "scaleY": -scale,
                "translateY": coords[1],
            },
            "crsCode": crs
        }
    }

    data = {}
    bands = {}
    data_getters = {
        "s1": get_s1,
        "s2": get_s2,
        "era5": get_era5,
        "srtm": get_srtm,
        "dynamic_world": get_dynamic_world,
    }
    for source, get_image_fn in data_getters.items():
        image, n_bands = get_image_fn(point.geometry(), start, end, delta)
        request["expression"] = image
        raw = ee.data.computePixels(request)
        arr = structured_to_unstructured(raw).reshape(-1, n_bands)
        arr = arr.astype(np.float64, copy=False)

        if source == "dynamic_world":
            data[source] = torch.as_tensor(arr.squeeze())
        else:
            bands[f"{source}_bands"] = image.bandNames().getInfo()
            data[source] = torch.as_tensor(arr)

    # repeat srtm to be the same shape as all other inputs
    num_timesteps = data["dynamic_world"].shape[0]
    data["srtm"] = data["srtm"].expand(num_timesteps, -1)

    x, mask, dw = presto.construct_single_presto_input(**data, **bands)
    latlon = [lat, lon]
    months = create_month_tensor(start, end, delta)

    inputs = {"x": x, "mask": mask, "dw": dw, "latlon": latlon, "month": months}

    sink.write({
        "__key__": f"sample{index:6d}",
        "inputs.pyd": inputs,
        "cls": class_label,
    })

def make_points_list(col):
    points = trial_export_points.getInfo()['features']
    class_labels = [x["properties"]["class"] for x in points]
    lons = [x["geometry"]["coordinates"][0] for x in points]
    lats = [x["geometry"]["coordinates"][1] for x in points]
    return zip(range(len(lats)), lats, lons, class_labels)

points_list = make_points_list(trial_export_points)
with wds.TarWriter("trial_dataset.tar") as sink:

    process_point_fn(list(points_list)[0], sink)
    # use this to avoid the lambda pickle error
    # def map_fn(x):
    #     return process_point_fn(x, sink)

    # with Pool(MAX_PARALLEL_TASKS // 5) as p:
    #     p.map(map_fn, points_list)