In [None]:
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from keras import layers
from keras.models import Model
from tensorflow.keras.applications import MobileNetV2
import keras
import keras_cv
import tensorflow as tf
import os
import numpy as np

# Model configuration

In [None]:
def get_configs(num_classes):
    if num_classes > 2:
        loss_fn = keras.losses.CategoricalCrossentropy(from_logits=False)
        model_metric = keras.metrics.MeanIoU(
            num_classes=num_classes,
            sparse_y_true=False,
            sparse_y_pred=False,
            name='mean_iou'
        )
        model_metric2 = keras.metrics.CategoricalAccuracy()
        activation = 'softmax'
    else:
        loss_fn = keras.losses.BinaryCrossentropy(from_logits=False)
        model_metric = tf.keras.metrics.BinaryIoU(target_class_ids=[0, 1], name='binary_iou')
        model_metric2 = 'accuracy'
        activation = 'sigmoid'
    return loss_fn, model_metric, model_metric2, activation

# Create U-Net model with MobilenetV2 backbone

In [None]:
# Unet with pre-trained MobilenetV2 encoder
def unet_builder(img_size, num_classes, weights = None):
    def decoder_block(inputs, skip, num_filters, block_name):
        "Decoder Block generator"
        x = layers.UpSampling2D(size=(2, 2), name=f"{block_name}_upsampling")(inputs)
        x = layers.Conv2D(num_filters,
                          (3, 3),
                          padding='same',
                          activation='relu',
                          kernel_initializer='he_normal',
                          name=f"{block_name}_conv")(x)
        x = layers.Concatenate(name=f"{block_name}_concat")([x, skip])
        return x

    # Input Layer
    inputs = keras.Input(shape = img_size, name = 'Input')

    # Encoder
    encoder = tf.keras.applications.MobileNetV2(input_shape=img_size, include_top=False, input_tensor= inputs)
    s1 = encoder.get_layer("block_1_expand_relu").output
    s2 = encoder.get_layer("block_3_expand_relu").output
    s3 = encoder.get_layer("block_6_expand_relu").output
    s4 = encoder.get_layer("block_13_expand_relu").output

    # Bottleneck layer
    b1 = encoder.get_layer("block_16_expand_relu").output

    # Freeze pre-trained backbone
    if weights:
        encoder.trainable = False

    # Decoder
    d1 = decoder_block(b1, s4, 512, "decoder_d1")
    d2 = decoder_block(d1, s3, 256, "decoder_d2")
    d3 = decoder_block(d2, s2, 128, "decoder_d3")
    d4 = decoder_block(d3, s1, 64,  "decoder_d4")

    loss_fn, model_metric, model_metric2, activation = get_configs(num_classes)

    up = layers.UpSampling2D(size=(2, 2), name='upsampling')(d4)
    outputs = layers.Conv2D(num_classes,
                            (3, 3),
                            padding='same',
                            activation=activation,
                            name='output')(up)

    # Create the UNet model
    model = keras.models.Model(inputs=inputs, outputs=[outputs])

    # Compile the model
    model.compile(optimizer = tf.keras.optimizers.Adam(),
        loss=loss_fn,
        metrics=[model_metric,model_metric2],
    )

    if weights:
       model.load_weights(weights)

    return model

# Augmentation pipeline with Keras CV

In [None]:
#---------------Augmentation pipeline------------
color_augmentations = [
    keras_cv.layers.RandomChannelShift(value_range=(0, 1), factor=0.2),
    keras_cv.layers.RandomColorDegeneration(factor=0.5),
    keras_cv.layers.RandomHue(factor=0.2, value_range=(0.5, 0.8)),
]

# Define spatial augmentations
spatial_augmentations = [
    keras_cv.layers.RandomFlip(mode='horizontal'),
    keras_cv.layers.RandomFlip(mode='vertical'),
    keras_cv.layers.RandomRotation(factor = 0.15),
]

