<a href="https://colab.research.google.com/github/boothmanrylan/canadaMSSForestDisturbances/blob/main/exportTrainingData.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
!pip install --quiet --upgrade pip
!pip install --quiet "apache-beam[gcp]==2.46.0"
!pip install --quiet geemap
!pip install --quiet msslib

In [None]:
import os
import io
import itertools

import google
from google.colab import auth
from google.api_core import retry

import requests

import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions

import ee
import geemap
import geopandas

import numpy as np
from numpy.lib import recfunctions as rfn
import pandas as pd
import tensorflow as tf

import matplotlib.pyplot as plt

In [None]:
PROJECT = 'api-project-269347469410'
BUCKET = 'gs://rylan-mssforestdisturbances/'
LOCATION = 'us-central1'

HIGH_VOLUME_ENDPOINT = 'https://earthengine-highvolume.googleapis.com'

auth.authenticate_user()

os.environ['GOOGLE_CLOUD_PROJECT'] = PROJECT
!gcloud config set project {PROJECT}

credentials, _ = google.auth.default()
ee.Initialize(credentials, project=PROJECT, opt_url=HIGH_VOLUME_ENDPOINT)

from msslib import msslib

In [None]:
!git clone --quiet https://github.com/boothmanrylan/canadaMSSForestDisturbances.git
%cd canadaMSSForestDisturbances
from mss_forest_disturbances import data

In [None]:
MAX_REQUESTS = 20
ASSET_PATH = "projects/api-project-269347469410/assets/rylan-mssforestdisturbances/"

# Step 1. Create a Covering Grid of Forest Dominated Canada

Step 1.1

Create a grid that covers all of forest dominated Canada, excluding cells that are >70% water. Export the resulting grid as an Earth Engine asset.

In [None]:
GRID_CELL_SIZE = 512
grid = data.build_land_covering_grid(data.ECOZONES.geometry(), GRID_CELL_SIZE)
grid_list = grid.toList(grid.size())
ids = ee.List.sequence(0, grid.size().subtract(1))
id_grid = ee.FeatureCollection(
    ids.map(lambda i: ee.Feature(grid_list.get(i)).set('cell_id', i))
)

task = ee.batch.Export.table.toAsset(
    collection=id_grid,
    description="export_land_covering_grid",
    assetId=os.path.join(ASSET_PATH, "data", "land_covering_grid")
)
task.start()

Step 1.2

For each year for which we are generating training data estimate the amount of harvest and fire that occurred in each cell of the grid created in Step 1.1. Export the resulting FeatureCollection as an Earth Engine asset.

In [None]:
def set_id(feature):
    cell_id = feature.getNumber('cell_id').format("%d")
    year = feature.getNumber('year').format("%d")
    id = cell_id.cat('_').cat(year)
    return feature.set("id", id)

base_grid = ee.FeatureCollection(os.path.join(ASSET_PATH, "data", "land_covering_grid"))

for year in range(1985, 1996):
    annual_grid = data.add_disturbance_counts(base_grid, year).map(set_id)

    asset_name = f"disturbance_estimate_grid_{year}"
    task = ee.batch.Export.table.toAsset(
        collection=annual_grid,
        description=f"export_grid_with_disturbance_estimates_{year}",
        assetId=os.path.join(ASSET_PATH, "data", "annual_grids", asset_name)
    )
    task.start()

# Step 2. Select Cells from Grid to Create Train/Test/Val Datasets

In [None]:
annual_grids_assets = [
    os.path.join(
        ASSET_PATH,
        "data",
        "annual_grids",
        f"disturbance_estimate_grid_{year}"
    )
    for year in range(1985, 1996)
]
annual_grids = ee.FeatureCollection([
    ee.FeatureCollection(asset)
    for asset in annual_grids_assets
]).flatten()

# perform the train/test/val splitting individually within each ecozone
ecozones = annual_grids.aggregate_array("ecozone").distinct().getInfo()
ecozone_grids = [
    annual_grids.filter(ee.Filter.eq("ecozone", x))
    for x in ecozones
]

cell_counts = [200, 200, 200]
splits = [0.7, 0.15, 0.15]
selected_cells = [
    data.sample_cells(grid, *cell_counts, *splits)
    for grid in ecozone_grids
]

# join the train/test/val groups from each ecozone
# shuffle to ensure ecozones are intermingled
train_cells = ee.FeatureCollection(
    [ecozone_selection[0] for ecozone_selection in selected_cells]
).flatten().sort("shuffle")
test_cells = ee.FeatureCollection(
    [ecozone_selection[1] for ecozone_selection in selected_cells]
).flatten().sort("shuffle")
val_cells = ee.FeatureCollection(
    [ecozone_selection[2] for ecozone_selection in selected_cells]
).flatten().sort("shuffle")

