# Earth Engine Explore

Explore possible explanatory and response variables for fire risk modeling across the Amazon

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import pickle

In [None]:
SEED = 54
RNG = np.random.default_rng(SEED)

# Set params

In [None]:
# Final scale
FINAL_SCALE = 500 # in Meters
# Set patch size
PATCH_SIZE = 128

# For model, input and output bands
INPUT_BANDS = (
    ['{}{:02d}'.format('A', i) for i in range(64)]
    + ['burned_area_2023']
    # + ['{}{:02d}_prev_year'.format('A', i) for i in range(64)]
    # + ['burned_area_{}'.format(y) for y in range(1985, 2024)]
)
# INPUT_BANDS = ['burned_area_{}'.format(y) for y in range(1985, 2024)]
OUTPUT_BANDS = ['BurnDate']

# Structure needed for parsing tensorflow record
with open('features_dict.pkl', 'rb') as f:
    FEATURES_DICT = pickle.load(f)

# Create dataset

In [None]:
class Augment(tf.keras.layers.Layer):
  def __init__(self, seed=42):
    super().__init__()
    # both use the same seed, so they'll make the same random changes.
    self.augment_inputs = tf.keras.layers.RandomFlip(
        mode="horizontal_and_vertical", seed=seed)
    self.augment_labels = tf.keras.layers.RandomFlip(
        mode="horizontal_and_vertical", seed=seed)

  def call(self, inputs, labels):
    inputs = {name: self.augment_inputs(v) for name, v in inputs.items()}
    labels = self.augment_labels(labels)
    return inputs, labels


def parse_tfrecord(example_proto):
  return tf.io.parse_single_example(example_proto, FEATURES_DICT)


def to_tuple(inputs):
  return (
      {name: inputs[name] for name in INPUT_BANDS},
      inputs[OUTPUT_BANDS[0]] > 0
      # tf.one_hot(tf.cast(inputs[OUTPUT_BANDS[0]], tf.uint8), )
  )


def get_dataset(pattern, batch_size, shuffle=True):
  dataset = tf.data.Dataset.list_files(pattern).interleave(
      lambda filename: tf.data.TFRecordDataset(filename, compression_type='GZIP'))
  dataset = dataset.map(parse_tfrecord, num_parallel_calls=tf.data.AUTOTUNE)
  dataset = dataset.map(to_tuple, num_parallel_calls=tf.data.AUTOTUNE)
  dataset = dataset.cache()
  if shuffle:
    dataset = dataset.shuffle(512)
  dataset = dataset.batch(batch_size)
  dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
  return dataset


# Create the training and validation datasets.
training_dataset = get_dataset('../data/beam/training-*.tfrecord.gz', 8) #.map(Augment(), num_parallel_calls=tf.data.AUTOTUNE)
validation_dataset = get_dataset('../data/beam/validation-*.tfrecord.gz', 1, shuffle=False)

# Inspect the first element from the training dataset.
for inputs, outputs in training_dataset.take(1):
  print("inputs:")
  for name, values in inputs.items():
    print(f"  {name}: {values.dtype.name} {values.shape}")
  print(f"outputs: {outputs.dtype.name} {outputs.shape}")

In [None]:
def get_unet(input_shape):
    inputs = keras.Input(shape=[None, None, len(INPUT_BANDS)])

    # --- Encoder ---
    s1, p1 = encoder_block(inputs, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)
    s4, p4 = encoder_block(p3, 512)

    # --- Bottleneck ---
    b = conv_block(p4, 1024)

    # --- Decoder ---
    d1 = decoder_block(b, s4, 512)
    d2 = decoder_block(d1, s3, 256)
    d3 = decoder_block(d2, s2, 128)
    d4 = decoder_block(d3, s1, 64)

    # --- Output layer ---
    outputs = layers.Conv2D(1, 1, padding="same", activation="sigmoid")(d4)

    return keras.Model(inputs, outputs, name="U-Net")


def conv_block(inputs, num_filters):
    x = layers.SeparableConv2D(num_filters, 3, padding="same", activation="relu")(inputs)
    x = layers.SeparableConv2D(num_filters, 3, padding="same", activation="relu")(x)
    return x