image_only_augmentations = keras_cv.layers.RandomAugmentationPipeline(layers=color_augmentations ,
                                                                      augmentations_per_image=2,
                                                                      seed=np.random.randint(1e6))
image_mask_augmentations = keras_cv.layers.RandomAugmentationPipeline(layers=spatial_augmentations,
                                                                      augmentations_per_image=2,
                                                                      seed=np.random.randint(1e6))

def apply_augmentation_pipeline(inputs):
    """Apply the augmentation pipelines to images and masks"""
    image = inputs["images"]
    mask = inputs["segmentation_masks"]

    # Apply color-based augmentations to images
    augmented_image = image_only_augmentations(image)

    # Apply spatial augmentations to both images and masks
    augmented_image_mask = image_mask_augmentations(tf.concat([image, tf.cast(mask, dtype= tf.float32)], axis=-1))
    augmented_image, augmented_mask = tf.split(augmented_image_mask, [3, 1], axis=-1)

    return {"images": augmented_image, "segmentation_masks": tf.cast(augmented_mask, dtype= tf.uint8)}

# Load the small fire detection dataset

In [None]:
# Dataset source: https://github.com/hayatkhan8660-maker/Fire_Seg_Dataset?tab=readme-ov-file

BATCH_SIZE = 32
EPOCHS = 20
NUM_CLASSES = 2
MODEL_WIDTH = 224
MODEL_HEIGHT = 224

DATASET_ROOT = "../small_fire_detect"

# Dataset files
train_images_dir = f"{DATASET_ROOT}/images_train"
train_masks_dir = f"{DATASET_ROOT}/masks_train"
test_images_dir = f"{DATASET_ROOT}/images_val"
test_masks_dir = f"{DATASET_ROOT}/masks_val"

def create_tf_dataset(images_dir, masks_dir):
    def load_image(image_path, mask_path):
        image = tf.io.read_file(image_path)
        image = tf.image.decode_jpeg(image, channels=3)
        image = tf.image.convert_image_dtype(image, tf.float32)

        mask = tf.io.read_file(mask_path)
        mask = tf.image.decode_png(mask, channels=1)
        mask = tf.cast(mask, tf.uint8)

        return image, mask
    
    image_paths = [os.path.join(images_dir, fname) for fname in sorted(os.listdir(images_dir))]
    mask_paths = [os.path.join(masks_dir, fname) for fname in sorted(os.listdir(masks_dir))]
    dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))
    dataset = dataset.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
    return dataset

def preprocess_ds(inputs, batch_size = BATCH_SIZE, augment=False):
    def binary_mask(mask):
        mask = mask / 255
        return tf.cast(mask, dtype= tf.uint8)

    def resize(image, mask):
        image = tf.image.resize(image, [MODEL_HEIGHT, MODEL_WIDTH])
        mask = tf.image.resize(mask, [MODEL_HEIGHT, MODEL_WIDTH], method='nearest')
        return {"images": image, "segmentation_masks": binary_mask(mask)}

    dataset = inputs.map(resize, num_parallel_calls=tf.data.AUTOTUNE).shuffle(buffer_size = len(inputs))
    if augment:
        dataset = dataset.map(apply_augmentation_pipeline, num_parallel_calls=tf.data.AUTOTUNE)
    return dataset.batch(batch_size, drop_remainder = True).prefetch(tf.data.AUTOTUNE).cache()

# Create the datasets
train_ds = create_tf_dataset(train_images_dir, train_masks_dir)
eval_ds = create_tf_dataset(test_images_dir, test_masks_dir)

# Pre-process/Augment
train_ds = preprocess_ds(train_ds, augment= True)
eval_ds = preprocess_ds(eval_ds, augment= False)


# Prepare for training
def dict_to_tuple(x):
    return x["images"], tf.one_hot(
        tf.cast(tf.squeeze(x["segmentation_masks"], axis=-1), "int32"), 2
    )

