# Earth Engine Explore

Explore possible explanatory and response variables for fire risk modeling across the Amazon


Most code borrowed from: https://colab.research.google.com/github/google/earthengine-community/blob/master/guides/linked/Earth_Engine_training_patches_getPixels.ipynb

In [None]:
from IPython.display import Image
from matplotlib import pyplot as plt


import concurrent
import ee
import google
import io
import json
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import multiprocessing
import numpy as np
import requests
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import geopandas as gpd
import pandas as pd
from tqdm.notebook import tqdm
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions


from google.api_core import retry
from google.protobuf import json_format
from IPython.display import Image
from matplotlib import rc

rc('animation', html='html5')

In [None]:
SEED = 54
RNG = np.random.default_rng(SEED)

# Authentication

In [None]:
PROJECT = 'ksolvik-misc'

In [None]:
ee.Authenticate()

In [None]:
# credentials, _ = google.auth.default()
ee.Initialize(project=PROJECT, opt_url='https://earthengine-highvolume.googleapis.com')

# Set params

In [None]:
# Features: ALPHAEARTH_EMBEDDINGS
ALPHAEARTH_EMBEDDINGS = ee.ImageCollection('GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL')

# The years from which to sample for alpha earth
YEARS = np.arange(2023, 2024)
# Target variable: MB Fire annual burned area
MB_FIRE = ee.Image('projects/mapbiomas-public/assets/brazil/fire/collection4_1/mapbiomas_fire_collection41_annual_burned_v1')

# Region of interest to sample from
ROI = gpd.read_file('../data/Limites_RAISG_2025/Lim_Raisg.shp')

# Number of areas to sample
N_SAMPLE = 2000

# Set final scale
FINAL_SCALE = 30 # in Meters
# Set patch size
PATCH_SIZE = 128

# How much to keep as validation
VALIDATION_RATIO=0.2

# For model, input and output bands
INPUT_BANDS = ['{}{:02d}'.format('A', i) for i in range(64)] + ['burned_area_{}'.format(y) for y in range(2000, 2024)]
OUTPUT_BANDS = ['burned_area_2024']

In [None]:
def sample_random_points(roi: gpd.GeoDataFrame, n_sample: int, rng: np.random.Generator)->np.array:
  """Get random points within region of interest."""
  sample_df = roi.sample_points(n_sample, rng=rng).geometry.explode().get_coordinates()
  sample_df.index = np.arange(sample_df.shape[0])
  return sample_df.values

In [None]:
## Precompute some inputs based on params
# A random sample of N locations in the ROI as a list of GeoJSON points.
SAMPLE_POINTS  = sample_random_points(ROI, N_SAMPLE, RNG)

# Make a projection to discover the scale in degrees.
PROJ_AE = ee.Projection('EPSG:4326').atScale(FINAL_SCALE)
PROJ_AE_DICT = PROJ_AE.getInfo()
# Get scales out of the transform.
SCALE_X = PROJ_AE_DICT['transform'][0]
SCALE_Y = -PROJ_AE_DICT['transform'][4]

## Image retrieval functions

In [None]:

@retry.Retry()
def compute_patch(coords, image, patch_size, scale_x, scale_y, band_sel=None):
  """Compute a patch of pixel, with upper-left corner defined by the coords."""

  # Make a request object.
  request = {
      'expression':image,
      'fileFormat': 'NPY',
      'grid': {
          'dimensions': {
              'width': patch_size,
              'height':patch_size
          },
          'affineTransform': {
              'scaleX': scale_x,
              'shearX': 0,
              'translateX': coords[0],
              'shearY': 0,
              'scaleY': scale_y,
              'translateY': coords[1]
          },
          'crsCode': 'EPSG:4326',
      },
  }

  if not band_sel is None:
    request['bandIds'] = band_sel

  return np.load(io.BytesIO(ee.data.computePixels(request)))

def serialize_example(structured_array):
  """Convert structured numpy array to serliazed tf.Example proto"""
  return array_to_example(structured_array).SerializeToString()

def array_to_example(structured_array):
  """"Convert structured numpy array to tf.Example proto."""
  feature = {}
  for f in structured_array.dtype.names:
    feature[f] = tf.train.Feature(
        float_list = tf.train.FloatList(
            value = structured_array[f].flatten()))
  return tf.train.Example(
      features = tf.train.Features(feature = feature))

