<a href="https://colab.research.google.com/github/boothmanrylan/canadaMSSForestDisturbances/blob/main/exportTrainingData.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
!pip install -q -q --upgrade pip
!pip install -q -q "apache-beam[gcp]==2.50.0"
!pip install -q -q msslib

In [None]:
import os
import io
import itertools

import google
from google.colab import auth
from google.api_core import retry

import requests

import ee
import geemap
import geopandas

import numpy as np
from numpy.lib import recfunctions as rfn
import pandas as pd
import tensorflow as tf

import matplotlib.pyplot as plt

from msslib import msslib

In [None]:
!git clone --quiet https://github.com/boothmanrylan/canadaMSSForestDisturbances.git
%cd canadaMSSForestDisturbances
from mss_forest_disturbances import constants, grid, preprocessing

In [None]:
auth.authenticate_user()

os.environ['GOOGLE_CLOUD_PROJECT'] = constants.PROJECT
!gcloud config set project {constants.PROJECT}

credentials, _ = google.auth.default()
ee.Initialize(
    credentials,
    project=constants.PROJECT,
    opt_url=constants.HIGH_VOLUME_ENDPOINT
)

# Step 1. Create a Covering Grid of Forest Dominated Canada

Step 1.1

Create a grid that covers all of forest dominated Canada, excluding cells that are >70% water. Export the resulting grid as an Earth Engine asset.

In [None]:
land_covering_grid = grid.build_land_covering_grid(
    ee.FeatureCollection(constants.ECOZONES).geometry(),
    constants.EXPORT_PATCH_SIZE
)

grid_list = land_covering_grid.toList(land_covering_grid.size())
ids = ee.List.sequence(0, land_covering_grid.size().subtract(1))
id_grid = ee.FeatureCollection(
    ids.map(lambda i: ee.Feature(grid_list.get(i)).set('cell_id', i))
)
task = ee.batch.Export.table.toAsset(
    collection=id_grid,
    description="export_land_covering_grid",
    assetId=os.path.join(constants.ASSET_PATH, "data", "land_covering_grid")
)
task.start()

Step 1.2

For each year for which we are generating training data estimate the amount of harvest and fire that occurred in each cell of the grid created in Step 1.1. Export the resulting FeatureCollection as an Earth Engine asset.

In [None]:
def set_id(feature):
    cell_id = feature.getNumber('cell_id').format("%d")
    year = feature.getNumber('year').format("%d")
    id = cell_id.cat('_').cat(year)
    return feature.set("id", id)

base_grid = ee.FeatureCollection(
    os.path.join(constants.ASSET_PATH, "data", "land_covering_grid")
)

years = range(constants.FIRST_DISTURBANCE_YEAR, constants.LAST_MSS_YEAR + 1)
for year in years:
    annual_grid = grid.add_disturbance_counts(base_grid, year).map(set_id)

    asset_name = f"grid{year}"
    task = ee.batch.Export.table.toAsset(
        collection=annual_grid,
        description=f"export_grid_with_disturbance_estimates_{year}",
        assetId=os.path.join(
            constants.ASSET_PATH,
            "data",
            "annual_grids",
            asset_name
        )
    )
    task.start()

Step 1.3

Create a grid that covers all of forest dominated Canada, excluding no cells, with an overlap between adjacent cells of 8 pixels to avoid edge artifacts.

We will use the grid from 1.1 to generate training data and this grid to create the final maps.

In [None]:
overlapped_grid = grid.build_grid(
    ee.FeatureCollection(constants.ECOZONES).geometry(),
    constants.PATCH_SIZE,
    constants.OVERLAP
)
task = ee.batch.Export.table.toAsset(
    collection=overlapped_grid,
    description='export_overlapped_grid',
    assetId=os.path.join(constants.ASSET_PATH, "data", "overlapped_grid")
)
task.start()

# Step 2. Select Cells from Grid to Create Train/Test/Val Datasets

In [None]:
years = range(constants.FIRST_DISTURBANCE_YEAR, constants.LAST_MSS_YEAR + 1)
annual_grids_assets = [
    os.path.join(
        constants.ASSET_PATH,
        "data",
        "annual_grids",
        f"grid{year}"
    )
    for year in years
]
annual_grids = ee.FeatureCollection([
    ee.FeatureCollection(asset)
    for asset in annual_grids_assets
]).flatten()

covering_grid = ee.FeatureCollection(
    os.path.join(constants.ASSET_PATH, "data", "land_covering_grid")
)

# perform the train/test/val splitting individually within each ecozone
ecozones = annual_grids.aggregate_array("ecozone").distinct()
ecozone_grids = [
    annual_grids.filter(ee.Filter.eq("ecozone", x))
    for x in ecozones.getInfo()
]

forested_ecozones = ee.FeatureCollection(constants.ECOZONES)
total_forested_area = forested_ecozones.geometry().area()

def calc_area(ecozone_id):
    ecozone = forested_ecozones.filter(ee.Filter.eq("ECOZONE_ID", ecozone_id))
    return ecozone.geometry().area()

ecozone_areas = ecozones.map(calc_area)
ecozone_areas_percentage = ecozone_areas.map(
    lambda x: ee.Number(x).divide(total_forested_area)
).getInfo()

