<a href="https://colab.research.google.com/github/boothmanrylan/canadaMSSForestDisturbances/blob/main/exportTrainingData.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Based on https://github.com/GoogleCloudPlatform/python-docs-samples/tree/main/people-and-planet-ai/land-cover-classification
and https://github.com/google/earthengine-community/blob/master/guides/linked/Earth_Engine_training_patches_computePixels.ipynb



# Setup

In [None]:
!pip install --quiet --upgrade pip
!pip install --quiet "apache-beam[gcp]==2.46.0"

exit() # restart runtime to ensure we get the newly installed packages

In [None]:
import os

import google
from google.colab import auth
from google.api_core import retry

import requests

import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions

import ee
import geopandas

import numpy as np
import tensorflow as tf

In [None]:
PROJECT = 'api-project-269347469410'
BUCKET = 'gs://rylan-mssforestdisturbances/'
LOCATION = 'us-central1'

HIGH_VOLUME_ENDPOINT = 'https://earthengine-highvolume.googleapis.com'

auth.authenticate_user()

os.environ['GOOGLE_CLOUD_PROJECT'] = PROJECT
!gcloud config set project {PROJECT}

credentials, _ = google.auth.default()
ee.Initialize(credentials, project=PROJECT, opt_url=HIGH_VOLUME_ENDPOINT)

In [None]:
# clone and install msslib
!git clone --quiet https://github.com/boothmanrylan/msslib.git
%cd msslib
!pip install --quiet .
%cd ..

!git clone --quiet https://github.com/boothmanrylan/canadaMSSForestDisturbances.git
%cd canadaMSSForestDisturbances
from mss_forest_disturbances import data

In [None]:
train_file = os.path.join(BUCKET, "data", "train_cells.geojson")
train_cells = geopandas.read_file(train_file)

In [None]:
MAX_REQUESTS = 20

# Export

In [None]:
ECOZONES = [4, 5, 6, 7, 9, 11, 12, 13, 14, 15]
DISTURBANCE_TPYES = ['fire', 'harvest', 'undisturbed']

def get_image_label_metadata(series):
    """
    row should be a pandas series with keys:
    lat, lon, year, ecozone, train/test/val, and fire/harvest/no disturbance
    """
    # TODO: use lat, lon, and year as inputs to msslib.getCol()

    # TODO: must return an iterable in order for FlatMap to work: use yield
    pass


def serialize_tensor(image, label, metadata):
    # TODO: create a tf.train.Example()
    # TODO: return example.SerializeToString() --> ensure we can read/parse this later on
    pass

class ProcessSampleGroup(beam.PTransform):
    def __init__(self, prefix):
        super().__init__()
        self.prefix = prefix

    def expand(self, pcoll):
        return (
            pcoll
            | beam.Reshuffle()
            | beam.FlatMap(get_image_label_metadata)
            | beam.MapTuple(serialize_to_tensor)
            | beam.io.WriteToTFRecord(self.prefix, file_name_suffix=".tfrecord.gz")
        )

def filter(x, ecozone, disturbance_type):
    x['ecozone'] == ecozone and x['disturbance_type'] == disturbance_type

def write_tfrecord(input_file, output_prefix):
    data = pd.read_csv(input_file) # TODO: GeoJSON

    with beam.Pipeline() as pipeline:
        pcoll = pipeline | beam.Create(data)

        for ecozone in ECOZONES:
            for disturbance_type in DISTURBANCE_TYPES:
                path = os.path.join(
                    output_prefix,
                    f"ecozone{ecozone}",
                    disturbance_type
                )

                inner_pcoll = pcoll | beam.Filter(
                    lambda x: filter(x, ecozone, disturbance_type)
                )
                inner_pcoll.apply(ProcessSampleGroup(prefix=path))