# Option 1: Execute using simple concurrent.futures

In [None]:

# executor = concurrent.futures.ThreadPoolExecutor(max_workers=10)

# writer = tf.io.TFRecordWriter(OUTPUT_FILE, 'GZIP')

for year in YEARS:
    year = int(year)
    ae_year_mean = (ALPHAEARTH_EMBEDDINGS
               .filter(ee.Filter.calendarRange(year, year, 'year'))
               .mosaic()
               .setDefaultProjection(PROJ_AE)
               .reduceResolution('mean')
               )
    mb_year = MB_FIRE

    joined_img = ae_year_mean.addBands(mb_year)

#     future_to_image = {
#         executor.submit(compute_patch, [point.x, point.y], joined_img, PATCH_SIZE, SCALE_X, SCALE_Y):
#             'ALL_{}'.format(index) for index, point in SAMPLE.iterrows()
#     }

#     arrays = ()
#     types = []
#     for future in concurrent.futures.as_completed(future_to_image):
#       image_id = future_to_image[future]
#       try:
#           np_array = future.result()
#           example_proto = array_to_example(np_array)
#           writer.write(example_proto.SerializeToString())
#           writer.flush()
#       except Exception as e:
#           print(e)
#           pass

# writer.close()

# Option 2: Using beam pipeline

In [None]:
import random

In [None]:
class EEComputePatch(beam.DoFn):
    def setup(self):
        ee.Initialize(project=PROJECT, opt_url='https://earthengine-highvolume.googleapis.com')

    def process(self, coords, image, patch_size, scale_x, scale_y, band_sel=None):
        """Compute a patch of pixel, with upper-left corner defined by the coords."""

        # Make a request object.
        request = {
            'expression':image,
            'fileFormat': 'NPY',
            'grid': {
                'dimensions': {
                    'width': patch_size,
                    'height':patch_size
                },
                'affineTransform': {
                    'scaleX': scale_x,
                    'shearX': 0,
                    'translateX': coords[0],
                    'shearY': 0,
                    'scaleY': scale_y,
                    'translateY': coords[1]
                },
                'crsCode': 'EPSG:4326',
            },
        }

        if not band_sel is None:
            request['bandIds'] = band_sel

        yield np.load(io.BytesIO(ee.data.computePixels(request)))

In [None]:
def split_dataset(element, n_partitions) -> int:
  weights = [1 - VALIDATION_RATIO, VALIDATION_RATIO]
  return random.choices([0, 1], weights)[0]

beam_options = PipelineOptions([], direct_num_workers=4, direct_running_mode='multi_processing')

with beam.Pipeline(options=beam_options) as pipeline:
  training_data, validation_data = (
      pipeline
      | 'Create points' >> beam.Create(SAMPLE_POINTS)
      | 'Get patch' >> beam.ParDo(EEComputePatch(), joined_img, PATCH_SIZE, SCALE_X, SCALE_Y)
      | 'Serialize' >> beam.Map(serialize_example)
      | 'Split dataset' >> beam.Partition(split_dataset, 2)
  )

  training_data | 'Write training data' >> beam.io.WriteToTFRecord(
      '../data/beam/training', file_name_suffix='.tfrecord.gz'
  )
  validation_data | 'Write validation data' >> beam.io.WriteToTFRecord(
      '../data/beam/validation', file_name_suffix='.tfrecord.gz'
  )

# Parse results

In [None]:
# Specify the size and shape of patches expected by the model.
FEATURES = joined_img.bandNames().getInfo()
KERNEL_SHAPE = [PATCH_SIZE, PATCH_SIZE]
COLUMNS = [
  tf.io.FixedLenFeature(shape=KERNEL_SHAPE, dtype=tf.float32) for k in FEATURES
]
FEATURES_DICT = dict(zip(FEATURES, COLUMNS))

In [None]:
def parse_tfrecord(example_proto):
  """Parse a serialized example."""
  return tf.io.parse_single_example(example_proto, FEATURES_DICT)

