# 00. Train Model

Helper notebook for training fire risk model on colab

In [None]:
!git clone https://github.com/kysolvik/aic-risk-modeling.git
!cd /content/aic-risk-modeling/; git checkout vertex-train; pip install -e .
!pip install --upgrade tensorflow-metadata


In [None]:

!pip install --upgrade tensorflow-metadata


In [None]:
from google.colab import auth
auth.authenticate_user()

In [None]:
# Restart runtime
from IPython import get_ipython

get_ipython().kernel.do_shutdown(restart=True)

In [None]:
# import matplotlib.pyplot as plt
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import pickle
import json
import aic_risk_modeling.train as train

In [None]:
SEED = 54
RNG = np.random.default_rng(SEED)

# Set params

In [None]:
# GCS Dir holding schema and tfrecords
GCS_DATA_DIR = 'gs://aic-fire-amazon/results_2024_5k/'
# glob match path for tfrecords
TFRECORD_PATTERN = '*.tfrecord.gz'
# Set patch size
PATCH_SIZE = 128

# For model, input and output bands
INPUT_BANDS = (
    ['{}{:02d}'.format('A', i) for i in range(64)]
    + ['burned_area_2023']
)
OUTPUT_BAND = 'BurnDate'


# Create dataset

In [None]:
training_pattern = os.path.join(GCS_DATA_DIR, 'training-{}'.format(TFRECORD_PATTERN))
validation_pattern = os.path.join(GCS_DATA_DIR, 'validation-{}'.format(TFRECORD_PATTERN))

schema = train.load_schema_from_gcs(GCS_DATA_DIR)
feature_spec = train.build_features_dict(schema, patch_size=PATCH_SIZE)

print("Example features:")
for i, f in enumerate(list(feature_spec.keys())[:20]):
    print(i + 1, f)

training_ds = train.dataset_from_gcs(training_pattern, feature_spec,
                               input_bands=[k for k in feature_spec.keys() if k not in ['lat','lon','id', OUTPUT_BAND]],
                               output_bands=[OUTPUT_BAND],
                               batch_size=4,
                               shuffle_buffer=64,
                               cache=False)
validation_ds = train.dataset_from_gcs(validation_pattern, feature_spec,
                               input_bands=[k for k in feature_spec.keys() if k not in ['lat','lon','id', OUTPUT_BAND]],
                               output_bands=[OUTPUT_BAND],
                               batch_size=4,
                               shuffle=False,
                               cache=False)
for inputs, labels in training_ds.take(1):
    print("Batch inputs keys:", list(inputs.keys()))
    print("Label shape:", labels.shape)

In [None]:
model = train.get_mlp([PATCH_SIZE, PATCH_SIZE, len(INPUT_BANDS)])
model.summary()

In [None]:
# Define the input dictionary layers.
inputs_dict = {
    name: tf.keras.Input(shape=(None, None, 1), name=name)
    for name in INPUT_BANDS
}

concat = tf.keras.layers.Concatenate()(list(inputs_dict.values()))
new_model = tf.keras.Model(inputs=inputs_dict, outputs=model(concat))

In [None]:

new_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0025),
    loss="Dice",
    metrics=[
        tf.keras.metrics.BinaryIoU(target_class_ids=[1]),
        ]
    )

checkpoint_filepath = './checkpoint.model.keras'
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    monitor='val_loss',
    mode='min',
    save_best_only=True)

early_stopping_callback = keras.callbacks.EarlyStopping(
    monitor='val_loss',
    mode='min',
    patience=5)

new_model.fit(
    training_ds,
    validation_data=validation_ds,
    epochs=20,
    steps_per_epoch=500,
    validation_steps=100,
    callbacks=[model_checkpoint_callback, early_stopping_callback]
)

In [None]:
new_model = tf.keras.models.load_model('checkpoint.model.keras')

In [None]:
small_ds = validation_ds.take(50)
valid_masks = np.array(list(small_ds.map(lambda inputs, mask: mask).as_numpy_iterator()))
valid_masks = valid_masks.reshape(-1, 128, 128)


