In [None]:
#Install all necessary dependencies
!pip install segmentation-models-pytorch
!pip install rasterio geopandas matplotlib albumentations

In [None]:
#Import libraries
import os
import torch
import numpy as np
import rasterio
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp
import torch.nn.functional as F
import torch.nn as nn
from sklearn.metrics import classification_report

In [None]:
#Define dataset (Google Drive Method), generate or download metadata where necessary, select augmented, extended, or standard dataset.
USE_AUGMENTED_DATA = False
if USE_AUGMENTED_DATA:
    original_metadata = "/content/drive/MyDrive/SGDS_BD_DISS_DATA/BD_copernicus_metadata_extended.csv.csv"
    augmented_metadata = "/content/drive/MyDrive/SGDS_BD_DISS_DATA/AUG_copernicus_metadata.csv"
    if os.path.exists(augmented_metadata):
        df_main = pd.read_csv(original_metadata)
        df_aug = pd.read_csv(augmented_metadata)
        df_combined = pd.concat([df_main, df_aug], ignore_index=True)
        combined_metadata_path = "/content/drive/MyDrive/SGDS_BD_DISS_DATA/COMBINED_METADATA_EXTENDED.csv"
        df_combined.to_csv(combined_metadata_path, index=False)
        metadata_csv = combined_metadata_path
    else:
        metadata_csv = original_metadata
else:
    metadata_csv = "/content/drive/MyDrive/SGDS_BD_DISS_DATA/BD_copernicus_metadata_extended.csv"

In [None]:
#Dataset Class, designed to function with CopFM encoder.
class SiameseCSVData(Dataset):
    def __init__(self, metadata_csv):
        self.data = pd.read_csv(metadata_csv)
        self.tiles = self.data['path'].apply(lambda p: os.path.basename(p).split('_')[2])
        self.tile_ids = sorted(set(self.tiles))

        valid_ids = []
        for tile_id in self.tile_ids:
            if self._mask_exists(tile_id):
                valid_ids.append(tile_id)
        self.tile_ids = valid_ids

    def _mask_exists(self, tile_id):
        match = self.data[self.data['path'].str.contains(f"tile_{tile_id}_S1_PostMean")]
        if match.empty:
            return False
        region = match.iloc[0]['region'].upper()
        mask_path = f"/content/drive/MyDrive/SGDS_BD_DISS_DATA/{region}_MASKS/tile_{tile_id}_mask.tif"
        return os.path.exists(mask_path)

    def __len__(self):
        return len(self.tile_ids)

    def __getitem__(self, idx):
        tile_id = self.tile_ids[idx]
        tile_rows = self.data[self.data['path'].str.contains(f"tile_{tile_id}_")]
        pre_row = tile_rows[tile_rows['path'].str.contains('PreMean')].iloc[0]
        post_row = tile_rows[tile_rows['path'].str.contains('PostMean')].iloc[0]

        with rasterio.open(pre_row['path']) as src:
            pre_img = src.read([1, 2]).astype(np.float32)
        with rasterio.open(post_row['path']) as src:
            post_img = src.read([1, 2]).astype(np.float32)

        pre_img = np.clip(pre_img, -35, 0)
        post_img = np.clip(post_img, -35, 0)
        pre_img = (pre_img + 35) / 35.0
        post_img = (post_img + 35) / 35.0

        region = pre_row['region'].upper()
        mask_path = f"/content/drive/MyDrive/SGDS_BD_DISS_DATA/{region}_MASKS/tile_{tile_id}_mask.tif"
        with rasterio.open(mask_path) as src:
            mask = src.read(1).astype(np.int64)

        return torch.tensor(pre_img), torch.tensor(post_img), torch.tensor(mask)

In [None]:
#Dataset split (prev version)
dataset = SiameseCSVData(metadata_csv)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_ds, val_ds = torch.utils.data.random_split(dataset, [train_size, val_size])

#Data loaders
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=2)

In [None]:
from segmentation_models_pytorch.decoders.unet.decoder import UnetDecoder