train_ds = train_ds.map(dict_to_tuple)
eval_ds = eval_ds.map(dict_to_tuple)

# Train the model

In [None]:
saving_cb = ModelCheckpoint(
    filepath='best_unet.weights.h5',
    save_weights_only=True,
    monitor='val_loss',
    mode='min',
    save_best_only=True,
    verbose=1
)

early_stopping_cb = EarlyStopping(
    monitor='val_loss',
    mode='min',
    patience=8,
    restore_best_weights=True,
    verbose=1
)

model = unet_builder((MODEL_HEIGHT, MODEL_WIDTH, 3),
                   NUM_CLASSES,
                   weights= None)
model.summary()

history = model.fit(train_ds, validation_data=eval_ds, epochs=EPOCHS, callbacks = [saving_cb, early_stopping_cb])

# Plot training metrics

In [None]:
def plot_metrics(history, title=''):
    metrics = ['loss', 'binary_iou', 'accuracy'] if NUM_CLASSES == 2 else ['loss', 'mean_iou', 'categorical_accuracy']
    plt.figure(figsize=(18, 5))
    plt.suptitle(title)

    for i, metric in enumerate(metrics):
        plt.subplot(1, len(metrics), i + 1)
        plt.plot(history.history[metric], label='train')
        plt.plot(history.history['val_' + metric], label='val')
        plt.title(metric)
        plt.xlabel('Epochs')
        plt.ylabel(metric)
        plt.legend()

    plt.show()

plot_metrics(history, 'Segmentation head training')

# Finetune backbone on the dataset

In [None]:
def finetune_full_model(model, train_ds, val_ds, epochs=10,num_classes = 2, learning_rate = 0.0001, callbacks=None):
    for layer in model.layers:
        layer.trainable = True

    loss_fn, model_metric, model_metric2, _ = get_configs(num_classes)

    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
                  loss=loss_fn,
                  metrics=[model_metric, model_metric2])

    history = model.fit(train_ds,
                        validation_data=val_ds,
                        epochs=epochs,
                        callbacks=callbacks)

    return history

history = finetune_full_model(model, train_ds, eval_ds, epochs=10, learning_rate= 0.0001, callbacks=[saving_cb, early_stopping_cb])

# Test inference

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import cv2

def plot_segmentation_results(model, dataset, num_images=4, num_classes=21, normalized=False):
    def overlay_segmentation(image, mask):

        image = np.squeeze(image, axis = 0)
        # Create a copy of the image to overlay green on the mask regions
        overlay = image.copy()
        
        # Apply green color (set green channel to 255) where the mask is non-zero
        overlay[mask != 0] = [0, 255, 0]
        
        return overlay

    images = []
    predictions = []
    
    # Get the specified number of images and their corresponding predictions
    for img, _ in dataset.take(num_images):
        if normalized:
            image = img * 255
        images.append(image.numpy().astype(np.uint8))
        
        # Run inference
        pred_mask = model.predict(img)
        pred_mask = np.squeeze(pred_mask, axis=0)
        pred_classes = np.argmax(pred_mask, axis=-1)
        predictions.append(pred_classes)
    
    # Plot the images and their segmentation results in a grid
    plt.figure(figsize=(12, 6 * num_images // 2))
    for i in range(num_images):
        plt.subplot(num_images // 2, 4, 2 * i + 1)
        plt.imshow(images[i][0])
        plt.title("Original Image")
        plt.axis("off")
        
        plt.subplot(num_images // 2, 4, 2 * i + 2)
        overlay_img = overlay_segmentation(images[i], predictions[i])
        plt.imshow(overlay_img)
        plt.title("Segmented Image")
        plt.axis("off")
    
    plt.tight_layout()
    plt.show()

plot_segmentation_results(model, eval_ds, num_images=8, num_classes=2, normalized=True)