# Training data pipeline


Most code borrowed from: https://colab.research.google.com/github/google/earthengine-community/blob/master/guides/linked/Earth_Engine_training_patches_getPixels.ipynb

In [None]:
import ee
import io
import numpy as np
import pandas as pd
import geopandas as gpd
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
import random
import tensorflow as tf

from google.api_core import retry


In [None]:
SEED = 54
RNG = np.random.default_rng(SEED)

# Set params

In [None]:
# The year we're predicting on (and target data will be sampled from)
TARGET_YEAR = 2024

# Region of interest to sample from
ROI = gpd.read_file('../data/Limites_RAISG_2025/Lim_Raisg.shp')

# Number of areas to sample
N_SAMPLE = 2000

# Set final scale
FINAL_SCALE = 500 # in Meters

# Set patch size
PATCH_SIZE = 128

# How much to keep as validation
VALIDATION_RATIO=0.2

# Helper Functions

In [None]:
def sample_random_points(roi: gpd.GeoDataFrame, n_sample: int, rng: np.random.Generator)->np.array:
  """Get random points within region of interest."""
  sample_df = roi.sample_points(n_sample, rng=rng).geometry.explode().get_coordinates()
  sample_df.index = np.arange(sample_df.shape[0])
  return sample_df.values

def array_to_example(structured_array):
  """"Convert structured numpy array to tf.Example proto."""
  feature = {}
  for f in structured_array.dtype.names:
    feature[f] = tf.train.Feature(
        float_list = tf.train.FloatList(
            value = structured_array[f].flatten()))
  return tf.train.Example(
      features = tf.train.Features(feature = feature))

def serialize_example(structured_array):
  """Convert structured numpy array to serliazed tf.Example proto"""
  return array_to_example(structured_array).SerializeToString()

def split_dataset(element, n_partitions) -> int:
  weights = [1 - VALIDATION_RATIO, VALIDATION_RATIO]
  return random.choices([0, 1], weights)[0]

def points_to_geeflow_csv(sample_points, out_csv_path):
  lon = sample_points[:,0]
  lat = sample_points[:,1]
  out_df = pd.DataFrame({
    'lat': lat,
    'lon': lon,
    'label': 0,
    'split': 'train',
    'index': range(lon.shape[0])
  }
  ).set_index('index')

  # Shuffle order
  out_df = out_df.sample(frac=1).reset_index(drop=True)
  out_df.index.names = ['index']

  num_train = round(out_df.shape[0]*(1-VALIDATION_RATIO))
  out_df.loc[num_train:, 'split'] = 'val'

  out_df.to_csv(out_csv_path)

  


# Custom Beam DoFn for gathering EE data

