In [None]:
#@title Copyright 2024 Google LLC. { display-mode: "form" }
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Earth Engine benchmarking toolkit

This notebook provides a set of tools for benchmarking the EECU-time cost of
a number of different types of Earth Engine processing operations.

For more information about how to use this notebook to estimate Earth Enigne costs, please see [the developers' guide](https://developers.google.com/earth-engine/guides/computation_benchmarks)

In [None]:
from functools import cache
from google.colab import auth
from google.api_core import retry
from numpy.lib import recfunctions as rfn
from tqdm import tqdm
from tqdm.notebook import tqdm_notebook

import concurrent
import ee
import io
import google
import logging
import numpy as np
import requests
import tensorflow as tf

In [None]:
PROJECT_ID = "your-project-here" # @param {type:"string"}
BUCKET_NAME = "your-bucket-here" # @param {type:"string"}
RANDOM_SEED = 1 # @param {type:"integer"}
HV_DRY_RUN = True # @param {type:"boolean"}
TASK_DRY_RUN = True # @param {type:"boolean"}

In [None]:
# Grab the right creds.
auth.authenticate_user()
credentials, _ = google.auth.default()

# Initialize Earth Engine.
ee.Initialize(
    credentials,
    project=PROJECT_ID,
    opt_url='https://earthengine-highvolume.googleapis.com')

# High-volume data extraction

Copied from the ["Pixels to the People!" blog post from Nick Clinton](https://medium.com/google-earth/pixels-to-the-people-2d3c14a46da6). This issues a large number of queries to the Earth Engine online stack.

In [None]:
@cache
def get_proj(scale):
  """Return the EPSG:4326 projection at the given nominal scale."""
  return ee.Projection('EPSG:4326').atScale(scale).getInfo()

@retry.Retry()
def get_patch(coords, image, scale, patch_size=128):
  """Get a patch centered on the given point, as a numpy array."""
  # Pre-compute a geographic coordinate system.
  proj = get_proj(scale)

  # Get scales in degrees out of the transform.
  scale_x = proj['transform'][0]
  scale_y = -proj['transform'][4]

  # Offset to the upper left corner.
  offset_x = -scale_x * patch_size / 2
  offset_y = -scale_y * patch_size / 2

  request = {
      'fileFormat': 'NPY',
      'grid': {
          'dimensions': {
              'width': patch_size,
              'height': patch_size
          },
          'affineTransform': {
              'scaleX': scale_x,
              'shearX': 0,
              'shearY': 0,
              'scaleY': scale_y,
              'translateX': coords[0] + offset_x,
              'translateY': coords[1] + offset_y,
          },
          'crsCode': proj['crs']
      },
      'expression': image,
  }
  return np.load(io.BytesIO(ee.data.computePixels(request)))


@cache
def get_sample_coords(roi, n):
  """"Get a random sample of N points in the ROI."""
  points = ee.FeatureCollection.randomPoints(region=roi, points=n, maxError=1, seed=RANDOM_SEED)
  return points.aggregate_array('.geo').getInfo()


def array_to_example(structured_array, features):
  """"Serialize a structured numpy array into a tf.Example proto."""
  feature = {}
  for f in features:
    feature[f] = tf.train.Feature(
        float_list = tf.train.FloatList(
            value = structured_array[f].flatten()))
  return tf.train.Example(features=tf.train.Features(feature=feature))

def extract_samples(image, roi, num_samples, scale, features, filename):
  if HV_DRY_RUN:
    print(f"[Extracting {filename}]")
    return
  sample_points = get_sample_coords(roi, num_samples)
  writer = tf.io.TFRecordWriter(f'gs://{BUCKET_NAME}/{filename}.tfrecord.gz')

  def write(writer, result):
      try:
          example_proto = array_to_example(result, features)
          writer.write(example_proto.SerializeToString())
      except Exception as e:
          print(e)
          pass

  EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=HTTP_PARALLELISM)
  progress_monitor = tqdm_notebook(
      EXECUTOR.map(
          lambda x: write(writer, get_patch(x['coordinates'], image, scale)),
          sample_points),
      desc=filename,
      total=num_samples)

  # Remove spurious warnings about 429s (since we retry them anyway).
  class No429Filter(logging.Filter):
    def filter(self, record):
        # Don't allow 429 messages with log level WARNING.
        return not ("429" in record.getMessage() and record.levelno == logging.WARNING)

  filter = No429Filter()
  try:
    logging.getLogger("googleapiclient.http").addFilter(filter)
    result = list(progress_monitor)
  finally:
    logging.getLogger("googleapiclient.http").removeFilter(filter)
    writer.flush()
    writer.close()

# Sentinel-2 Composite

A simple mosaic based on cloud-filtered images.

In [None]:
# Blue, green, red, NIR, AOT.
S2_FEATURES = ['B2_median', 'B3_median', 'B4_median', 'B8_median', 'AOT_median']

def get_s2_median(start, end):
  """Get a Sentinel-2 median composite in the ROI."""
  s2 = ee.ImageCollection('COPERNICUS/S2_HARMONIZED')
  s2c = ee.ImageCollection('COPERNICUS/S2_CLOUD_PROBABILITY')
  s2Sr = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')

  s2c = s2c.filterDate(start, end)
  s2Sr = s2Sr.filterDate(start, end)

  def indexJoin(collectionA, collectionB, propertyName):
    joined = ee.ImageCollection(ee.Join.saveFirst(propertyName).apply(
        primary=collectionA,
        secondary=collectionB,
        condition=ee.Filter.equals(
            leftField='system:index',
            rightField='system:index'
        ))
    )
    return joined.map(lambda image : image.addBands(ee.Image(image.get(propertyName))))

  def maskImage(image):
    s2c = image.select('probability')
    return image.updateMask(s2c.lt(50))

  withCloudProbability = indexJoin(s2Sr, s2c, 'cloud_probability')
  masked = ee.ImageCollection(withCloudProbability.map(maskImage))
  return masked.reduce(ee.Reducer.median(), 8)

# Run the benchmarks

Caution! This can be expensive.

In [None]:
# Regions of interest.
regions = {
    'bay_area': ee.Geometry.Rectangle(
       [-123.05832753906247, 37.03109527141115, -121.14121328124997, 38.24468432993584]),
#   'nigeria': ee.FeatureCollection("USDOS/LSIB_SIMPLE/2017")\
#       .filter(ee.Filter.eq('country_na', 'Nigeria'))\
#       .first()\
#       .geometry(),
#   'germany': ee.FeatureCollection("USDOS/LSIB_SIMPLE/2017")\
#     .filter(ee.Filter.eq('country_na', 'Germany'))\
#     .first()\
#     .geometry(),
}

BASE_DATE = ee.Date('2024-01-01')
timeframes = {
    '3mo': ee.DateRange(BASE_DATE.advance(-3, 'month'), BASE_DATE),
#    '6mo': ee.DateRange(BASE_DATE.advance(-6, 'month'), BASE_DATE),
#    '1yr': ee.DateRange(BASE_DATE.advance(-1, 'year'), BASE_DATE),
}

scales = [120] # [10, 30, 120]
sample_count = [100] # [100, 500, 1000]

for region_name in regions:
  for timeframe in timeframes:
    image = get_s2_median(timeframes[timeframe].start(), timeframes[timeframe].end())
    for scale in scales:
      operation_name = f'{region_name}_{timeframe}_{scale}m_image'

      # Kick off an export job for the image.
      image_op = f'{operation_name}_0samples_exportimage'
      ee.data.setWorkloadTag(image_op)
      image_task = ee.batch.Export.image.toCloudStorage(
          image=image,
          bucket=BUCKET_NAME,
          fileNamePrefix=image_op,
          region=regions[region_name],
          scale=scale,
          description=image_op,
          maxPixels=2e10
      )

      if not TASK_DRY_RUN:
        image_task.start()

      for n in sample_count:
        # Export point samples to BQ.
        bq_op = f"{operation_name}_{n}samples_bq"
        ee.data.setWorkloadTag(bq_op)
        table = image.sample(
                region=regions[region_name],
                scale=scale,
                numPixels=n,
                tileScale=4,
                seed=RANDOM_SEED)
        bq_task = ee.batch.Export.table.toBigQuery(
            collection=table,
            description=bq_op,
            table=f"{PROJECT_ID}.{BUCKET_NAME}.{bq_op}",
        )
        if not TASK_DRY_RUN:
          bq_task.start()
        else:
          print(f"[dry run] Starting {bq_task}")

        # Extract sample patches using the EE HV API.
        hv_op = f"{operation_name}_{n}samples_hv"
        ee.data.setWorkloadTag(hv_op)
        extract_samples(image=image,
                        roi=regions[region_name],
                        num_samples=n,
                        scale=scale,
                        features=S2_FEATURES,
                        filename=hv_op)