# Earth Engine Explore

Explore possible explanatory and response variables for fire risk modeling across the Amazon


Most code borrowed from: https://colab.research.google.com/github/google/earthengine-community/blob/master/guides/linked/Earth_Engine_training_patches_getPixels.ipynb

In [None]:
from IPython.display import Image
from matplotlib import pyplot as plt


import concurrent
import ee
import google
import io
import json
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import multiprocessing
import numpy as np
import requests
import tensorflow as tf
import geopandas as gpd
from tqdm.notebook import tqdm

from google.api_core import retry
from google.protobuf import json_format
from IPython.display import Image
from matplotlib import rc

rc('animation', html='html5')

In [None]:
SEED = 54
RNG = np.random.default_rng(SEED)

# Authentication

In [None]:
PROJECT = 'ksolvik-misc'

In [None]:
ee.Authenticate()

In [None]:
# credentials, _ = google.auth.default()
ee.Initialize(project=PROJECT)#, opt_url='https://earthengine-highvolume.googleapis.com')

In [None]:
# REPLACE WITH YOUR BUCKET!
OUTPUT_FILE = '../data/test.tfrecord.gz'

# MODIS vegetation indices, 16-day.
ALPHAEARTH_EMBEDDINGS = ee.ImageCollection('GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL')
# Bay area.
ROI = gpd.read_file('../data/Limites_RAISG_2025/Lim_Raisg.shp')
# ROI = ee.Geometry.Rectangle(
#     [-123.05832753906247, 37.03109527141115,
#      -121.14121328124997, 38.24468432993584])

# Number of areas to sample
N_SAMPLE = 5

# A random sample of N locations in the ROI as a list of GeoJSON points.
SAMPLE = ROI.sample_points(N_SAMPLE, rng=RNG).geometry.explode().get_coordinates()
SAMPLE.index = np.arange(SAMPLE.shape[0])

# The years from which to sample
YEARS = np.arange(2023, 2024)

# Make a projection to discover the scale in degrees.
SCALE_AE = 10
PROJ_AE = ee.Projection('EPSG:4326').atScale(SCALE_AE)
PROJ_AE_DICT = PROJ_AE.getInfo()
# Get scales out of the transform.
SCALE_X = PROJ_AE_DICT['transform'][0]
SCALE_Y = -PROJ_AE_DICT['transform'][4]
# Set patch size
PATCH_SIZE = 256


MB_FIRE = ee.Image('projects/mapbiomas-public/assets/brazil/fire/collection4_1/mapbiomas_fire_collection41_annual_burned_v1')
MB_FIRE_REPROJ = MB_FIRE.reproject(PROJ_AE)

## Image retrieval functions

This section has a function to get a 1000x1000 meter patch of pixels from an asset, centered on the provided coordinates, as a numpy array.  The function can be retried automatically by using the [Retry](https://googleapis.dev/python/google-api-core/latest/retry.html) decorator.  There is also a function to serialize a structured array to a `tf.Example` proto.

In [None]:
@retry.Retry()
def get_patch(coords, ic, patch_size, scale_x, scale_y, band_sel=None, filter_point=True):
  """Get a patch of pixels from an asset, centered on the coords."""
  ee_point = ee.Geometry.Point(coords)
  if filter_point:
    image = (ic
        .filterBounds(ee_point)
        .first())
  else:
    image = (ic
        .first())
  image_id = image.getInfo()['id']

  # Make a request object.
  request = {
      'assetId': image_id,
      'fileFormat': 'NPY',
      'grid': {
          'dimensions': {
              'width': patch_size,
              'height':patch_size
          },
          'affineTransform': {
              'scaleX': scale_x,
              'shearX': 0,
              'translateX': coords[0],
              'shearY': 0,
              'scaleY': scale_y,
              'translateY': coords[1]
          },
          'crsCode': 'EPSG:4326',
      },
  }

  if not band_sel is None:
    request['bandIds'] = band_sel

  return np.load(io.BytesIO(ee.data.getPixels(request)))

@retry.Retry()
def compute_patch(coords, image, patch_size, scale_x, scale_y, band_sel=None):
  """Compute a patch of pixel, with upper-left corner defined by the coords."""

  # Make a request object.
  request = {
      # 'assetId': image_id,
      'expression':image,
      'fileFormat': 'NPY',
      'grid': {
          'dimensions': {
              'width': patch_size,
              'height':patch_size
          },
          'affineTransform': {
              'scaleX': scale_x,
              'shearX': 0,
              'translateX': coords[0],
              'shearY': 0,
              'scaleY': scale_y,
              'translateY': coords[1]
          },
          'crsCode': 'EPSG:4326',
      },
  }

  if not band_sel is None:
    request['bandIds'] = band_sel

  return np.load(io.BytesIO(ee.data.computePixels(request)))

def _float_feature(floats):
  """Returns a float_list from a float list."""
  print(floats)
  return tf.train.Feature(float_list=tf.train.FloatList(value=floats))


def array_to_example(structured_array):
  """"Serialize a structured numpy array into a tf.Example proto."""
  feature = {}
  for f in structured_array.dtype.names:
    feature[f] = tf.train.Feature(
        float_list = tf.train.FloatList(
            value = structured_array[f].flatten()))
  return tf.train.Example(
      features = tf.train.Features(feature = feature))

In [None]:

executor = concurrent.futures.ThreadPoolExecutor(max_workers=10)

writer = tf.io.TFRecordWriter(OUTPUT_FILE, 'GZIP')

for year in tqdm(YEARS):
    year = int(year)
    ae_year = ALPHAEARTH_EMBEDDINGS.filter(
        ee.Filter.calendarRange(year, year, 'year')
    ).mosaic()
    mb_year = MB_FIRE_REPROJ

    joined_img = ae_year.addBands(mb_year)

    future_to_image = {
        executor.submit(compute_patch, [point.x, point.y], joined_img, PATCH_SIZE, SCALE_X, SCALE_Y):
            'ALL_{}'.format(index) for index, point in SAMPLE.iterrows()
    }

    arrays = ()
    types = []
    for future in concurrent.futures.as_completed(future_to_image):
      image_id = future_to_image[future]
      try:
          np_array = future.result()
          example_proto = array_to_example(np_array)
          writer.write(example_proto.SerializeToString())
          writer.flush()
      except Exception as e:
          print(e)
          pass

writer.close()

In [None]:
# Specify the size and shape of patches expected by the model.
FEATURES = joined_img.bandNames().getInfo()
KERNEL_SHAPE = [PATCH_SIZE, PATCH_SIZE]
COLUMNS = [
  tf.io.FixedLenFeature(shape=KERNEL_SHAPE, dtype=tf.float32) for k in FEATURES
]
FEATURES_DICT = dict(zip(FEATURES, COLUMNS))

In [None]:
def parse_tfrecord(example_proto):
  """Parse a serialized example."""
  return tf.io.parse_single_example(example_proto, FEATURES_DICT)

dataset = tf.data.TFRecordDataset(OUTPUT_FILE, compression_type='GZIP')
dataset = dataset.map(parse_tfrecord, num_parallel_calls=5)

In [None]:
for data in dataset:
  rgb = np.stack([
      data['A00'].numpy(),
      data['A01'].numpy(),
      data['A02'].numpy()], axis=2)
  plt.imshow(rgb)
  plt.show()
