In [None]:
import numpy as np
import rasterio
from rasterio.warp import calculate_default_transform, reproject, Resampling
from tensorflow.keras.models import load_model
import matplotlib.pyplot as plt
import cv2
import os
from tensorflow.keras.models import load_model
from tensorflow.keras.layers import LeakyReLU
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.metrics import MeanIoU

In [15]:
# Functionalities

# Define Jaccard Loss
def jaccard_loss(y_true, y_pred, smooth=100):
    """
    Calculates the Jaccard loss, also known as the Intersection over Union (IoU) loss.
    Args:
        y_true (tensor): Ground truth labels.
        y_pred (tensor): Predicted labels.
        smooth (float): Smoothing factor to avoid division by zero.
    Returns:
        jaccard loss (tensor)
    """
    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
    sum_ = tf.keras.backend.sum(y_true_f + y_pred_f)
    jac = (intersection + smooth) / (sum_ - intersection + smooth)
    return 1 - jac


# 1. Load the VV and exclusion mask with rasterio

def load_and_reproject_to_vv(vv_path, exclusion_mask_path):
    with rasterio.open(vv_path) as vv_src:
        vv_data = vv_src.read(1)
        vv_meta = vv_src.meta
        
    with rasterio.open(exclusion_mask_path) as exclusion_src:
        exclusion_nodata = exclusion_src.nodata  # Get no-data value of exclusion mask
        if exclusion_src.crs != vv_meta['crs']:
            # If CRS is different, reproject the exclusion mask to VV's CRS
            transform, width, height = calculate_default_transform(
                exclusion_src.crs, vv_meta['crs'], exclusion_src.width, exclusion_src.height, *exclusion_src.bounds
            )
            exclusion_meta = exclusion_src.meta.copy()
            exclusion_meta.update({
                'crs': vv_meta['crs'],
                'transform': transform,
                'width': width,
                'height': height
            })
            
            exclusion_data = np.zeros((height, width), dtype=np.float32)
            reproject(
                source=rasterio.band(exclusion_src, 1),
                destination=exclusion_data,
                src_transform=exclusion_src.transform,
                src_crs=exclusion_src.crs,
                dst_transform=transform,
                dst_crs=vv_meta['crs'],
                resampling=Resampling.nearest
            )
        else:
            exclusion_data = exclusion_src.read(1)

    return vv_data, vv_meta, exclusion_data, exclusion_nodata


# 2. Apply the exclusion mask on VV
def apply_exclusion_mask(vv_data, exclusion_data, exclusion_nodata):
    # Ensure exclusion mask has the same shape as VV
    exclusion_data_resized = cv2.resize(exclusion_data, vv_data.shape[::-1], interpolation=cv2.INTER_NEAREST)
    
    # Apply exclusion mask: where exclusion mask has no-data value, set VV to no-data (0.0)
    vv_data_with_exclusion = np.where(exclusion_data_resized == exclusion_nodata, 0.0, vv_data)
    
    return vv_data_with_exclusion

# 3. Normalize the VV data to range [0, 1] from the range [-25, 5]
def normalize_vv(vv_data):
    vv_normalized = (vv_data + 25) / 30.0  # [-25, 5] to [0, 1]
    vv_normalized = np.clip(vv_normalized, 0, 1)  # Ensure values are within the range
    return vv_normalized

# 4. Load the Keras model
def load_unet_model(model_path):

    custom_objects = {
    'jaccard_loss': jaccard_loss,
    'MeanIoU': MeanIoU(num_classes=2),
    'LeakyReLU': LeakyReLU,
    }

    model = load_model(model_path, custom_objects=custom_objects)
    return model

# 5. Cut the VV into (256, 256) tiles, pad if necessary
def create_patches(vv_data, patch_size=256):
    h, w = vv_data.shape
    padded_vv = np.pad(vv_data, ((0, patch_size - h % patch_size), (0, patch_size - w % patch_size)), constant_values=0)
    
    patches = []
    for i in range(0, padded_vv.shape[0], patch_size):
        for j in range(0, padded_vv.shape[1], patch_size):
            patch = padded_vv[i:i+patch_size, j:j+patch_size]
            patches.append(patch)
    return np.array(patches), padded_vv.shape