# export each group to Google Earth Engine
task = ee.batch.Export.table.toAsset(
    collection=train_cells,
    description="export_train_cells",
    assetId=os.path.join(ASSET_PATH, "data", "train_cells")
)
task.start()

task = ee.batch.Export.table.toAsset(
    collection=test_cells,
    description="export_test_cells",
    assetId=os.path.join(ASSET_PATH, "data", "test_cells")
)
task.start()

task = ee.batch.Export.table.toAsset(
    collection=val_cells,
    description="export_val_cells",
    assetId=os.path.join(ASSET_PATH, "data", "val_cells")
)
task.start()

In [None]:
def count_images(feat):
    year = feat.getNumber("year")
    geom = feat.geometry(ERROR_MARGIN, PROJECTION)
    centroid = geom.centroid(1)

    images = msslib.getCol(
        aoi=centroid,
        yearRange=[year, year],
        doyRange=data.DOY_RANGE,
        maxCloudCover=100
    )
    return feat.set("num_images", images.size())

train_cells = ee.FeatureCollection(
    os.path.join(ASSET_PATH, "data", "val_cells")
)
train_cells = train_cells.map(count_images)
filtered_cells = train_cells.filter(ee.Filter.eq("num_images", 0))
print(train_cells.size().getInfo(), filtered_cells.size().getInfo())

# Step 3. Export Image Patches

Based on https://github.com/GoogleCloudPlatform/python-docs-samples/tree/main/people-and-planet-ai/land-cover-classification
and https://github.com/google/earthengine-community/blob/master/guides/linked/Earth_Engine_training_patches_computePixels.ipynb

In [None]:
# create default request for computePixels
# proj = data.PROJECTION.getInfo()
PROJECTION = ee.Projection('EPSG:4269').atScale(60)
proj = PROJECTION.getInfo()
scale_x = proj['transform'][0]
scale_y = -proj['transform'][4]

PATCH_SIZE = 512

OFFSET_X = -scale_x * PATCH_SIZE / 2
OFFSET_Y = -scale_y * PATCH_SIZE / 2

REQUEST = {
    'fileFormat': 'NPY',
    'grid': {
        'dimensions': {
            'width': PATCH_SIZE,
            'height': PATCH_SIZE,
        },
        'affineTransform': {
            'scaleX': scale_x,
            'shearX': 0,
            'shearY': 0,
            'scaleY': scale_y,
        },
        'crsCode': proj['crs']
    }
}

In [None]:
bands = ['nir', 'red_edge', 'red', 'green', 'tca', 'ndvi']
historical_bands = ['historical_' + x for x in bands]
BANDS = bands + historical_bands

ERROR_MARGIN = ee.ErrorMargin(0.1, "projected")


def ee_init():
    credentials, project = google.auth.default(
        scopes=[
            "https://www.googleapis.com/auth/cloud-platform",
            "https://www.googleapis.com/auth/earthengine",
        ]
    )
    ee.Initialize(
        credentials.with_quota_project(None),
        project=project,
        opt_url=HIGH_VOLUME_ENDPOINT,
    )


def _get_images_from_feature(feature):
    geom = feature.geometry(ERROR_MARGIN, PROJECTION)
    year = feature.getNumber("year")

    images = msslib.getCol(
        aoi=geom.centroid(1),
        yearRange=[year, year],
        doyRange=data.DOY_RANGE,
        maxCloudCover=100
    )

    return images


def get_image_ids(row, asset_path):
    ee_init()

    col = ee.FeatureCollection(asset_path)
    feature = col.filter(ee.Filter.eq("id", row["id"])).first()

    images = _get_images_from_feature(feature)

    image_ids = images.aggregate_array("system:id").getInfo()
    feature_ids = itertools.repeat(row["id"])
    paths = itertools.repeat(asset_path)

    return zip(image_ids, feature_ids, paths)


@retry.Retry()
def get_image_label_metadata(image_id, feature_id, asset_path):
    ee_init()

    image = msslib.process(ee.Image(image_id))
    image, label = data.prepare_image_for_export(image)
    image = image.select(BANDS)

    col = ee.FeatureCollection(asset_path)
    feature = col.filter(ee.Filter.eq("id", feature_id)).first()
    metadata = data.prepare_metadata_for_export(image, feature)
    metadata = {key: val.getInfo() for key, val in metadata.items()}

    geom = feature.geometry(ERROR_MARGIN, PROJECTION)
    coords = geom.centroid(1).getInfo()["coordinates"]

    request = dict(REQUEST)
    request['grid']['affineTransform']['translateX'] = coords[0] + OFFSET_X
    request['grid']['affineTransform']['translateY'] = coords[1] + OFFSET_Y

    image_request = dict(request)
    image_request['expression'] = image.unmask(0)
    np_image = np.load(io.BytesIO(ee.data.computePixels(image_request)))

    label_request = dict(request)
    label_request['expression'] = label.unmask(0)
    np_label = np.load(io.BytesIO(ee.data.computePixels(label_request)))

    return np_image, np_label, metadata