# select 1200 fire, 1200 harvest, and 600 undisturbed cells in total
# distributed across ecozones proportional to ecozone size
cell_counts = np.array([1200, 1200, 600])
splits = [0.7, 0.15, 0.15]
selected_cells = [
    grid.sample_cells(
        ecozone_grid,
        *np.ceil(cell_counts * percent).tolist(),
        *splits
    )
    for ecozone_grid, percent in zip(ecozone_grids, ecozone_areas_percentage)
]

# join the train/test/val groups from each ecozone
# shuffle to ensure ecozones are intermingled
train_cells = ee.FeatureCollection(
    [ecozone_selection[0] for ecozone_selection in selected_cells]
).flatten().sort("shuffle")
test_cells = ee.FeatureCollection(
    [ecozone_selection[1] for ecozone_selection in selected_cells]
).flatten().sort("shuffle")
val_cells = ee.FeatureCollection(
    [ecozone_selection[2] for ecozone_selection in selected_cells]
).flatten().sort("shuffle")

# get all the cells that never appear in train/test/val (regardless of year)
used_cells_ids = ee.FeatureCollection([
    train_cells, test_cells, val_cells
]).flatten().aggregate_array("cell_id")
unused_cell_filter = ee.Filter.listContains(
    rightField='cell_id',
    leftValue=used_cells_ids,
).Not()
unused_cells = covering_grid.filter(unused_cell_filter)
unused_cells = unused_cells.randomColumn("shuffle", 42).sort("shuffle")
model2_train_cells = ee.FeatureCollection(unused_cells.toList(700))
model2_test_cells = ee.FeatureCollection(unused_cells.toList(150, 700))
model2_val_cells = ee.FeatureCollection(unused_cells.toList(150, 850))

model1_cells = [train_cells, test_cells, val_cells]
model2_cells = [model2_train_cells, model2_test_cells, model2_val_cells]
for model, model_cells in zip(["model1", "model2"], [model1_cells, model2_cells]):
    for group, cells in zip(["train", "test", "val"], model_cells):
        path = os.path.join(constants.ASSET_PATH, "data", model, group)
        task = ee.batch.Export.table.toAsset(
            collection=cells,
            description="export_cells",
            assetId=path,
        )
        task.start()

# Step 3. Export Image Patches

Based on https://github.com/GoogleCloudPlatform/python-docs-samples/tree/main/people-and-planet-ai/land-cover-classification
and https://github.com/google/earthengine-community/blob/master/guides/linked/Earth_Engine_training_patches_computePixels.ipynb

In [None]:
image_uri = os.path.join(
    constants.BASE_DOCKER_IMAGE_URI,
    "dataflow/dockerfile:1.0"
)

In [None]:
# this only needs to be run to re/create the docker image!
!gcloud builds submit --tag {image_uri} .

In [None]:
temp_location = os.path.join(constants.BUCKET, 'temp')
staging_location = os.path.join(constants.BUCKET, 'staging')
output_prefix = os.path.join(constants.BUCKET, 'scratch', 'test_export2')
input_asset = os.path.join(constants.ASSET_PATH, 'data', 'train_cells')

!python dataflow_job.py \
    --runner='DataflowRunner' \
    --project={constants.PROJECT} \
    --job_name='test-data-export' \
    --region='us-central1' \
    --temp_location={temp_location} \
    --staging_location={staging_location} \
    --num_workers=20 \
    --max-requests=20 \
    --input-asset={input_asset} \
    --output-prefix={output_prefix} \
    --experiments=use_runner_v2 \
    --sdk_container_image={image_uri} \
    --sdk_location=container

# Step 4. Verify TFRecords were Created Properly

In [None]:
IMAGE_FEATURES = {
    b: tf.io.FixedLenFeature(
        shape=[constants.EXPORT_PATCH_SIZE, constants.EXPORT_PATCH_SIZE],
        dtype=tf.float32
    )
    for b in constants.BANDS
}

LABEL_FEATURES = {
    "label": tf.io.FixedLenFeature(
        shape=[constants.EXPORT_PATCH_SIZE, constants.EXPORT_PATCH_SIZE],
        dtype=tf.int64
        )
}

METADATA_FEATURES = {
    m: tf.io.FixedLenFeature(shape=1, dtype=tf.int64)
    for m in ["ecozone", "doy"]
}

def parse(example_proto):
    image = tf.io.parse_single_example(example_proto, IMAGE_FEATURES)
    metadata = tf.io.parse_single_example(example_proto, METADATA_FEATURES)
    label = tf.io.parse_single_example(example_proto, LABEL_FEATURES)
    return image, metadata, label

files = tf.data.Dataset.list_files(f"{output_prefix}/*/*.tfrecord.gz")
dataset = tf.data.TFRecordDataset(files, compression_type="GZIP")
dataset = dataset.map(parse, num_parallel_calls=5)

for im, m, label in dataset.take(5):
    im = tf.stack([im[b] for b in constants.BANDS], axis=-1)
    label = label["label"]

    fig, axes = plt.subplots(1, 2, squeeze=True)
    axes[0].imshow(im[:, :, :3], vmin=0.02, vmax=0.08)
    axes[1].imshow(label)
    plt.show()
    print([(k, v.numpy()) for k, v in m.items()])
