# 04-predict-flooding

Generate flood and permanent water extent predictions using a pre-trained UNet model.

In [None]:
!pip install rasterio

In [2]:
import os
import numpy as np
import pickle
import tensorflow as tf
import rasterio
from rasterio import windows
from rasterio import features
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import backend as K
import matplotlib.pyplot as plt

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [4]:
# get trained UNet flood segmentation model
gdrive_folder_base = os.path.join(os.getcwd(), "drive", "MyDrive")
checkpoint_fname = os.path.join(gdrive_folder_base, "ccai-public", "checkpoints_unet_gt3classes_edge_weighted_ce_dice_loss", "starfm_preds_batchsize_32_lr_0.001.h5")

In [5]:
def dice_loss_edge_weight_ce_multiclass(
    edge_weight: int,
    is_logits: bool,
    num_classes: int
    ):
    """dice_loss_edge_weight_ce_multiclass

    Edge weighted mutliclass cross entropy loss and dice loss.

    Based on Garg et al: https://arxiv.org/pdf/2302.08180.pdf

    Parameters
    ----------
    edge_weight : int
        Weight value for edges in true mask.
    is_logits : bool
        Whether values in `y_pred` are logits or probabilities.
    num_classes : int
         Number of possible classes in segmented image outputs.

    Returns
    -------
    Callable
        Loss function that can be passed into `model.compile()`
        with Keras or Tensorflow.
    """
    def loss(y_true, y_pred):

        ### EDGE WEIGHTED CE ###

        # kernel for erosion and dilation
        kernel = tf.zeros((3, 3, 1), dtype=tf.int64)

        # dilation
        y_true_dil = y_true
        y_true_dil = tf.cast(y_true_dil, dtype=tf.int64)
        y_true_dil = tf.math.greater(
            y_true_dil, tf.constant(0, dtype=tf.int64))
        y_true_dil = tf.where(y_true_dil, 1, 0)
        y_true_dil = tf.cast(y_true_dil, dtype=tf.int64)
        y_comp_dil = tf.nn.dilation2d(
            y_true_dil,
            filters=kernel,
            strides=(1, 1, 1, 1),
            padding="SAME",
            data_format="NHWC",
            dilations=(1, 1, 1, 1)
        )
        y_dil_edges = y_comp_dil - y_true_dil

        # erosion
        y_true_er = y_true
        y_true_er = tf.cast(y_true_er, dtype=tf.int64)
        y_true_er = tf.math.greater(y_true_er, tf.constant(0, dtype=tf.int64))
        y_true_er = tf.where(y_true_er, 1, 0)
        y_true_er = tf.cast(y_true_er, dtype=tf.int64)
        y_comp_er = tf.nn.erosion2d(
            y_true_er,
            filters=kernel,
            strides=(1, 1, 1, 1),
            padding="SAME",
            data_format="NHWC",
            dilations=(1, 1, 1, 1)
        )
        y_er_edges = y_comp_er - y_true_er

        # weights
        edges = tf.add(y_dil_edges, y_er_edges)
        edges = tf.multiply(edges, edge_weight)
        edges = tf.add(edges, 1)

        w_tmp = []
        for i in range(0, num_classes):
            w_tmp.append(edges)

        edge_weights = tf.concat(w_tmp, axis=0)
        edge_weights = tf.cast(edge_weights, "float64")
        edge_weights = tf.reshape(edge_weights, [-1])

        # weighted CE
        if is_logits:
            # Apply softmax to predictions if they are logits
            # Applies softmax along the last dimension
            # assumes predictions are in channels last format
            y_pred = tf.cast(y_pred, "float64")
            y_pred = tf.keras.activations.softmax(y_pred)

        # Following Keras source code - valid outputs when computing logs
        y_pred = tf.cast(y_pred, "float64")
        y_pred = tf.clip_by_value(
            y_pred, tf.keras.backend.epsilon(), 1-tf.keras.backend.epsilon())

        y_true_f = tf.reshape(tf.one_hot(
            tf.cast(y_true, "int32"), depth=num_classes, axis=-1), [-1])
        y_pred_f = tf.reshape(y_pred, [-1])
        y_true_f = tf.cast(y_true_f, "float64")
        y_pred_f = tf.cast(y_pred_f, "float64")

        wce = y_true_f * tf.math.log(y_pred_f) * edge_weights

        wce = tf.reduce_mean(wce)

        wce = -wce

        ### DICE LOSS ###
        intersect = tf.reduce_sum(y_true_f * y_pred_f, axis=-1, keepdims=False)
        denom = tf.reduce_sum(y_true_f + y_pred_f, axis=-1, keepdims=False)

        dice = 1 - (tf.reduce_mean((2. * intersect) / (denom + 1e-7)))

        return dice + wce

    return loss

# TC Yasa

Generate predictions of flooding on croplands in Vanua Levu following Tropical Cyclone Yasa in 2020.  

In [25]:
gdrive_folder = os.path.join(os.getcwd(), "drive", "MyDrive", "tc-yasa-aoi2")

In [26]:
model = keras.models.load_model(checkpoint_fname, custom_objects={"loss": dice_loss_edge_weight_ce_multiclass(3, False, 3)})

In [None]:
# load inputs for segmenting flood extent
inputs_np = np.load(os.path.join(gdrive_folder, "starfm_preds_cday_0.npy"))
print(f"shape of inputs: {inputs_np.shape}")

# add batch
tf_tmp_inputs = tf.convert_to_tensor(inputs_np, dtype=tf.float32)
tf_tmp_inputs = tf.expand_dims(tf_tmp_inputs, 0)

preds = model.predict(tf_tmp_inputs)
preds = np.squeeze(preds, axis=0)
preds_cat = np.argmax(preds, axis=-1)

In [28]:
# save prediction
with rasterio.open(os.path.join(gdrive_folder, "starfm_synth_ndvi_0_cday_idx_0.tif")) as ref:
    meta = ref.meta

with rasterio.open(os.path.join(gdrive_folder, "flood-mask-tc-yasa-preds-cday-idx-0.tif"), "w", **meta) as src:
    src.write(preds_cat.astype("int16"), 1)

In [None]:
plt.imshow(preds_cat)

In [None]:
# load inputs for segmenting flood extent
inputs_np = np.load(os.path.join(gdrive_folder, "starfm_preds_cday_1.npy"))
print(f"shape of inputs: {inputs_np.shape}")

# add batch
tf_tmp_inputs = tf.convert_to_tensor(inputs_np, dtype=tf.float32)
tf_tmp_inputs = tf.expand_dims(tf_tmp_inputs, 0)

preds = model.predict(tf_tmp_inputs)
preds = np.squeeze(preds, axis=0)
preds_cat = np.argmax(preds, axis=-1)

In [31]:
# save prediction
with rasterio.open(os.path.join(gdrive_folder, "starfm_synth_ndvi_0_cday_idx_1.tif")) as ref:
    meta = ref.meta

with rasterio.open(os.path.join(gdrive_folder, "flood-mask-tc-yasa-preds-cday-idx-1.tif"), "w", **meta) as src:
    src.write(preds_cat.astype("int16"), 1)

In [None]:
plt.imshow(preds_cat)