In [3]:
'''
Example Notebook for running the Wetlands Detection AI model

Prerequisites & Data Prep
Imagery Format: This notebook expects GeoTIFFs.

Bands: The model expects an 6-band input: red, green, blue, NIR, SWIR1 and SWIR2 Sentinel-2 bands (B2, B3, B4, B8, B11, B12).

Bit Depth: The normalization currently assumes unprocessed 16-bit integer values (0-10000).

Hardware: Large images are processed using a sliding window. A GPU is highly recommended for faster inference.
'''
import os
import numpy as np
import rasterio
from rasterio.windows import Window
import tensorflow as tf
from model.seg_model.resunet32 import ResUNet34

print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

# =======================
# 1. Load SwampEye Model
# =======================
def load_swampeye_model(model_path):
    model = ResUNet34()
    model.load_weights(model_path)
    return model

# =======================
# 2. Infer on Patches
# =======================
def infer_patches(patch_list, model):
    predictions = []
    for patch in patch_list:
        patch = patch.transpose(1, 2, 0)  # (H, W, C)

        # Normalize: Assuming 16-bit input (0-10000). Adjust if using 8-bit (0-255).
        patch = np.float32(np.clip(patch / 10000, 0, 1))
        pred = np.squeeze(model(patch[np.newaxis, :]), axis=0)  # (H, W, C)
        predictions.append(np.moveaxis(pred, -1, 0))  # (C, H, W)
    return predictions

# =======================
# 3. Extract Patches
# =======================
def extract_patches(img_path, patch_size=512, min_overlap=0.75):
    with rasterio.open(img_path) as src:
        height, width = src.height, src.width
        stride = int(patch_size * (1 - min_overlap))

        patches = []
        positions = []
        for y in range(0, height - patch_size + 1, stride):
            for x in range(0, width - patch_size + 1, stride):
                window = Window(x, y, patch_size, patch_size)
                patch = src.read(window=window)  # (Bands, H, W)
                patches.append(patch)
                positions.append((x, y))
    return patches, positions, (height, width)

# =======================
# 4. Assemble Mosaic
# =======================
def assemble_mosaic(predictions, positions, shape, threshold=0.5, patch_size=512, return_prob=False):
    num_classes = predictions[0].shape[0]
    H, W = shape
    acc = np.zeros((num_classes, H, W), dtype=np.float32)
    weight = np.zeros((H, W), dtype=np.float32)

    for pred, (x, y) in zip(predictions, positions):
        acc[:, y:y+patch_size, x:x+patch_size] += pred
        weight[y:y+patch_size, x:x+patch_size] += 1

    weight = np.clip(weight, 1e-6, None)
    mosaic = acc / weight[np.newaxis, :, :]

    if return_prob:
        # If only 1 channel, it’s already swamp probability
        prob_map = mosaic[0, :, :] if num_classes == 1 else mosaic[1, :, :]
        return prob_map.astype(np.float32)
    else:
        mask = (mosaic[0, :, :] if num_classes == 1 else mosaic[1, :, :]) > threshold
        return mask.astype(np.uint8)

        
# =======================
# 5. Save Output
# =======================
def save_output(ref_path, output_path, mosaic, return_prob=False, with_colormap=True):
    with rasterio.open(ref_path) as src:
        meta = src.meta.copy()

        if return_prob:
            # Save probability raster (float32, 0–1)
            meta.update({"count": 1, "dtype": "float32", "compress": "lzw"})
            with rasterio.open(output_path, "w", **meta) as dst:
                dst.write(mosaic.astype(np.float32), 1)
            print(f"Saved probability raster to {output_path}")

            # Optional: colorized visualization
            if with_colormap:
                vis_path = output_path.replace(".tif", "_colormap.tif")
                meta.update({"dtype": "uint8"})
                prob_uint8 = np.nan_to_num(mosaic, nan=0.0)  # clean NaNs
                prob_uint8 = (prob_uint8 * 255).clip(0, 255).astype(np.uint8)

                colormap = {i: (
                    int(100 + (123 - 100) * (i / 255)),  # R
                    int(65  + (245 - 65)  * (i / 255)),  # G
                    int(23  + (39  - 23)  * (i / 255)),  # B
                    255
                ) for i in range(256)}

                with rasterio.open(vis_path, "w", **meta) as dst:
                    dst.write(prob_uint8, 1)
                    dst.write_colormap(1, colormap)
                #print(f"Saved colorized probability raster to {vis_path}")

        else:
            # Save binary swamp mask (uint8)
            meta.update({"count": 1, "dtype": "uint8", "compress": "lzw"})
            colormap = {
                0: (100, 65, 23, 255),   # Brown (non-swamp)
                1: (123, 245, 39, 255),  # Green (swamp)
            }
            with rasterio.open(output_path, "w", **meta) as dst:
                dst.write(mosaic, 1)
                dst.write_colormap(1, colormap)
            print(f"Saved binary mask to {output_path}")


# =======================
# 6. Main Driver
# =======================
def run_swampeye_segmentation(
    input_dir,
    swamp_model_path,
    output_dir,
    patch_size=512,
    min_overlap=0.75,
    threshold=0.5,
    return_prob = False
):
    os.makedirs(output_dir, exist_ok=True)
    model = load_swampeye_model(swamp_model_path)

    for filename in os.listdir(input_dir):
        if not filename.lower().endswith((".tif", ".tiff")):
            continue

        print(f"Processing: {filename}")
        input_path = os.path.join(input_dir, filename)

        # Step 1: Extract patches
        patches, positions, img_shape = extract_patches(input_path, patch_size, min_overlap)

        # Step 2: Inference
        preds = infer_patches(patches, model)

        # Step 3: Assemble binary swamp mask
        swamp_mask = assemble_mosaic(preds, positions, img_shape, threshold, patch_size, return_prob)

        # Step 4: Save
        output_path = os.path.join(output_dir, f"{os.path.splitext(filename)[0]}_swamp.tif")
        save_output(input_path, output_path, swamp_mask, return_prob)

run_swampeye_segmentation(
    input_dir='',
    swamp_model_path='model/pretrained/resunet_focal_dice.h5py',
    output_dir='',
    threshold = 0.5,
    return_prob = True
)


Processing: Jamaicabay_01_2_2024.tif


KeyboardInterrupt: 