def encoder_block(inputs, num_filters):
    x = conv_block(inputs, num_filters)
    p = layers.MaxPooling2D((2, 2))(x)
    return x, p


def decoder_block(inputs, skip, num_filters):
    x = layers.Conv2DTranspose(num_filters, 2, strides=2, padding="same")(inputs)
    x = layers.concatenate([x, skip])
    x = conv_block(x, num_filters)
    return x


def get_unet_lite(input_shape):
    inputs = layers.Input(shape=input_shape)

    # --- Encoder (shallow + fewer filters) ---
    s1, p1 = encoder_block(inputs, 32)
    s2, p2 = encoder_block(p1, 64)
    s3, p3 = encoder_block(p2, 128)

    # --- Bottleneck ---
    b = conv_block(p3, 256)

    # --- Decoder ---
    d1 = decoder_block(b, s3, 128)
    d2 = decoder_block(d1, s2, 64)
    d3 = decoder_block(d2, s1, 32)

    # --- Output ---
    outputs = layers.Conv2D(1, 1, activation="sigmoid")(d3)

    return keras.Model(inputs, outputs, name="U-Net-Lite")

In [None]:
def get_model(input_shape):
    inputs = keras.Input(shape=[None, None, len(INPUT_BANDS)])

    ### [First half of the network: downsampling inputs] ###

    # Entry block
    x = layers.Conv2D(32, 3, strides=2, padding="same")(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    previous_block_activation = x  # Set aside residual

    for filters in [64, 128]:
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.MaxPooling2D(3, strides=2, padding="same")(x)

        # Project residual
        residual = layers.Conv2D(filters, 1, strides=2, padding="same")(
            previous_block_activation
        )
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    ### [Second half of the network: upsampling inputs] ###

    for filters in [128, 64, 32]:
        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.UpSampling2D(2)(x)

        # Project residual
        residual = layers.UpSampling2D(2)(previous_block_activation)
        residual = layers.Conv2D(filters, 1, padding="same")(residual)
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    # Add a per-pixel classification layer
    outputs = layers.Conv2D(1, 3, activation="sigmoid", padding="same")(x)

    # Define the model
    model = keras.Model(inputs, outputs)
    return model

def get_mlp(input_shape):
    inputs = keras.Input(shape=[None, None, len(INPUT_BANDS)])

    ### [First half of the network: downsampling inputs] ###

    # Entry block
    x = layers.Dense(128, activation='relu')(inputs)
    # x = layers.Dropout(0.3)(x)
    x = layers.Dense(64, activation='relu')(x)

    # Add a per-pixel classification layer
    outputs = layers.Dense(1, activation="sigmoid")(x)

    # Define the model
    model = keras.Model(inputs, outputs)
    return model

def get_multi_scale_mlp_head(input_shape, hidden=128):
    inputs = keras.Input(shape=[None, None, len(INPUT_BANDS)])

    # --- scale 1 (original resolution) ---
    s1 = layers.Dense(hidden, activation="gelu")(inputs)

    # --- scale 2 (128x128) ---
    s2 = layers.AveragePooling2D(pool_size=2)(inputs)
    s2 = layers.Dense(hidden, activation="gelu")(s2)
    s2 = layers.UpSampling2D(size=2, interpolation="bilinear")(s2)

    # --- scale 3 (64x64) ---
    s3 = layers.AveragePooling2D(pool_size=4)(inputs)
    s3 = layers.Dense(hidden, activation="gelu")(s3)
    s3 = layers.UpSampling2D(size=4, interpolation="bilinear")(s3)

    # Fuse
    fused = layers.Concatenate()([s1, s2, s3])
    fused = layers.LayerNormalization()(fused)
    fused = layers.Dense(hidden, activation="gelu")(fused)

    outputs = layers.Dense(1, activation='sigmoid')(fused)

    # Define the model
    model = keras.Model(inputs, outputs)
    return model


In [None]:
model = get_mlp([PATCH_SIZE, PATCH_SIZE, len(INPUT_BANDS)])
model.summary()

In [None]:
# Define the input dictionary layers.
inputs_dict = {
    name: tf.keras.Input(shape=(None, None, 1), name=name)
    for name in INPUT_BANDS
}

concat = tf.keras.layers.Concatenate()(list(inputs_dict.values()))
new_model = tf.keras.Model(inputs=inputs_dict, outputs=model(concat))
# print(new_model(inputs))

In [None]:
new_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0025),
    loss="Dice",
    metrics=[
        tf.keras.metrics.BinaryIoU(target_class_ids=[1]),
    ]
    )