dataset = tf.data.TFRecordDataset('../data/beam/training-00001-of-00008.tfrecord.gz', compression_type='GZIP')
dataset = dataset.map(parse_tfrecord, num_parallel_calls=5)

In [None]:
# for data in dataset:
#   rgb = np.stack([
#       data['A00'].numpy(),
#       data['A01'].numpy(),
#       data['A02'].numpy()], axis=2)
#   plt.imshow(rgb+0.5)
#   plt.show()
#   plt.imshow(data['burned_area_2024'])
#   plt.show()


# Parse with option to augment

In [None]:
class Augment(tf.keras.layers.Layer):
  def __init__(self, seed=42):
    super().__init__()
    # both use the same seed, so they'll make the same random changes.
    self.augment_inputs = tf.keras.layers.RandomFlip(
        mode="horizontal_and_vertical", seed=seed)
    self.augment_labels = tf.keras.layers.RandomFlip(
        mode="horizontal_and_vertical", seed=seed)

  def call(self, inputs, labels):
    inputs = {name: self.augment_inputs(v) for name, v in inputs.items()}
    labels = self.augment_labels(labels)
    return inputs, labels


def parse_tfrecord(example_proto):
  return tf.io.parse_single_example(example_proto, FEATURES_DICT)


def to_tuple(inputs):
  return (
      {name: inputs[name] for name in INPUT_BANDS},
      inputs[OUTPUT_BANDS[0]]
      # tf.one_hot(tf.cast(inputs[OUTPUT_BANDS[0]], tf.uint8), )
  )


def get_dataset(pattern, batch_size, shuffle=True):
  dataset = tf.data.Dataset.list_files(pattern).interleave(
      lambda filename: tf.data.TFRecordDataset(filename, compression_type='GZIP'))
  dataset = dataset.map(parse_tfrecord, num_parallel_calls=tf.data.AUTOTUNE)
  dataset = dataset.map(to_tuple, num_parallel_calls=tf.data.AUTOTUNE)
  dataset = dataset.cache()
  if shuffle:
    dataset = dataset.shuffle(512)
  dataset = dataset.batch(batch_size)
  dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
  return dataset


# Create the training and validation datasets.
training_dataset = get_dataset('../data/beam/training-*.tfrecord.gz', 8).map(Augment(), num_parallel_calls=tf.data.AUTOTUNE)
validation_dataset = get_dataset('../data/beam/validation-*.tfrecord.gz', 1, shuffle=False)

# Inspect the first element from the training dataset.
for inputs, outputs in training_dataset.take(1):
  print("inputs:")
  for name, values in inputs.items():
    print(f"  {name}: {values.dtype.name} {values.shape}")
  print(f"outputs: {outputs.dtype.name} {outputs.shape}")