def serialize_tensor(image, label, metadata):
    features  = {
        b: tf.train.Feature(
            float_list=tf.train.FloatList(
                value=image[b].flatten()
            )
        )
        for b in BANDS
    }

    features["label"] = tf.train.Feature(
        int64_list=tf.train.Int64List(
            value=label["label"].flatten()
        )
    )

    for key, value in metadata.items():
        features[key] = tf.train.Feature(
            int64_list=tf.train.Int64List(value=[value])
        )

    example = tf.train.Example(features=tf.train.Features(feature=features))
    return example.SerializeToString()


def write_tfrecord(input_asset_path, output_prefix, pipeline_options=None):
    col = ee.FeatureCollection(input_asset_path)
    df = geemap.ee_to_df(
        col, col_names=['disturbance_type', 'ecozone', 'id', "shuffle"]
    )

    ########################################################
    # work on a small random subset of the complete dataframe
    df = df.sort_values(by="shuffle", ignore_index=True).head(20)
    ########################################################

    ecozones = set(df['ecozone'])
    disturbance_types = set(df['disturbance_type'])

    sets = list(itertools.product(ecozones, disturbance_types))
    paths = [
        os.path.join(output_prefix, f"ecozone{ecozone}", disturbance_type)
        for ecozone, disturbance_type in sets
    ]

    def partition(elem, _num_partitions):
        elem_set = (int(elem["ecozone"]), elem["disturbance_type"])
        return sets.index(elem_set)

    with beam.Pipeline(options=pipeline_options) as pipeline:
        pcoll = pipeline | beam.Create(list(df.iloc))  # iloc cannot be directly pickled
        groups = pcoll | beam.Partition(partition, len(sets))

        for i, group in enumerate(groups):
            uid = f"{sets[i][0]}_{sets[i][1]}"
            (group
             | f"{uid} get ids" >> beam.FlatMap(get_image_ids, asset_path=input_asset_path)
             | f"{uid} reshuffle" >> beam.Reshuffle()
             | f"{uid} get data" >> beam.MapTuple(get_image_label_metadata)
             | f"{uid} serialize" >> beam.MapTuple(serialize_tensor)
             | f"{uid} write" >> beam.io.WriteToTFRecord(paths[i], file_name_suffix=".tfrecord.gz")
            )

pipeline_options = PipelineOptions(
    runner="DataflowRunner",
    project=PROJECT,
    job_name="test-data-export-workflow",
    region="us-central1",
    save_main_session=True,
    setup_file="./setup.py",
    max_num_workers=20,
    disk_size_gb=50,
    temp_location=os.path.join(BUCKET, "temp"),
)
# pipeline_options = None

train_col_asset_path = os.path.join(ASSET_PATH, "data", "train_cells")
write_tfrecord(
    train_col_asset_path,
    os.path.join(BUCKET, "scratch", "test_export"),
    pipeline_options,
)

In [None]:
collection = ee.FeatureCollection(train_col_asset_path)
feature = collection.first()
images = msslib.getCol(
    aoi=feature.geometry().centroid(1),
    yearRange=[feature.getNumber("year"), feature.getNumber("year")],
    doyRange=data.DOY_RANGE,
    maxCloudCover=100
)
image = images.first()

system_id = image.get('system:id').getInfo()

test_image = ee.Image(system_id)
print(test_image.getInfo)

# Step 4. Verify TFRecords were Created Properly

In [None]:
IMAGE_FEATURES = {
    b: tf.io.FixedLenFeature(shape=[512, 512], dtype=tf.float32)
    for b in BANDS
}

LABEL_FEATURES = {
    "label": tf.io.FixedLenFeature(shape=[512, 512], dtype=tf.int64)
}

METADATA_FEATURES = {
    m: tf.io.FixedLenFeature(shape=1, dtype=tf.int64)
    for m in ["ecozone", "doy"]
}

def parse(example_proto):
    image = tf.io.parse_single_example(example_proto, IMAGE_FEATURES)
    metadata = tf.io.parse_single_example(example_proto, METADATA_FEATURES)
    label = tf.io.parse_single_example(example_proto, LABEL_FEATURES)
    return image, metadata, label

files = tf.data.Dataset.list_files("scratch/train/ecozone4/fire*.tfrecord.gz")
dataset = tf.data.TFRecordDataset(files, compression_type="GZIP")
dataset = dataset.map(parse, num_parallel_calls=5)

for im, m, label in dataset.take(5):
    im = tf.stack([im[b] for b in BANDS], axis=-1)
    label = label["label"]

    fig, axes = plt.subplots(1, 2, squeeze=True)
    axes[0].imshow(im[:, :, :3], vmin=0.02, vmax=0.08)
    axes[1].imshow(label)
    plt.show()
    print([(k, v.numpy()) for k, v in m.items()])