# 6. Pass patches through the model and get binary masks
# def predict_masks(model, patches):
#     masks = []
#     for patch in patches:
#         patch = np.expand_dims(patch, axis=(0, -1))  # Reshape to (1, 256, 256, 1) for model input
#         prediction = model.predict(patch)[0]
#         mask = (prediction > 0.5).astype(np.uint8)  # Convert to binary mask
#         masks.append(mask.squeeze())  # (256, 256)
#     return np.array(masks)

def predict_masks(model, patches, batch_size=8):
    masks = []
    num_patches = len(patches)
    
    for i in range(0, num_patches, batch_size):
        batch_patches = patches[i:i+batch_size]
        batch_patches = np.expand_dims(batch_patches, axis=-1)  # Add channel dimension
        predictions = model.predict(batch_patches)
        
        for prediction in predictions:
            mask = (prediction > 0.5).astype(np.uint8)  # Convert to binary mask
            masks.append(mask.squeeze())  # Remove single channel
    return np.array(masks)

# 7. Merge patches and apply no-data masking
def merge_patches(masks, original_shape, patch_size=256):
    merged_mask = np.zeros(original_shape, dtype=np.uint8)
    idx = 0
    for i in range(0, original_shape[0], patch_size):
        for j in range(0, original_shape[1], patch_size):
            merged_mask[i:i+patch_size, j:j+patch_size] = masks[idx]
            idx += 1
    return merged_mask

# 8. Plot the VV and generated mask on top of each other
def plot_vv_and_mask(vv_data, mask):
    plt.figure(figsize=(10, 10))
    plt.subplot(1, 2, 1)
    plt.title("VV Band")
    plt.imshow(vv_data, cmap='gray')
    
    plt.subplot(1, 2, 2)
    plt.title("Generated Mask")
    plt.imshow(mask, cmap='gray', alpha=0.5)
    
    plt.show()

# 9. Save the mask as a TIFF file
def save_mask_as_tiff(mask, reference_meta, output_path):
    # Update metadata to match the mask
    reference_meta.update({
        'count': 1,
        'dtype': 'uint8',  # mask is binary
        'nodata': 255
    })
    
    with rasterio.open(output_path, 'w', **reference_meta) as dst:
        dst.write(mask, 1)

In [16]:
# Pipeline

def main(vv_path, exclusion_mask_path, model_path, output_tif_path):
    # 1. Load VV and exclusion mask
    vv_data, vv_meta, exclusion_data, exclusion_nodata = load_and_reproject_to_vv(vv_path, exclusion_mask_path)
    
    # 2. Apply exclusion mask on VV
    vv_data_masked = apply_exclusion_mask(vv_data, exclusion_data, exclusion_nodata)
    
    # 3. Normalize VV
    vv_data_normalized = normalize_vv(vv_data_masked)
    
    # 4. Load Keras U-Net model
    model = load_unet_model(model_path)
    
    # 5. Create patches of (256, 256)
    patches, padded_shape = create_patches(vv_data_normalized)
    
    # 6. Predict binary masks
    predicted_masks = predict_masks(model, patches)
    
    # 7. Merge patches
    merged_mask = merge_patches(predicted_masks, padded_shape)
    
    # Crop the merged mask to match the original VV data shape
    cropped_mask = merged_mask[:vv_data.shape[0], :vv_data.shape[1]]

    # Apply no-data value to the mask where exclusion mask applied
    final_mask = np.where(vv_data == 0.0, 255, cropped_mask)
    
    # 8. Plot VV and mask
    # plot_vv_and_mask(vv_data, final_mask)
    
    # 9. Save the mask to TIFF
    save_mask_as_tiff(final_mask, vv_meta, output_tif_path)


In [17]:
# Call the main function
vv_path = '../data/inference/srilanka/vv.tif'
exclusion_mask_path = '../data/inference/srilanka/ex-mask/merged/mask.tif'
model_path = '../models_trial/unet_checkpoint_200_2_v2.keras'
output_tif_path = '../data/inference/srilanka/mask.tif'

main(vv_path, exclusion_mask_path, model_path, output_tif_path)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 982ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3