checkpoint_filepath = './checkpoint.model.keras'
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    monitor='val_loss',
    mode='min',
    save_best_only=True)

early_stopping_callback = keras.callbacks.EarlyStopping(
    monitor='val_loss',
    mode='min',
    patience=10)

new_model.fit(
    training_dataset,
    validation_data=validation_dataset,
    epochs=25,
    callbacks=[model_checkpoint_callback, early_stopping_callback]
)

In [None]:
new_model = tf.keras.models.load_model('checkpoint.model.keras')

In [None]:
valid_masks = np.array([b[1][i].numpy() for b in validation_dataset for i in range(b[1].shape[0])])
valid_burn_lastyear= np.array([b[0]['burned_area_2023'][i].numpy() for b in validation_dataset for i in range(b[1].shape[0])])

In [None]:
from sklearn.metrics import f1_score, precision_score, recall_score, jaccard_score

In [None]:
out = new_model.predict(validation_dataset)


In [None]:

print(f1_score(valid_masks.flatten()>0.5, valid_burn_lastyear.flatten()>0.5))
print(recall_score(valid_masks.flatten()>0.5, valid_burn_lastyear.flatten()>0.5))
print(precision_score(valid_masks.flatten()>0.5, valid_burn_lastyear.flatten()>0.5))
print(jaccard_score(valid_masks.flatten()>0.5, valid_burn_lastyear.flatten()>0.5))

In [None]:

print(f1_score(valid_masks.flatten()>0.5, out.flatten()>0.99))
print(recall_score(valid_masks.flatten()>0.5, out.flatten()>0.99))
print(precision_score(valid_masks.flatten()>0.5, out.flatten()>0.99))
print(jaccard_score(valid_masks.flatten()>0.5, out.flatten()>0.99))

In [None]:
def visualize_risk_predict(input_batch, target_batch, output_i, batch_i, suptitle, cutoff=0.5):
    fig, axs = plt.subplots(2,2)
    fig.suptitle(suptitle)
    # Embeddings
    rgb = np.stack([
        input_batch['A01'][batch_i].numpy(),
        input_batch['A16'][batch_i].numpy(),
        input_batch['A09'][batch_i].numpy()], axis=2)
    # shift
    vmin=-0.3
    vmax=0.3
    rgb = (rgb - vmin)/(vmax - vmin)
    axs.flatten()[0].imshow(rgb)
    axs.flatten()[0].set_title('Embeddings')

    # 2022 burn
    axs.flatten()[1].imshow(input_batch['burned_area_2023'][batch_i].numpy() > 0.5)
    axs.flatten()[1].set_title('Burned area 2023')

    # Prediction
    axs.flatten()[2].imshow(output_i)
    axs.flatten()[2].set_title('Predicted burned area 2024')

    # 2023 burn (target)
    axs.flatten()[3].imshow(target_batch[batch_i].numpy()>cutoff)
    axs.flatten()[3].set_title('Actual burned area 2024')
    fig.tight_layout()

    plt.show()


In [None]:
j = 0
for batch in validation_dataset:
    for i in range(batch[1].shape[0]):
        if (batch[1][i].numpy()>0.5).sum()>0 or (out[j]>0.5).sum()>0:
            visualize_risk_predict(
                input_batch = batch[0],
                target_batch = batch[1],
                output_i = out[j],
                batch_i=i,
                suptitle='Image {}'.format(j),
                cutoff=0.99
            )
        j+=1