In [None]:
valid_masks = np.array([b[1][i].numpy() for b in validation_ds for i in range(b[1].shape[0])])
valid_burn_lastyear= np.array([b[0]['burned_area_2023'][i].numpy() for b in validation_ds for i in range(b[1].shape[0])])

In [None]:
pred_list = []
mask_list  =[]
for images, labels in small_ds:
    preds = new_model.predict(images)
    pred_list.append(preds)
    mask_list.append(labels)

In [None]:
valid_masks = np.array(mask_list).reshape(-1, 128, 128)
out = np.array(pred_list).reshape(-1, 128, 128)

In [None]:
from sklearn.metrics import f1_score, precision_score, recall_score, jaccard_score

In [None]:
# out = new_model.predict(validation_ds)

In [None]:

print(f1_score(valid_masks.flatten()>0.5, valid_burn_lastyear.flatten()>0.5))
print(recall_score(valid_masks.flatten()>0.5, valid_burn_lastyear.flatten()>0.5))
print(precision_score(valid_masks.flatten()>0.5, valid_burn_lastyear.flatten()>0.5))
print(jaccard_score(valid_masks.flatten()>0.5, valid_burn_lastyear.flatten()>0.5))

In [None]:
(out > 0.5).sum()

In [None]:

print(f1_score(valid_masks.flatten()>0.5, out.flatten()>0.9))
print(recall_score(valid_masks.flatten()>0.5, out.flatten()>0.5))
print(precision_score(valid_masks.flatten()>0.5, out.flatten()>0.5))
print(jaccard_score(valid_masks.flatten()>0.5, out.flatten()>0.5))

In [None]:
def visualize_risk_predict(input_batch, target_batch, output_i, batch_i, suptitle, cutoff=0.5):
    import matplotlib.pyplot as plt
    fig, axs = plt.subplots(2,2)
    fig.suptitle(suptitle)
    # Embeddings
    rgb = np.stack([
        input_batch['A01'][batch_i].numpy(),
        input_batch['A16'][batch_i].numpy(),
        input_batch['A09'][batch_i].numpy()], axis=2)
    # shift
    vmin=-0.3
    vmax=0.3
    rgb = (rgb - vmin)/(vmax - vmin)
    axs.flatten()[0].imshow(rgb)
    axs.flatten()[0].set_title('Embeddings')

    # 2022 burn
    axs.flatten()[1].imshow(input_batch['burned_area_2023'][batch_i].numpy() > 0.5)
    axs.flatten()[1].set_title('Burned area 2023')

    # Prediction
    axs.flatten()[2].imshow(output_i)
    axs.flatten()[2].set_title('Predicted burned area 2024')

    # 2023 burn (target)
    axs.flatten()[3].imshow(target_batch[batch_i].numpy()>cutoff)
    axs.flatten()[3].set_title('Actual burned area 2024')
    fig.tight_layout()

    plt.show()


In [None]:
for i in range(valid_masks.shape[0]):
    visualize_risk_predict(
        input_batch = inputs,
        target_batch = valid_masks,
        output_i = out[i],
        batch_i=i,
        suptitle='Image {}'.format(i),
        cutoff=0

In [None]:
j = 0
for batch in small_ds:
    out = new_model.predict(batch[0])
    for i in range(batch[1].shape[0]):
        if (batch[1][i].numpy()>0.5).sum()>0 or (out[i]>0.5).sum()>0:
            visualize_risk_predict(
                input_batch = batch[0],
                target_batch = batch[1],
                output_i = out[i],
                batch_i=i,
                suptitle='Image {}'.format(j),
                cutoff=0.5
            )
        j+=1
        if j > 100:
           break

In [None]:
j = 0
for batch in validation_ds:
    for i in range(batch[1].shape[0]):
        if (batch[1][i].numpy()>0.5).sum()>0 or (out[j]>0.5).sum()>0:
            visualize_risk_predict(
                input_batch = batch[0],
                target_batch = batch[1],
                output_i = out[j],
                batch_i=i,
                suptitle='Image {}'.format(j),
                cutoff=0.5
            )
        j+=1
        if j > 100:
           break