#Siamese U-Net Decoder
class SiamUNet(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder_pre = smp.encoders.get_encoder('resnet18', in_channels=2, depth=5, weights='imagenet')
        self.encoder_post = smp.encoders.get_encoder('resnet18', in_channels=2, depth=5, weights='imagenet')

        self.decoder = UnetDecoder(
            encoder_channels=self.encoder_pre.out_channels,
            decoder_channels=(256, 128, 64, 32, 16),
            n_blocks=5
        )

        self.segmentation_head = torch.nn.Conv2d(16, 3, kernel_size=1)

    def forward(self, pre_img, post_img):
        pre_feats = self.encoder_pre(pre_img)
        post_feats = self.encoder_post(post_img)
        diff_feats = [post - pre for pre, post in zip(pre_feats, post_feats)]
        x = self.decoder(diff_feats)
        return self.segmentation_head(x)


In [None]:
#Loss functions, Evals

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = SiamUNet().to(device)

CHECKPOINT_PATH = ''

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
start_epoch = 0

if os.path.exists(CHECKPOINT_PATH):
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint.get('epoch', 0) + 1
    print(f"Loaded checkpoint from epoch {start_epoch}")
else:
    print("No checkpoint found.")

def dice_loss(pred, target, epsilon=1e-6):
    pred = F.softmax(pred, dim=1)
    target_onehot = F.one_hot(target, num_classes=pred.shape[1]).permute(0, 3, 1, 2).float()
    intersection = (pred * target_onehot).sum(dim=(2, 3))
    union = pred.sum(dim=(2, 3)) + target_onehot.sum(dim=(2, 3))
    dice = (2 * intersection + epsilon) / (union + epsilon)
    return 1 - dice.mean()

def focal_loss(pred, target, alpha=1.0, gamma=2.0):
    ce_loss = F.cross_entropy(pred, target, reduction='none')
    pt = torch.exp(-ce_loss)
    focal = alpha * (1 - pt) ** gamma * ce_loss
    return focal.mean()

def combined_loss(pred, target):
    weights = torch.tensor([0.05, 1.0, 5.0], device=pred.device)
    ce = F.cross_entropy(pred, target, weight=weights)
    dice = dice_loss(pred, target)
    focal = focal_loss(pred, target, alpha=1.0, gamma=2.0)
    return 0.5 * ce + 0.3 * dice + 0.2 * focal

def pixel_accuracy(preds, labels):
    preds = torch.argmax(preds, dim=1)
    return (preds == labels).float().mean().item()

def per_class_iou_f1(preds, labels, num_classes=3):
    preds = torch.argmax(preds, dim=1).cpu().numpy()
    labels = labels.cpu().numpy()

    ious, f1s = [], []
    for cls in range(num_classes):
        pred_cls = (preds == cls)
        label_cls = (labels == cls)

        intersection = np.logical_and(pred_cls, label_cls).sum()
        union = np.logical_or(pred_cls, label_cls).sum()
        tp = intersection
        fp = pred_cls.sum() - tp
        fn = label_cls.sum() - tp

        iou = intersection / union if union != 0 else np.nan
        f1 = 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) != 0 else np.nan

        ious.append(iou)
        f1s.append(f1)

    return ious, f1s

In [None]:
#Training and Val

NUM_EPOCHS = 30
MODEL_PATH = '/content/drive/MyDrive/final_siamunet_imagenet_CEDiceFocal_extended_checkpoint.pth'

for epoch in range(start_epoch, NUM_EPOCHS):
    model.train()
    total_loss = 0
    train_accs = []

    for pre, post, y in train_loader:
        pre, post, y = pre.to(device), post.to(device), y.to(device)
        pred = model(pre, post)
        loss = combined_loss(pred, y)
        acc = pixel_accuracy(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        train_accs.append(acc)

    mean_train_loss = total_loss / len(train_loader)
    mean_train_acc = np.mean(train_accs)

    model.eval()
    val_accs, iou_scores, f1_scores = [], [], []
    with torch.no_grad():
        for pre, post, y in val_loader:
            pre, post, y = pre.to(device), post.to(device), y.to(device)
            pred = model(pre, post)
            acc = pixel_accuracy(pred, y)
            iou, f1 = per_class_iou_f1(pred, y)
            val_accs.append(acc)
            iou_scores.append(iou)
            f1_scores.append(f1)

    mean_iou = np.nanmean(iou_scores, axis=0)
    mean_f1 = np.nanmean(f1_scores, axis=0)
    mean_val_acc = np.mean(val_accs)

    print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS}")
    print(f"Train Loss: {mean_train_loss:.4f} | Train Acc: {mean_train_acc:.3f}")
    print(f"Val Acc: {mean_val_acc:.3f}")
    print(f"Val IoU: {[f'{m:.2f}' for m in mean_iou]} (bg/intact/damaged)")
    print(f"Val F1 : {[f'{m:.2f}' for m in mean_f1]}")

    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, MODEL_PATH)
    print(f"Saved model to {MODEL_PATH}")

In [None]:
#Evaluation and Visualisation

model.eval()
with torch.no_grad():
    index = 0
    pre, post, y = val_ds[index]

    # Retrieve region for visualisation
    val_indices = val_ds.indices if hasattr(val_ds, 'indices') else list(range(len(val_ds)))
    region = dataset.data[dataset.data['path'].str.contains(f"tile_{dataset.tile_ids[val_indices[index]]}_")].iloc[0]['region']

    pre_tensor = pre.unsqueeze(0).to(device)
    post_tensor = post.unsqueeze(0).to(device)
    pred = model(pre_tensor, post_tensor)
    pred_mask = torch.argmax(pred.squeeze(), dim=0).cpu().numpy()

    fig, axs = plt.subplots(1, 4, figsize=(20, 5))
    axs[0].imshow(pre[0].cpu(), cmap='gray')
    axs[0].set_title(f"{region.capitalize()} VV Pre-War")
    axs[1].imshow(post[0].cpu(), cmap='gray')
    axs[1].set_title(f"{region.capitalize()} VV Post-War")
    axs[2].imshow(y.cpu(), cmap='viridis', vmin=0, vmax=2)
    axs[2].set_title(f"{region.capitalize()} Ground Truth")
    axs[3].imshow(pred_mask, cmap='viridis', vmin=0, vmax=2)
    axs[3].set_title(f"{region.capitalize()} Predicted")
    for ax in axs:
        ax.axis('off')
    plt.tight_layout()
    plt.show()