In [None]:
def get_model(input_shape):
    inputs = keras.Input(shape=[None, None, len(INPUT_BANDS)])

    ### [First half of the network: downsampling inputs] ###

    # Entry block
    x = layers.Conv2D(32, 3, strides=2, padding="same")(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    previous_block_activation = x  # Set aside residual

    for filters in [64, 128]:
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.MaxPooling2D(3, strides=2, padding="same")(x)

        # Project residual
        residual = layers.Conv2D(filters, 1, strides=2, padding="same")(
            previous_block_activation
        )
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    ### [Second half of the network: upsampling inputs] ###

    for filters in [128, 64, 32]:
        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.UpSampling2D(2)(x)

        # Project residual
        residual = layers.UpSampling2D(2)(previous_block_activation)
        residual = layers.Conv2D(filters, 1, padding="same")(residual)
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    # Add a per-pixel classification layer
    outputs = layers.Conv2D(1, 3, activation="sigmoid", padding="same")(x)

    # Define the model
    model = keras.Model(inputs, outputs)
    return model

def get_mlp(input_shape):
    inputs = keras.Input(shape=[None, None, len(INPUT_BANDS)])

    ### [First half of the network: downsampling inputs] ###

    # Entry block
    x = layers.Dense(128, activation='relu')(inputs)
    # x = layers.Dropout(0.3)(x)
    x = layers.Dense(64, activation='relu')(x)

    # Add a per-pixel classification layer
    outputs = layers.Dense(1, activation="sigmoid")(x)

    # Define the model
    model = keras.Model(inputs, outputs)
    return model

def get_multi_scale_mlp_head(input_shape, hidden=128):
    inputs = keras.Input(shape=[None, None, len(INPUT_BANDS)])

    # --- scale 1 (original resolution) ---
    s1 = layers.Dense(hidden, activation="gelu")(inputs)

    # --- scale 2 (128x128) ---
    s2 = layers.AveragePooling2D(pool_size=2)(inputs)
    s2 = layers.Dense(hidden, activation="gelu")(s2)
    s2 = layers.UpSampling2D(size=2, interpolation="bilinear")(s2)

    # --- scale 3 (64x64) ---
    s3 = layers.AveragePooling2D(pool_size=4)(inputs)
    s3 = layers.Dense(hidden, activation="gelu")(s3)
    s3 = layers.UpSampling2D(size=4, interpolation="bilinear")(s3)

    # Fuse
    fused = layers.Concatenate()([s1, s2, s3])
    fused = layers.LayerNormalization()(fused)
    fused = layers.Dense(hidden, activation="gelu")(fused)

    outputs = layers.Dense(1, activation='sigmoid')(fused)

    # Define the model
    model = keras.Model(inputs, outputs)
    return model


In [None]:
model = get_mlp([PATCH_SIZE, PATCH_SIZE, len(INPUT_BANDS)])
model.summary()

In [None]:
# Define the input dictionary layers.
inputs_dict = {
    name: tf.keras.Input(shape=(None, None, 1), name=name)
    for name in INPUT_BANDS
}

concat = tf.keras.layers.Concatenate()(list(inputs_dict.values()))
new_model = tf.keras.Model(inputs=inputs_dict, outputs=model(concat))
# print(new_model(inputs))

In [None]:
new_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0025),
    loss="Dice",
    metrics=[
        tf.keras.metrics.BinaryIoU(target_class_ids=[1]),
    ]
    )

checkpoint_filepath = './checkpoint.model.keras'
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    monitor='val_loss',
    mode='min',
    save_best_only=True)

early_stopping_callback = keras.callbacks.EarlyStopping(
    monitor='val_loss',
    mode='min',
    patience=10)

new_model.fit(
    training_dataset,
    validation_data=validation_dataset,
    epochs=25,
    callbacks=[model_checkpoint_callback, early_stopping_callback]
)

In [None]:
new_model = tf.keras.models.load_model('checkpoint.model.keras')

In [None]:
valid_masks = np.array([b[1][i].numpy() for b in validation_dataset for i in range(batch[1].shape[0])])
valid_burn_lastyear= np.array([b[0]['burned_area_2023'][i].numpy() + b[0]['burned_area_2022'][i].numpy() for b in validation_dataset for i in range(batch[1].shape[0])])

In [None]:
from sklearn.metrics import f1_score, precision_score, recall_score, jaccard_score

In [None]:
out = new_model.predict(validation_dataset)


In [None]:

print(f1_score(valid_masks.flatten()>0.5, valid_burn_lastyear.flatten()>0.5))
print(recall_score(valid_masks.flatten()>0.5, valid_burn_lastyear.flatten()>0.5))
print(precision_score(valid_masks.flatten()>0.5, valid_burn_lastyear.flatten()>0.5))
print(jaccard_score(valid_masks.flatten()>0.5, valid_burn_lastyear.flatten()>0.5))

In [None]:

print(f1_score(valid_masks.flatten()>0.5, out.flatten()>0.99))
print(recall_score(valid_masks.flatten()>0.5, out.flatten()>0.99))
print(precision_score(valid_masks.flatten()>0.5, out.flatten()>0.99))
print(jaccard_score(valid_masks.flatten()>0.5, out.flatten()>0.99))

In [None]:
j = 0
for batch in validation_dataset:
    for i in range(batch[1].shape[0]):
        if (batch[1][i].numpy()>0.5).sum()>0 or (out[j]>0.99).sum()>0:
            plt.imshow(batch[1][i].numpy())
            plt.title(j)
            plt.show()
            plt.imshow(out[j]>0.99)
            plt.title('{}-pred'.format(j))
            plt.show()
        j+=1