In [32]:
import gradio as gr

In [33]:
import os 
import cv2
import pickle
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

from patchify import patchify, unpatchify
from tensorflow.keras.models import Model, load_model

from tensorflow.keras.losses import Loss
from tensorflow.keras import backend as K
from tensorflow.keras.saving import register_keras_serializable

In [34]:
@register_keras_serializable()
class DiceFocalLoss(Loss):
    def __init__(self, alpha=0.25, gamma=2.0, dice_weight=0.5, focal_weight=0.5, 
                 smooth=1e-6, from_logits=False, name="dice_focal_loss", reduction = 'sum_over_batch_size'):
        super(DiceFocalLoss, self).__init__(name=name, reduction=reduction)
        self.alpha = alpha
        self.gamma = gamma
        self.dice_weight = dice_weight
        self.focal_weight = focal_weight
        self.smooth = smooth
        self.from_logits = from_logits
        
    def dice_loss(self, y_true, y_pred, use_true_union=False):
        """
        Calculate Dice Loss for multiclass segmentation.
        
        Args:
            y_true: Ground truth labels (batch_size, height, width, num_classes)
            y_pred: Predicted probabilities (batch_size, height, width, num_classes)
            use_true_union: If True, uses A+B-A∩B formula; if False, uses A+B (Sørensen-Dice)
        """
        # Calculate intersection by summing over spatial dimensions (axis 1,2) for each class
        intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])  # Shape: (batch, num_classes)
        
        if use_true_union:
            # True set theory union: A + B - A∩B
            sum_sets = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
            union = sum_sets - intersection
        else:
            # Sørensen-Dice coefficient: A + B (more commonly used in segmentation)
            union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
        
        # Dice coefficient for each class and batch
        dice_coeff = (2.0 * intersection + self.smooth) / (union + self.smooth)
        
        # Average across classes and batches
        dice_loss = 1.0 - tf.reduce_mean(dice_coeff)
        return dice_loss
    
    def focal_loss(self, y_true, y_pred):
        """
        Calculate Focal Loss for multiclass segmentation.
        
        Args:
            y_true: Ground truth labels (batch_size, height, width, num_classes)
            y_pred: Predicted probabilities (batch_size, height, width, num_classes)
        """
        # Clip predictions to prevent log(0)
        epsilon = tf.keras.backend.epsilon()
        y_pred = tf.clip_by_value(y_pred, epsilon, 1.0 - epsilon)
        
        # Standard cross entropy
        ce_loss = -y_true * tf.math.log(y_pred)
        
        # For focal loss, p_t is simply the predicted probability of the true class
        # Since y_pred is already probabilities and y_true is one-hot encoded:
        p_t = tf.reduce_sum(y_true * y_pred, axis=-1, keepdims=True)  # Extract prob of true class
        
        # Apply focal loss modulation: (1-p_t)^gamma
        focal_weight = tf.pow((1 - p_t), self.gamma)
        
        # Apply class weighting (alpha) and focal weight
        focal_loss = self.alpha * focal_weight * ce_loss
        
        # Return mean loss
        return tf.reduce_mean(focal_loss)
    
    def call(self, y_true, y_pred):
        """
        Calculate combined Dice + Focal Loss.
        """
        # Convert logits to probabilities if needed
        if self.from_logits:
            y_pred = tf.nn.softmax(y_pred, axis=-1)
        
        # Calculate individual losses
        dice_loss_val = self.dice_loss(y_true, y_pred)
        focal_loss_val = self.focal_loss(y_true, y_pred)
        
        # Combine losses with weights
        total_loss = (self.dice_weight * dice_loss_val + 
                     self.focal_weight * focal_loss_val)
        
        return total_loss

In [35]:
@register_keras_serializable()
def multiclass_jaccard_index(y_true, y_pred, smooth=1e-6):
    assert y_true.shape == y_pred.shape
    y_true_flatten = K.flatten(y_true)
    y_pred_flatten = K.flatten(y_pred)
    intersection = K.sum(y_true_flatten * y_pred_flatten)
    jaccard_index = (intersection + smooth) / (K.sum(y_true_flatten) + K.sum(y_pred_flatten) - intersection + smooth)
    return jaccard_index

In [36]:
saved_model = load_model('../model/semantic-segmentation-aerial-unet-v0.keras')
INPUT_CLASS_FILE_PATH = os.path.join('../data', 'training_data_class_map.pkl')

