In [None]:
#

# Levee detection demo

The file contains the demonstration of the Swin UNETR model for the detection of the levees. The original model is run in the Docker environment, which is  published in the Git repository

This version of the model can predict the levees, run post-processing and save the results

## Clone the repository

In [1]:
!git clone https://github.com/nazarb/2025_levees_DL.git

Cloning into '2025_levees_DL'...
remote: Enumerating objects: 225, done.[K
remote: Counting objects: 100% (48/48), done.[K
remote: Compressing objects: 100% (36/36), done.[K
remote: Total 225 (delta 31), reused 8 (delta 8), pack-reused 177 (from 1)[K
Receiving objects: 100% (225/225), 1.01 MiB | 3.73 MiB/s, done.
Resolving deltas: 100% (83/83), done.


## Set up path for further processing

In [22]:
import os
basepath = os.getcwd()
print(basepath)
# set other paths for quick navigation
git_path = os.path.join(basepath, "2025_levees_DL")
print(git_path)
Swin_UNETR_path = os.path.join(git_path, "Swin_UNETR")
print(Swin_UNETR_path)
results_path = os.path.join(Swin_UNETR_path, "results/Swin_UNETR/Aug")
os.makedirs(results_path, exist_ok=True)
print(results_path)


/content
/content/2025_levees_DL
/content/2025_levees_DL/Swin_UNETR
/content/2025_levees_DL/Swin_UNETR/results/Swin_UNETR/Aug


## Install libraries

In [4]:
!pip install --no-cache-dir monai==1.3.2


Collecting monai==1.3.2
  Downloading monai-1.3.2-py3-none-any.whl.metadata (10 kB)
Downloading monai-1.3.2-py3-none-any.whl (1.4 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.4 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━[0m [32m0.9/1.4 MB[0m [31m25.9 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m34.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: monai
Successfully installed monai-1.3.2


In [5]:
!sudo apt install megatools
!pip install rasterio


Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following NEW packages will be installed:
  megatools
0 upgraded, 1 newly installed, 0 to remove and 38 not upgraded.
Need to get 207 kB of archives.
After this operation, 898 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 megatools amd64 1.10.3-1build1 [207 kB]
Fetched 207 kB in 1s (195 kB/s)
debconf: unable to initialize frontend: Dialog
debconf: (No usable dialog-like program is installed, so the dialog based frontend cannot be used. at /usr/share/perl5/Debconf/FrontEnd/Dialog.pm line 78, <> line 1.)
debconf: falling back to frontend: Readline
debconf: unable to initialize frontend: Readline
debconf: (This frontend requires a controlling tty.)
debconf: falling back to frontend: Teletype
dpkg-preconfigure: unable to re-open stdin: 
Selecting previously unselected package megatools.
(Reading database ... 126675 files and directorie

In [None]:

import warnings
warnings.filterwarnings("ignore")
import torch
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.is_available()
torch.cuda.device_count()
torch.cuda.current_device()
torch.cuda.get_device_name(0)

'Tesla T4'

## Download the data and the model

The raster data used in this study is calculated with the provided Google Earth Engine code. The model is published in the Dane Badawcze UW repository (Published version: 2025-09-24):

```
Buławka, Nazarij; Orengo, Hector A.; Lumbreras Ruiz, Felipe; Berganzo-Besga, Iban; Gupta, Ekta, 2025, "Traces of ancient irrigation in central Iraq detected using deep learning model", https://doi.org/10.58132/MY8CCL, Dane Badawcze UW, V1
```


In [6]:
import os
import shutil
import requests
import os
import shutil

# Download data the raster calculated with provided Google Earth Engine code


# MEGA file URL
mega_url = "https://mega.nz/file/E9gBCCjR#UCDklOgOVOwQ0hRAIy5pvAdX9Y-VLBQktDlsrj4R1Ms"

# Temporary download path (megadl saves the file with original name by default)
download_dir = "/content"
downloaded_file = os.path.join(download_dir, "CFE_a_selected_L5_S2_S1_MSRM_PCA_GLO_N48.tif")

# Target directory
target_dir1 = f'{Swin_UNETR_path}/data'
os.makedirs(target_dir1, exist_ok=True)

# Download file using megadl
!megadl "{mega_url}" --path "{download_dir}"

# Move the file to target directory (if needed rename)
shutil.move(downloaded_file, os.path.join(target_dir1, "CFE_a_selected_L5_S2_S1_MSRM_PCA_GLO_N48.tif"))
print(f"File moved to: {target_dir1}")




[0KDownloaded CFE_a_selected_L5_S2_S1_MSRM_PCA_GLO_N48.tif
File moved to: /content/2025_levees_DL/Swin_UNETR/data


In [7]:
# Download the model from the Dane Badawcze UW repository (Published version: 2025-09-24):
## Buławka, Nazarij; Orengo, Hector A.; Lumbreras Ruiz, Felipe; Berganzo-Besga, Iban; Gupta, Ekta, 2025, "Traces of ancient irrigation in central Iraq detected using deep learning model", https://doi.org/10.58132/MY8CCL, Dane Badawcze UW, V1

## URL of the file
url = "https://danebadawcze.uw.edu.pl/api/access/datafile/17759"

## Download location
download_path = "/content/Levees_SWINUNETR_48_best.pth"

## Target directory
target_dir2 = f'{Swin_UNETR_path}/model/Swin_UNETR/Aug'
os.makedirs(target_dir2, exist_ok=True)  # create dir if it doesn't exist

## Download the file
print("Downloading...")
response = requests.get(url, stream=True)
with open(download_path, "wb") as f:
    shutil.copyfileobj(response.raw, f)
print("Download complete.")

3# Move to target directory
final_path = os.path.join(target_dir2, "Levees_SWINUNETR_48_best.pth")
shutil.move(download_path, final_path)
print(f"File moved to: {final_path}")



Downloading...
Download complete.
File moved to: /content/2025_levees_DL/Swin_UNETR/model/Swin_UNETR/Aug/Levees_SWINUNETR_48_best.pth


In [18]:
## Predict

Swin_UNETR_pth = os.path.join(target_dir2, "Levees_SWINUNETR_48_best.pth")
print({Swin_UNETR_pth})
data_path =  os.path.join(target_dir1, "CFE_a_selected_L5_S2_S1_MSRM_PCA_GLO_N48.tif")
print(data_path)


{'/content/2025_levees_DL/Swin_UNETR/model/Swin_UNETR/Aug/Levees_SWINUNETR_48_best.pth'}
/content/2025_levees_DL/Swin_UNETR/data/CFE_a_selected_L5_S2_S1_MSRM_PCA_GLO_N48.tif


In [23]:
import monai
from monai.inferers import sliding_window_inference
from monai.networks.nets import SwinUNETR, AttentionUnet
import einops
import warnings
import json
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import os
import rasterio
from rasterio.windows import Window
import tifffile as tiff


def get_filename_without_extension():
    data_to_detect = data_path
    filename_to_detect = os.path.splitext(os.path.basename(data_to_detect))[0]
    return filename_to_detect
filename_to_detect = get_filename_without_extension()

class NumpyDataset(Dataset):
    def __init__(self, image_paths, label_data):
        self.image_paths = image_paths
        self.label_data = label_data  # Accept either paths or preloaded arrays

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = np.load(self.image_paths[idx])

        # Check if label_data contains paths or arrays
        if isinstance(self.label_data[idx], str):
            label = np.load(self.label_data[idx])
        else:
            label = self.label_data[idx]

        image = self.replace_nans_in_array(image)
        label = self.replace_nans_in_array(label)

        image = torch.tensor(image, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.float32)

        return image, label, self.image_paths[idx]

    @staticmethod
    def replace_nans_in_array(arr):
        arr[np.isnan(arr)] = 0
        arr[np.isinf(arr)] = 0
        return arr


def load_dataset_json(json_path):
    with open(json_path, 'r') as file:
        dataset_json = json.load(file)
    return dataset_json

def prepare_test_loader(test_image_paths, batch_size):
    dummy_label_dir = "./model/temp_labels/"
    if not os.path.exists(dummy_label_dir):
        os.makedirs(dummy_label_dir)

    test_labels = []
    for image_path in test_image_paths:
        dummy_label_path = os.path.join(dummy_label_dir, os.path.basename(image_path).replace('.npy', '_label.npy'))
        np.save(dummy_label_path, np.zeros_like(np.load(image_path)))
        test_labels.append(dummy_label_path)

    test_dataset = NumpyDataset(test_image_paths, test_labels)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    return test_loader

def initialize_model():
    """Initialize the SwinUNETR model."""
    model = SwinUNETR(
        img_size=(96, 96),
        in_channels=48,
        out_channels=1,  # Use the passed `num_classes`
        use_checkpoint=True,
        feature_size=48,
        depths=(3, 9, 18, 3),
        num_heads=(4, 8, 16, 32),
        drop_rate=0.1,  # Added dropout
        attn_drop_rate=0.1,
        dropout_path_rate=0.2,
        spatial_dims=2
    )
    return model


def load_adapted_model(model, checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    model.load_state_dict(checkpoint, strict=False)
    return model

### Load model

def load_large_image(image_path):
    print(f"Loading large image from {image_path}")
    with rasterio.open(image_path) as src:
        image = src.read()
        transform = src.transform  # Capture the affine transformation matrix
    return image, transform

def split_to_tiles(image, tile_size, save_dir):
    print(f"Splitting image into tiles of size {tile_size}x{tile_size}")
    tiles = []
    num_channels, height, width = image.shape
    for i in range(0, height, tile_size):
        for j in range(0, width, tile_size):
            window = Window(j, i, tile_size, tile_size)
            tile = image[:, i:i+tile_size, j:j+tile_size]
            if tile.shape[1] == tile_size and tile.shape[2] == tile_size:
                tile_path = os.path.join(save_dir, f'tile_{i}_{j}.tif')
                print(f"Saving tile to {tile_path}")
                tiff.imwrite(tile_path, tile)
                tiles.append(tile_path)
    print(f"Total number of tiles: {len(tiles)}")
    return tiles

def convert_tiles_to_npy(tile_paths, npy_dir):
    if not os.path.exists(npy_dir):
        os.makedirs(npy_dir)
    npy_paths = []
    for tile_path in tile_paths:
        image = tiff.imread(tile_path)
        npy_path = os.path.join(npy_dir, os.path.basename(tile_path).replace('.tif', '.npy'))
        print(f"Converting {tile_path} to {npy_path}")
        np.save(npy_path, image)
        npy_paths.append(npy_path)
    return npy_paths

def merge_tiles(tiles, image_shape, tile_size):
    print(f"Merging tiles back into full image of shape {image_shape}")
    _, height, width = image_shape  # Assume image_shape is in the format (num_channels, height, width)
    full_image = np.zeros((height, width), dtype=np.uint8)  # Only one channel for the full image

    for tile_path in tiles:
        tile = tiff.imread(tile_path)
        # Adjust the parsing to match the filename pattern used in split_to_tiles
        filename = os.path.basename(tile_path)
        parts = filename.replace('tile_', '').replace('.tif', '').split('_')
        if len(parts) == 2:  # Ensure the filename format is correct
            i, j = map(int, parts)
            full_image[i:i+tile_size, j:j+tile_size] = tile  # Single channel tile assignment
        else:
            print(f"Unexpected filename format: {filename}")

    return full_image
###

def predict_and_save(model, test_loader, device, tile_paths, save_dir='model/temp_pred'):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    print(f"Starting prediction on device: {device}")
    model.eval()
    model.to(device)
    predictions = []

    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader):
            images, _, _ = batch  # Unpack the batch; labels and paths are not needed for predictions
            print(f"Predicting batch {batch_idx + 1}/{len(test_loader)}")
            print(f"Images type: {type(images)}, Images shape: {images.shape if isinstance(images, torch.Tensor) else 'unknown'}")

            images = images.to(device)  # Move images to the correct device
            outputs = model(images)
            outputs = torch.sigmoid(outputs)
            outputs = outputs.cpu().numpy()

            # Save the predictions for each image in the batch
            for i in range(images.shape[0]):
                output_image = outputs[i]
                binary_output = (output_image > 0.5).astype(np.uint8)  # Apply threshold to create binary output
                original_tile_name = os.path.basename(tile_paths[batch_idx * test_loader.batch_size + i]).replace('.npy', '.tif')
                save_path = os.path.join(save_dir, original_tile_name)
                print(f"Saving prediction to {save_path}")
                tiff.imwrite(save_path, binary_output)
                predictions.append(save_path)

    return predictions



def save_full_raster(predictions, image_shape, tile_size, save_path):
    print(f"Saving full raster to {save_path}")
    full_raster = merge_tiles(predictions, image_shape, tile_size)
    with rasterio.open(save_path, 'w', driver='GTiff', height=full_raster.shape[0],
                       width=full_raster.shape[1], count=1, dtype=full_raster.dtype) as dst:
        dst.write(full_raster, 1)

def create_tfw_file(transform, tfw_path):
    print(f"Creating .tfw file at {tfw_path}")
    with open(tfw_path, 'w') as f:
        f.write(f"{transform.a}\n")  # pixel size in the x-direction
        f.write(f"{transform.b}\n")  # rotation term (always 0 for north-up images)
        f.write(f"{transform.d}\n")  # rotation term (always 0 for north-up images)
        f.write(f"{transform.e}\n")  # pixel size in the y-direction (usually negative)
        f.write(f"{transform.c}\n")  # x-coordinate of the upper-left corner of the upper-left pixel
        f.write(f"{transform.f}\n")  # y-coordinate of the upper-left corner of the upper-left pixel

if __name__ == "__main__":
    image_path = data_path

    tile_size = 96
    batch_size = 12
    model = initialize_model()

    print("Loading and preprocessing the large image")
    large_image, transform = load_large_image(image_path)
    large_image = NumpyDataset.replace_nans_in_array(large_image)
    image_shape = large_image.shape
    print(f"Large image shape: {image_shape}")

    temp_dir = f'{Swin_UNETR_path}/results/temp/'
    npy_dir = f'{Swin_UNETR_path}/results/temp_npy/'
    pred_dir = f'{Swin_UNETR_path}/results/temp_pred/'
    result_path = f'{Swin_UNETR_path}/results/Swin_UNETR/Aug/{filename_to_detect}_swinunetr.tif'
    tfw_path = f'{Swin_UNETR_path}/results/Swin_UNETR/Aug/{filename_to_detect}_swinunetr.tfw'

    if not os.path.exists(temp_dir):
        os.makedirs(temp_dir)

    tile_paths = split_to_tiles(large_image, tile_size, temp_dir)
    npy_paths = convert_tiles_to_npy(tile_paths, npy_dir)
    test_loader = prepare_test_loader(npy_paths, batch_size)

    adapted_model_path = Swin_UNETR_pth
    model = load_adapted_model(model, adapted_model_path)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    predictions = predict_and_save(model, test_loader, device, tile_paths, save_dir=pred_dir)

    # Adjust the image_shape to reflect the single-channel output
    single_channel_image_shape = (1, image_shape[1], image_shape[2])
    save_full_raster(predictions, single_channel_image_shape, tile_size, result_path)

    # Create the .tfw file using the transform from the original large image
    create_tfw_file(transform, tfw_path)

    print("Prediction, saving, and .tfw file creation complete")



Loading and preprocessing the large image
Loading large image from /content/2025_levees_DL/Swin_UNETR/data/CFE_a_selected_L5_S2_S1_MSRM_PCA_GLO_N48.tif
Large image shape: (48, 2994, 1746)
Splitting image into tiles of size 96x96
Saving tile to /content/2025_levees_DL/Swin_UNETR/results/temp/tile_0_0.tif
Saving tile to /content/2025_levees_DL/Swin_UNETR/results/temp/tile_0_96.tif
Saving tile to /content/2025_levees_DL/Swin_UNETR/results/temp/tile_0_192.tif
Saving tile to /content/2025_levees_DL/Swin_UNETR/results/temp/tile_0_288.tif
Saving tile to /content/2025_levees_DL/Swin_UNETR/results/temp/tile_0_384.tif
Saving tile to /content/2025_levees_DL/Swin_UNETR/results/temp/tile_0_480.tif
Saving tile to /content/2025_levees_DL/Swin_UNETR/results/temp/tile_0_576.tif
Saving tile to /content/2025_levees_DL/Swin_UNETR/results/temp/tile_0_672.tif
Saving tile to /content/2025_levees_DL/Swin_UNETR/results/temp/tile_0_768.tif
Saving tile to /content/2025_levees_DL/Swin_UNETR/results/temp/tile_0_86

  dataset = writer(


In [24]:
import numpy as np
import cv2
import os
import rasterio
from skimage import io
from skimage.morphology import thin
from scipy.ndimage import distance_transform_edt
from skimage.morphology import skeletonize

# ====== INPUT data ======
# raster_r = input("Enter path to raster data: ").strip('"') # change to process other files
raster_r = result_path
def get_filename_without_extension():
    return os.path.splitext(os.path.basename(raster_r))[0]
filename_to_detect = get_filename_without_extension()
def load_large_image(image_path):
    print(f"Loading large image from {image_path}")
    with rasterio.open(image_path) as src:
        image = src.read(1)  # Read first band
        transform = src.transform
    return image, transform

binary_raster, transform = load_large_image(raster_r)
binary_raster = (binary_raster > 0).astype(np.uint8)  # Ensure binary

# REMOVE SMALL OBJECTS
min_area = 350
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary_raster, connectivity=8)

# Keep only features above the selected min_area parameter
filtered_raster = np.zeros_like(binary_raster, dtype=np.uint8)
for i in range(1, num_labels):  # skip background
    if stats[i, cv2.CC_STAT_AREA] >= min_area:
        filtered_raster[labels == i] = 1

# Save TFW file to preserve geographic location
def create_tfw_file(transform, tfw_path):
    print(f"Creating .tfw file at {tfw_path}")
    with open(tfw_path, 'w') as f:
        f.write(f"{transform.a}\n")
        f.write(f"{transform.b}\n")
        f.write(f"{transform.d}\n")
        f.write(f"{transform.e}\n")
        f.write(f"{transform.c}\n")
        f.write(f"{transform.f}\n")

# Output
output_dir = results_path
os.makedirs(output_dir, exist_ok=True)
filtered_dir = os.path.join(output_dir, "Filtered")
os.makedirs(filtered_dir, exist_ok=True)

# Save
filtered_tif_path = os.path.join(filtered_dir, f'{filename_to_detect}_filtered.tif')
io.imsave(filtered_tif_path, (filtered_raster * 1).astype(np.uint8))
create_tfw_file(transform, os.path.splitext(filtered_tif_path)[0] + '.tfw')

print("Filtering complete. Small features removed and .tfw file saved.")

# CLOSING
kernel = np.ones((7,7),np.uint8)
closing_raster = cv2.morphologyEx(filtered_raster, cv2.MORPH_CLOSE, kernel)

# SAVE CLOSING results
closing_tif_path = os.path.join(filtered_dir, f'{filename_to_detect}_closing.tif')
cv2.imwrite(closing_tif_path, (closing_raster * 1).astype(np.uint8))
create_tfw_file(transform, os.path.splitext(closing_tif_path)[0] + '.tfw')


# Skeletonize
skeletonize_raster = skeletonize(closing_raster > 0, method='lee')
skeletonize_raster = (skeletonize_raster * 1).astype(np.uint8)

# Save the final results
skeletonize_raster_tif_path = os.path.join(output_dir, f'{filename_to_detect}_skeletonize.tif')
cv2.imwrite(skeletonize_raster_tif_path, skeletonize_raster)
create_tfw_file(transform, os.path.splitext(skeletonize_raster_tif_path)[0] + '.tfw')


Loading large image from /content/2025_levees_DL/Swin_UNETR/results/Swin_UNETR/Aug/CFE_a_selected_L5_S2_S1_MSRM_PCA_GLO_N48_swinunetr.tif


  return func(*args, **kwargs)


Creating .tfw file at /content/2025_levees_DL/Swin_UNETR/results/Swin_UNETR/Aug/Filtered/CFE_a_selected_L5_S2_S1_MSRM_PCA_GLO_N48_swinunetr_filtered.tfw
Filtering complete. Small features removed and .tfw file saved.
Creating .tfw file at /content/2025_levees_DL/Swin_UNETR/results/Swin_UNETR/Aug/Filtered/CFE_a_selected_L5_S2_S1_MSRM_PCA_GLO_N48_swinunetr_closing.tfw
Creating .tfw file at /content/2025_levees_DL/Swin_UNETR/results/Swin_UNETR/Aug/CFE_a_selected_L5_S2_S1_MSRM_PCA_GLO_N48_swinunetr_skeletonize.tfw


In [35]:



# MEGA file URL
mega_url2 = "https://mega.nz/file/5txRRQiZ#dDRqdI4F695r7yG465W5KIISVdQ1Iw2O8EWK2BIdoFs"

# Temporary download path (megadl saves the file with original name by default)
download_dir3 = "/content"
downloaded_file = os.path.join(download_dir, "CFE_a_levee_reference.tif")

# Target directory
target_dir3 = f'{Swin_UNETR_path}/data'
os.makedirs(target_dir3, exist_ok=True)

# Download file using megadl
!megadl "{mega_url2}" --path "{download_dir3}"

# Move the file to target directory (if needed rename)
shutil.move(downloaded_file, os.path.join(target_dir3, "CFE_a_levee_reference.tif"))
print(f"File moved to: {target_dir3}")

Reference_data_path = os.path.join(target_dir3, "CFE_a_levee_reference.tif")


[0KDownloaded CFE_a_levee_reference.tif
File moved to: /content/2025_levees_DL/Swin_UNETR/data


In [40]:
import os
import numpy as np
import rasterio
from rasterio.enums import Resampling
from rasterio.warp import reproject
from rasterio.transform import Affine
from scipy.ndimage import binary_dilation, binary_erosion

# ------------------------------------------------------------
# Step 1. Read spatial coordinates from TFW
# ------------------------------------------------------------
def read_tfw_coordinates(tif_path):
    tfw_path = os.path.splitext(tif_path)[0] + ".tfw"
    if not os.path.exists(tfw_path):
        raise FileNotFoundError(f"World file not found: {tfw_path}")

    with open(tfw_path, "r") as f:
        lines = f.readlines()
    if len(lines) != 6:
        raise ValueError(f"Invalid world file format: {tfw_path}")

    # Parse six parameters
    A = float(lines[0])  # pixel size in X
    D = float(lines[1])  # rotation term
    B = float(lines[2])  # rotation term
    E = float(lines[3])  # pixel size in Y
    C = float(lines[4])  # X coordinate of center of upper-left pixel
    F = float(lines[5])  # Y coordinate of center of upper-left pixel

    transform = Affine(A, B, C, D, E, F)
    print("\n[TFW] Affine transform read from file:")

    return transform


# ------------------------------------------------------------
# Step 2. Align reference raster to predicted raster
# ------------------------------------------------------------

def align_reference_with_crs(reference_tif, predicted_tif, predicted_transform, dst_crs, output_path):
    with rasterio.open(predicted_tif) as pred:
        dst_width = pred.width
        dst_height = pred.height
        dst_profile = pred.profile.copy()

    with rasterio.open(reference_tif) as ref:
        if ref.crs is None:
            raise ValueError("Reference raster has no CRS — cannot reproject.")
        src_dtype = ref.dtypes[0]
        dest = np.zeros((dst_height, dst_width), dtype=src_dtype)

        reproject(
            source=rasterio.band(ref, 1),
            destination=dest,
            src_transform=ref.transform,
            src_crs=ref.crs,
            dst_transform=predicted_transform,
            dst_crs=dst_crs,
            resampling=Resampling.nearest,
            num_threads=2
        )

    dst_profile.update(
        dtype=src_dtype,
        count=1,
        compress="lzw",
        driver="GTiff",
        transform=predicted_transform,
        crs=dst_crs,
        width=dst_width,
        height=dst_height
    )

    with rasterio.open(output_path, "w", **dst_profile) as dst:
        dst.write(dest, 1)

    print(f"\nAligned reference raster saved to:\n{output_path}")
    return output_path



def align_reference(reference_tif, predicted_tif, predicted_transform, output_path):
    with rasterio.open(predicted_tif) as pred:
        dst_crs = pred.crs
        dst_width = pred.width
        dst_height = pred.height
        dst_profile = pred.profile.copy()

    with rasterio.open(reference_tif) as ref:
        if ref.crs is None:
            raise ValueError("Reference raster has no CRS — cannot reproject.")
        src_dtype = ref.dtypes[0]
        dest = np.zeros((dst_height, dst_width), dtype=src_dtype)

        reproject(
            source=rasterio.band(ref, 1),
            destination=dest,
            src_transform=ref.transform,
            src_crs=ref.crs,
            dst_transform=predicted_transform,
            dst_crs=dst_crs,
            resampling=Resampling.nearest,
            num_threads=2
        )

    dst_profile.update(
        dtype=src_dtype,
        count=1,
        compress="lzw",
        driver="GTiff",
        transform=predicted_transform,
        crs=dst_crs,
        width=dst_width,
        height=dst_height
    )

    with rasterio.open(output_path, "w", **dst_profile) as dst:
        dst.write(dest, 1)

    print(f"\nAligned reference raster saved to:\n{output_path}")
    return output_path


# ------------------------------------------------------------
# Step 3. Compare rasters
# ------------------------------------------------------------
def compare_rasters(reference_tif, predicted_tif, IoU_buf):
    with rasterio.open(reference_tif) as ref_src:
        ref_data = ref_src.read(1)
        profile = ref_src.profile

    with rasterio.open(predicted_tif) as pred_src:
        pred_data = pred_src.read(1)

    if ref_data.shape != pred_data.shape:
        raise ValueError("Rasters must have the same dimensions and alignment.")

    struct_element = np.ones((IoU_buf, IoU_buf))
    buffered_ref = binary_dilation(ref_data, structure=struct_element).astype(ref_data.dtype)
    eroded_pred_data = binary_erosion(pred_data, structure=struct_element).astype(pred_data.dtype)

    TP = (buffered_ref == 1) & (eroded_pred_data == 1)
    FP = (buffered_ref == 0) & (eroded_pred_data == 1)
    FN = (buffered_ref == 1) & (eroded_pred_data == 0)

    output_data = np.zeros_like(eroded_pred_data, dtype=np.int8)
    output_data[TP] = 1
    output_data[FP] = -1
    output_data[FN] = 2

    tp_count = np.sum(TP)
    fp_count = np.sum(FP)
    fn_count = np.sum(FN)

    precision = tp_count / (tp_count + fp_count) if (tp_count + fp_count) > 0 else 0
    recall = tp_count / (tp_count + fn_count) if (tp_count + fn_count) > 0 else 0
    f1_score = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    iou = tp_count / (tp_count + fp_count + fn_count) if (tp_count + fp_count + fn_count) > 0 else 0

    print(f"\n--- Evaluation Metrics ---")
    print(f"IoU:        {iou:.4f}")
    print(f"Precision:  {precision:.4f}")
    print(f"Recall:     {recall:.4f}")
    print(f"F1 Score:   {f1_score:.4f}")

    output_path = predicted_tif.replace(".tif", f"_comparison_buf{IoU_buf}.tif")
    profile.update(dtype=rasterio.int8, count=1)

    with rasterio.open(output_path, "w", **profile) as dst:
        dst.write(output_data, 1)

    print(f"Comparison raster saved to:\n{output_path}")


# ------------------------------------------------------------
# Step 4. Main workflow
# ------------------------------------------------------------
def main():
    reference_tif_ = Reference_data_path
    #reference_tif = input("Enter path to reference data: ").strip('"') # If manual selection is needed
    predicted_tif = closing_tif_path
    #predicted_tif = input("Enter path to predicted data: ").strip('"') # If manual selection is needed

    IoU_buf = 1  # buffer size for dilation/erosion

    # Step 1: Read transform from TFW file
    predicted_transform = read_tfw_coordinates(predicted_tif)

    # Step 1b: Define CRS manually (replace EPSG code as needed!)
    from rasterio.crs import CRS
    dst_crs = CRS.from_epsg(4326)   #

    # Step 2: Align reference raster
    aligned_ref_path = os.path.splitext(predicted_tif)[0] + "_ref_aligned.tif"
    aligned_ref = align_reference_with_crs(reference_tif, predicted_tif, predicted_transform, dst_crs, aligned_ref_path)

    # Step 3: Compare rasters
    compare_rasters(aligned_ref, predicted_tif, IoU_buf)



if __name__ == "__main__":
    main()



[TFW] Affine transform read from file:

Aligned reference raster saved to:
/content/2025_levees_DL/Swin_UNETR/results/Swin_UNETR/Aug/Filtered/CFE_a_selected_L5_S2_S1_MSRM_PCA_GLO_N48_swinunetr_closing_ref_aligned.tif

--- Evaluation Metrics ---
IoU:        0.3605
Precision:  0.4661
Recall:     0.6142
F1 Score:   0.5300
Comparison raster saved to:
/content/2025_levees_DL/Swin_UNETR/results/Swin_UNETR/Aug/Filtered/CFE_a_selected_L5_S2_S1_MSRM_PCA_GLO_N48_swinunetr_closing_comparison_buf1.tif


5. Visualize the results



In [None]:
#!pip install leaflet
!pip install leafmap
!pip install localtileserver

In [None]:
import leafmap

# Map center coordinates
latitude = 31.52806
longitude = 65.24722

# Initialize map
m = leafmap.Map(center=[latitude, longitude], zoom=14)

# Raster path
raster_path = "/content/2025_levees_DL/Swin_UNETR/results/Swin_UNETR/Aug/Filtered/CFE_a_selected_L5_S2_S1_MSRM_PCA_GLO_N48_swinunetr_closing_comparison_buf1.tif"

# Define palette
palette = [
    "#e5350e",  # -1 = False Positives (red)
    "#0dff0d",  #  1 = True Positives (green)
    "#729bff",  #  2 = False Negatives (blue)
]

# ✅ Correct syntax for add_raster
m.add_raster(
    raster_path,    # source (positional)
    palette=palette,
    vmin=-1,
    vmax=2,
    nodata=0,
    layer_name="Detection Results",
)

# Add legend
legend_dict = {
    "False Positives (-1)": "#e5350e",
    "True Positives (1)": "#0dff0d",
    "False Negatives (2)": "#729bff",
}
m.add_legend(title="Detection Results", legend_dict=legend_dict)

# Add layer control for interactivity
m.add_layer_control()

# Show map
m