In [None]:
class EEComputePatch(beam.DoFn):
    """DoFn() for computing EE patch
    
    config (dict): Dictionary containing configuration settings 
        in the following key:value pairs:
            project_id (str): Google Cloud project id
            patch_size (int): Patch size, in pixels, of output chips
            scale (float): Final scale, in m
            target_year (int): Year of prediction
            target_key (str): Name of target data, corresponds to key in self.prep_dict
            inputs_keys (list): Names of input data, correspond to keys in self.prep_dict
            proj (str): Projection, e.g. "EPSG:4326"
    """
    def __init__(self, config):
        self.config = config
        self.prep_dict = {
            'embeddings': self._prep_embeddings,
            'mcd64': self._prep_mcd64,
            'mb_fire': self._prep_mb_burned_area
        }

    def setup(self):
        print(f"Initializing Earth Engine for project: {self.config['project_id']}")
        ee.Initialize(project=self.config['project_id'], opt_url='https://earthengine-highvolume.googleapis.com')

        # Set some params
        self.proj = ee.Projection(self.config['proj']).atScale(self.config['scale'])
        self.proj_dict = self.proj.getInfo()
        self.scale_x = self.proj_dict['transform'][0]
        self.scale_y = -self.proj_dict['transform'][4]

        # Setup Earth Engine image object with all target bands
        inputs_list = [
            self.prep_dict[k](self.config['target_year']-1)
            for k in self.config['inputs_keys']
        ]
        outputs_list = [self.prep_dict[self.config['target_key']](self.config['target_year'])]
        full_list = inputs_list + outputs_list
        # Get original band names, with system indices prepended (toBands() adds)
        band_names = [
            bn 
            for image in full_list
            for bn in image.bandNames().getInfo()
        ]

        # Final prepped image
        self.prepped_image = ee.ImageCollection(inputs_list + outputs_list).toBands().rename(band_names)

    @retry.Retry()
    def process(self, coords):
        """Compute a patch of pixel, with upper-left corner defined by the coords."""

        # Make a request object.
        request = {
            'expression':self.prepped_image,
            'fileFormat': 'NPY',
            'grid': {
                'dimensions': {
                    'width': self.config['patch_size'],
                    'height':self.config['patch_size']
                },
                'affineTransform': {
                    'scaleX': self.scale_x,
                    'shearX': 0,
                    'translateX': coords[0],
                    'shearY': 0,
                    'scaleY': self.scale_y,
                    'translateY': coords[1]
                },
                'crsCode': 'EPSG:4326',
            },
        }

        yield np.load(io.BytesIO(ee.data.computePixels(request)))
    
    def _prep_embeddings(self, year):
        return (
            ee.ImageCollection('GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL')
            .filter(ee.Filter.calendarRange(year, year, 'year'))
            .mosaic()
            .setDefaultProjection(self.proj)
            .reduceResolution('mean', maxPixels=500)
            )

    def _prep_mcd64(self, year):
        return (
            ee.ImageCollection('MODIS/061/MCD64A1')
            .select('BurnDate')
            .filter(ee.Filter.calendarRange(year, year, 'year'))
            .min()
            )

    def _prep_mb_burned_area(self, year):
        return (
            ee.Image('projects/mapbiomas-public/assets/brazil/fire/collection4_1/mapbiomas_fire_collection41_annual_burned_v1')
            .select(['burned_area_{}'.format(year)])
            .reduceResolution('mean', maxPixels=500)
            )

    def _prep_default(self, year):
        """Example prep method"""
        return (
            ee.ImageCollection()
            .mean()
            .reduceResolution('mean', maxPixels=500)
            )

# Run beam pipeline

In [None]:
# A random sample of N locations in the ROI as a list of GeoJSON points.
SAMPLE_POINTS  = sample_random_points(ROI, N_SAMPLE, RNG)

# Write to csv compatible with geeflow
points_to_geeflow_csv(SAMPLE_POINTS, './amazon.csv')

beam_options = PipelineOptions([], direct_num_workers=8, direct_running_mode='multi_processing')

config_dict = {
    'project_id': 'ksolvik-misc',
    'patch_size': PATCH_SIZE,
    'scale': FINAL_SCALE,
    'target_year': TARGET_YEAR,
    'target_key': 'mcd64',
    'inputs_keys': ['embeddings', 'mb_fire'],
    'proj': 'EPSG:4326'
}

with beam.Pipeline(options=beam_options) as pipeline:
  training_data, validation_data = (
      pipeline
      | 'Create points' >> beam.Create(SAMPLE_POINTS)
      | 'Get patch' >> beam.ParDo(EEComputePatch(config_dict))
      | 'Serialize' >> beam.Map(serialize_example)
      | 'Split dataset' >> beam.Partition(split_dataset, 2)
  )

  training_data | 'Write training data' >> beam.io.WriteToTFRecord(
      '../data/beamMODIS/training', file_name_suffix='.tfrecord.gz'
  )
  validation_data | 'Write validation data' >> beam.io.WriteToTFRecord(
      '../data/beamMODIS/validation', file_name_suffix='.tfrecord.gz'
  )