with open(INPUT_CLASS_FILE_PATH, 'rb') as f:
    data = pickle.load(f)

label_map = data['label_map']
class_labels_rgb = data['class_labels_rgb']

print(label_map)
print(class_labels_rgb)

class_rgb = {label_map.get(k): class_labels_rgb.get(k) for k in label_map.keys()}
label_map_rev = {v: k for k, v in label_map.items()}


def pad_to_nearest_multiple_patchsize(image, patch_size):
    height, width = image.shape[:2]
    pad_h = (patch_size - height % patch_size) % patch_size
    pad_w = (patch_size - width % patch_size) % patch_size
    padded = np.pad(image, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect')
    return padded, height, width

def unpad_to_original_size(padded_image, original_height, original_width):
    if padded_image.ndim == 2:
        return padded_image[:original_height, :original_width]
    return padded_image[:original_height, :original_width, :]



def convert_to_rgb_colors(class_rgb, mask):
    num_classes = len(class_rgb)
    mask = mask.astype(np.int32)
    lut = np.zeros((num_classes, 3), dtype=np.uint8)

    for cl_id, color in class_rgb.items():
        lut[cl_id] = color

    rgb_mask = lut[mask]

    return rgb_mask   

    
    

{'Building': 0, 'Land (unpaved area)': 1, 'Road': 2, 'Vegetation': 3, 'Water': 4, 'Unlabeled': 5}
{'Building': (60, 16, 152), 'Land (unpaved area)': (132, 41, 246), 'Road': (110, 193, 228), 'Vegetation': (254, 221, 58), 'Water': (226, 169, 41), 'Unlabeled': (155, 155, 155)}


In [37]:
def preprocessing_pipeline(image, PATCH_SIZE=256):
    # image = cv2.imread(image_path)
    # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    padded_image, height, width = pad_to_nearest_multiple_patchsize(image, PATCH_SIZE)
    padded_image = padded_image / 255.0
    ph, pw = padded_image.shape[:2]
    im_patches = patchify(padded_image, (PATCH_SIZE, PATCH_SIZE, 3), step=PATCH_SIZE)
    num_i, num_j = im_patches.shape[:2]
    return im_patches, num_i, num_j, height, width, ph, pw

In [38]:
def convert_to_rgb_colors(class_rgb, mask):
    num_classes = len(class_rgb)
    mask = mask.astype(np.int32)
    lut = np.zeros((num_classes, 3), dtype=np.uint8)

    for cl_id, color in class_rgb.items():
        lut[cl_id] = color

    rgb_mask = lut[mask]

    return rgb_mask

In [39]:
def predict(image_filepath):
    # print(type(image_filepath))
    im_patches, num_i, num_j, height, width, ph, pw = preprocessing_pipeline(image_filepath)
    im_shape = im_patches.shape[:-1]
    mask_patches = np.zeros(im_shape)
    for i in range(num_i):
        for j in range(num_j):
            prediction = saved_model.predict(im_patches[i, j, :, :, :])
            predicted_image = np.argmax(prediction, axis=3)
            mask_patches[i, j, :, :, :] = predicted_image
    patches_reshaped = mask_patches.squeeze(axis=2) 
    unpatched_mask = unpatchify(patches_reshaped, (ph, pw))
    reconstructed_mask = unpad_to_original_size(unpatched_mask, height, width)
    rgb_mask = convert_to_rgb_colors(class_rgb, reconstructed_mask)
    return rgb_mask

In [40]:
with gr.Blocks() as demo:
    title = gr.Markdown("""<span style="font-size: 24px; font-weight: bold;">Semantic Segmentation on Satellite Images</span>""")
    with gr.Row():
        with gr.Column():
            gr.Markdown("Input Image")
            in_image = gr.Image(type="numpy", image_mode='RGB')
            btn = gr.Button('Submit')
        with gr.Column():
            gr.Markdown("Segmentation Result")
            out_image = gr.Image()
        
    btn.click(predict, inputs=in_image, outputs=out_image)

demo.launch()
    

* Running on local URL:  http://127.0.0.1:7878
* To create a public link, set `share=True` in `launch()`.




[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 809ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 58ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 98ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 57ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 50ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 43ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 58ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 44ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 49ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 55ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 71ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 61ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 46ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3