# Training data pipeline


Most code borrowed from: https://colab.research.google.com/github/google/earthengine-community/blob/master/guides/linked/Earth_Engine_training_patches_getPixels.ipynb

In [None]:
import ee
import io
import numpy as np
import geopandas as gpd
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
import random
import tensorflow as tf
import pickle

import google
from google.api_core import retry


In [None]:
SEED = 54
RNG = np.random.default_rng(SEED)

# Authentication

In [None]:
PROJECT = 'ksolvik-misc'

In [None]:
ee.Authenticate()

In [None]:
# credentials, _ = google.auth.default()
ee.Initialize(project=PROJECT, opt_url='https://earthengine-highvolume.googleapis.com')

# Set params

In [None]:
# Features: ALPHAEARTH_EMBEDDINGS
ALPHAEARTH_EMBEDDINGS = ee.ImageCollection('GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL')

# The years from which to sample for alpha earth
INPUT_YEAR = 2023

# MB Fire annual burned area
MB_FIRE = (ee.Image('projects/mapbiomas-public/assets/brazil/fire/collection4_1/mapbiomas_fire_collection41_annual_burned_v1')
           .reduceResolution('mean', maxPixels=500))
# MODIS burned area product
MODIS_FIRE =  (ee.ImageCollection('MODIS/061/MCD64A1')
               .select('BurnDate')
              )
FIRE_YEAR = 2024


# Region of interest to sample from
ROI = gpd.read_file('../data/Limites_RAISG_2025/Lim_Raisg.shp')

# Number of areas to sample
N_SAMPLE = 2000

# Set final scale
FINAL_SCALE = 500 # in Meters
# Set patch size
PATCH_SIZE = 128

# How much to keep as validation
VALIDATION_RATIO=0.2

In [None]:
def sample_random_points(roi: gpd.GeoDataFrame, n_sample: int, rng: np.random.Generator)->np.array:
  """Get random points within region of interest."""
  sample_df = roi.sample_points(n_sample, rng=rng).geometry.explode().get_coordinates()
  sample_df.index = np.arange(sample_df.shape[0])
  return sample_df.values

In [None]:
## Precompute some inputs based on params
# A random sample of N locations in the ROI as a list of GeoJSON points.
SAMPLE_POINTS  = sample_random_points(ROI, N_SAMPLE, RNG)

# Make a projection to discover the scale in degrees.
PROJ_AE = ee.Projection('EPSG:4326').atScale(FINAL_SCALE)
PROJ_AE_DICT = PROJ_AE.getInfo()
# Get scales out of the transform.
SCALE_X = PROJ_AE_DICT['transform'][0]
SCALE_Y = -PROJ_AE_DICT['transform'][4]

## Image retrieval functions

In [None]:
@retry.Retry()
def compute_patch(coords, image, patch_size, scale_x, scale_y, band_sel=None):
  """Compute a patch of pixel, with upper-left corner defined by the coords."""

  # Make a request object.
  request = {
      'expression':image,
      'fileFormat': 'NPY',
      'grid': {
          'dimensions': {
              'width': patch_size,
              'height':patch_size
          },
          'affineTransform': {
              'scaleX': scale_x,
              'shearX': 0,
              'translateX': coords[0],
              'shearY': 0,
              'scaleY': scale_y,
              'translateY': coords[1]
          },
          'crsCode': 'EPSG:4326',
      },
  }

  if not band_sel is None:
    request['bandIds'] = band_sel

  return np.load(io.BytesIO(ee.data.computePixels(request)))

def serialize_example(structured_array):
  """Convert structured numpy array to serliazed tf.Example proto"""
  return array_to_example(structured_array).SerializeToString()

def array_to_example(structured_array):
  """"Convert structured numpy array to tf.Example proto."""
  feature = {}
  for f in structured_array.dtype.names:
    feature[f] = tf.train.Feature(
        float_list = tf.train.FloatList(
            value = structured_array[f].flatten()))
  return tf.train.Example(
      features = tf.train.Features(feature = feature))

In [None]:
year = INPUT_YEAR
ae_year_mean = (ALPHAEARTH_EMBEDDINGS
            .filter(ee.Filter.calendarRange(year, year, 'year'))
            .mosaic()
            .setDefaultProjection(PROJ_AE)
            .reduceResolution('mean', maxPixels=500)
            )
ae_prev_year_mean = (ALPHAEARTH_EMBEDDINGS
            .filter(ee.Filter.calendarRange(year-1, year-1, 'year'))
            .mosaic()
            .setDefaultProjection(PROJ_AE)
            .reduceResolution('mean', maxPixels=500)
            )
ae_prev_year_mean = ae_prev_year_mean.rename([b + '_prev_year' for b in ae_prev_year_mean.bandNames().getInfo()])

mb_year = MB_FIRE

modis_year = (MODIS_FIRE
              .filter(ee.Filter.calendarRange(FIRE_YEAR,FIRE_YEAR, 'year'))
              .min()
              )

joined_img = ae_year_mean.addBands(ae_prev_year_mean).addBands(mb_year).addBands(modis_year)

# Specify the size and shape of patches expected by the model.
FEATURES = joined_img.bandNames().getInfo()
KERNEL_SHAPE = [PATCH_SIZE, PATCH_SIZE]
COLUMNS = [
  tf.io.FixedLenFeature(shape=KERNEL_SHAPE, dtype=tf.float32) for k in FEATURES
]
FEATURES_DICT = dict(zip(FEATURES, COLUMNS))
with open('features_dict.pkl', 'wb') as f:
    pickle.dump(FEATURES_DICT, f)

# Beam pipeline

In [None]:
class EEComputePatch(beam.DoFn):
    def setup(self):
        ee.Initialize(project=PROJECT, opt_url='https://earthengine-highvolume.googleapis.com')

    def process(self, coords, image, patch_size, scale_x, scale_y, band_sel=None):
        """Compute a patch of pixel, with upper-left corner defined by the coords."""

        # Make a request object.
        request = {
            'expression':image,
            'fileFormat': 'NPY',
            'grid': {
                'dimensions': {
                    'width': patch_size,
                    'height':patch_size
                },
                'affineTransform': {
                    'scaleX': scale_x,
                    'shearX': 0,
                    'translateX': coords[0],
                    'shearY': 0,
                    'scaleY': scale_y,
                    'translateY': coords[1]
                },
                'crsCode': 'EPSG:4326',
            },
        }

        if not band_sel is None:
            request['bandIds'] = band_sel

        yield np.load(io.BytesIO(ee.data.computePixels(request)))

In [None]:
def split_dataset(element, n_partitions) -> int:
  weights = [1 - VALIDATION_RATIO, VALIDATION_RATIO]
  return random.choices([0, 1], weights)[0]

beam_options = PipelineOptions([], direct_num_workers=8, direct_running_mode='multi_processing')

with beam.Pipeline(options=beam_options) as pipeline:
  training_data, validation_data = (
      pipeline
      | 'Create points' >> beam.Create(SAMPLE_POINTS)
      | 'Get patch' >> beam.ParDo(EEComputePatch(), joined_img, PATCH_SIZE, SCALE_X, SCALE_Y)
      | 'Serialize' >> beam.Map(serialize_example)
      | 'Split dataset' >> beam.Partition(split_dataset, 2)
  )

  training_data | 'Write training data' >> beam.io.WriteToTFRecord(
      '../data/beamMODIS/training', file_name_suffix='.tfrecord.gz'
  )
  validation_data | 'Write validation data' >> beam.io.WriteToTFRecord(
      '../data/beamMODIS/validation', file_name_suffix='.tfrecord.gz'
  )