In [None]:
#Install all necessary dependencies - Copernicus requires TorchGeo, and the below Torch and Torchvision Instances
!pip install opencv-python-headless==4.10.0.82
!pip install thinc==8.2.2
!pip install tensorflow==2.18.0
!pip install cupy-cuda11x
!pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu118
!pip install torchgeo geopandas rasterio scikit-learn matplotlib pandas shapely

In [None]:
#Install Numpy 1.26.4
!pip install numpy==1.26.4

In [None]:
#Import libraries
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import rasterio
import numpy as np
from sklearn.metrics import f1_score, jaccard_score
import matplotlib.pyplot as plt
from torchgeo.models import CopernicusFM
from torch.cuda.amp import autocast, GradScaler
import torchvision.transforms as T
from sklearn.metrics import classification_report

In [None]:
#Define dataset (Google Drive Method), generate or download metadata where necessary, select augmented, extended, or standard dataset.
USE_AUGMENTED_DATA = True
if USE_AUGMENTED_DATA:
    original_metadata = "/content/drive/MyDrive/SGDS_BD_DISS_DATA/BD_copernicus_metadata.csv"
    augmented_metadata = "/content/drive/MyDrive/SGDS_BD_DISS_DATA/AUG_copernicus_metadata.csv"
    if os.path.exists(augmented_metadata):
        df_main = pd.read_csv(original_metadata)
        df_aug = pd.read_csv(augmented_metadata)
        df_combined = pd.concat([df_main, df_aug], ignore_index=True)
        combined_metadata_path = "/content/drive/MyDrive/SGDS_BD_DISS_DATA/COMBINED_METADATA.csv"
        df_combined.to_csv(combined_metadata_path, index=False)
        metadata_csv = combined_metadata_path
    else:
        metadata_csv = original_metadata
else:
    metadata_csv = "/content/drive/MyDrive/SGDS_BD_DISS_DATA/BD_copernicus_metadata.csv"


In [None]:
#Dataset Class, designed to function with CopFM encoder.
class CopernicusDataset(Dataset):
    def __init__(self, metadata_csv):
        self.data = pd.read_csv(metadata_csv)
        self.tiles = self.data['path'].apply(lambda p: os.path.basename(p).split('_')[2])
        self.tile_ids = sorted(set(self.tiles))
        self.resize = T.Resize((256, 256), antialias=True)

        valid_ids = []
        for tile_id in self.tile_ids:
            mask_path = self._get_mask_path(tile_id)
            if mask_path and os.path.exists(mask_path):
                with rasterio.open(mask_path) as src:
                    mask = src.read(1)
                if np.any(mask > 0):
                    valid_ids.append(tile_id)
        self.tile_ids = valid_ids
        self.tile_ids.sort(key=lambda tid: -self._damage_ratio(tid))

    #Acquire ground truth masks
    def _get_mask_path(self, tile_id):
        match = self.data[self.data['path'].str.contains(f"tile_{tile_id}_S1_PostMean")]
        if match.empty:
            return None
        row = match.iloc[0]
        region = row['region'].upper()
        return os.path.join('/content/drive/MyDrive/SGDS_BD_DISS_DATA', f"{region}_MASKS", f"tile_{tile_id}_mask.tif")

    #Damage ratio function
    def _damage_ratio(self, tile_id):
        mask_path = self._get_mask_path(tile_id)
        if not mask_path or not os.path.exists(mask_path):
            return 0.0
        with rasterio.open(mask_path) as src:
            mask = src.read(1)
        return np.mean(mask == 2)

    def __len__(self):
        return len(self.tile_ids)

    #Image resizing (upscaling) to 256x256 for 50x50px tiles to work with CopFM pretrain weights.
    def __getitem__(self, idx):
        tile_id = self.tile_ids[idx]
        tile_rows = self.data[self.data['path'].str.contains(f"tile_{tile_id}_")]
        pre_row = tile_rows[tile_rows['path'].str.contains('PreMean')].iloc[0]
        post_row = tile_rows[tile_rows['path'].str.contains('PostMean')].iloc[0]

        with rasterio.open(pre_row['path']) as src:
            pre_img = src.read(out_shape=(2, 50, 50))
        with rasterio.open(post_row['path']) as src:
            post_img = src.read(out_shape=(2, 50, 50))

        img = np.concatenate([pre_img, post_img], axis=0)
        img = torch.tensor(img, dtype=torch.float32) / 255.0
        img = T.Resize((256, 256))(img)

        mask_path = self._get_mask_path(tile_id)
        with rasterio.open(mask_path) as src:
            mask = src.read(1, out_shape=(50, 50))
        mask = torch.tensor(mask, dtype=torch.long)
        mask = T.Resize((256, 256))(mask.unsqueeze(0)).squeeze(0).long()

        delta_time = 20000.0
        patch_area = 1.0
        metadata = torch.tensor([post_row['lon'], post_row['lat'], delta_time, patch_area], dtype=torch.float32)

        return img, mask, metadata

In [None]:
#U-Net Style Decoder (in line with Siamese U-Net Model)
class UNetDecoder(nn.Module):
    def __init__(self, in_channels=1024, num_classes=3):
        super().__init__()
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels, 512, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.up3 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.final = nn.Conv2d(128, num_classes, kernel_size=1)

    def forward(self, x):
        x = self.up1(x)
        x = self.up2(x)
        x = self.up3(x)
        return self.final(x)

In [None]:
#Siamese Segmentation Trainer, finding the difference in pre and post tokens to optimally learn change between SAR images.
class SiameseSegmentationTrainer(nn.Module):
    def __init__(self, weight_path, num_classes=3):
        super().__init__()
        self.encoder_pre = CopernicusFM(
            img_size=256,
            patch_size=16,
            num_classes=0,
            global_pool=False,
            embed_dim=1024,
            depth=24,
            num_heads=16
        )
        self.encoder_post = CopernicusFM(
            img_size=256,
            patch_size=16,
            num_classes=0,
            global_pool=False,
            embed_dim=1024,
            depth=24,
            num_heads=16
        )

        weights = torch.load(weight_path, map_location='cpu')
        if "pos_embed" in weights:
            old_pos_embed = weights["pos_embed"]
            cls_token = old_pos_embed[:, :1, :]
            old_grid = int((old_pos_embed.shape[1] - 1) ** 0.5)
            new_grid = 16  # 256 / 16
            num_patches = new_grid * new_grid

            old_pe = old_pos_embed[:, 1:, :].reshape(1, old_grid, old_grid, -1).permute(0, 3, 1, 2)
            new_pe = F.interpolate(old_pe, size=(new_grid, new_grid), mode='bilinear', align_corners=False)
            new_pe = new_pe.permute(0, 2, 3, 1).reshape(1, num_patches, -1)
            weights["pos_embed"] = torch.cat([cls_token, new_pe], dim=1)

        self.encoder_pre.load_state_dict(weights, strict=False)
        self.encoder_post.load_state_dict(weights, strict=False)

        self.decoder = UNetDecoder(in_channels=1024, num_classes=num_classes)

    def forward(self, x, meta):
        pre = x[:, :2]
        post = x[:, 2:]
        wavelengths = torch.tensor([56000000.0] * 2, device=x.device)
        bandwidths = torch.tensor([100000000.0] * 2, device=x.device)

        pre_tokens = self.encoder_pre.patch_embed_spectral(pre, wavelengths, bandwidths, kernel_size=7)
        post_tokens = self.encoder_post.patch_embed_spectral(post, wavelengths, bandwidths, kernel_size=7)

        for blk in self.encoder_pre.blocks:
            pre_tokens = blk(pre_tokens)
        for blk in self.encoder_post.blocks:
            post_tokens = blk(post_tokens)

        diff_tokens = post_tokens - pre_tokens

        B, N, C = diff_tokens.shape
        if N == 257:
            diff_tokens = diff_tokens[:, 1:, :]
            N -= 1

        H = W = int(N ** 0.5)
        feats = diff_tokens.permute(0, 2, 1).reshape(B, C, H, W)
        out = self.decoder(feats)
        return F.interpolate(out, size=(50, 50), mode='bilinear', align_corners=False)

In [None]:
#Composite loss function combining CE, Dice and Focal loss (final version)
def dice_loss(pred, target, epsilon=1e-6):
    pred = F.softmax(pred, dim=1)
    target_onehot = F.one_hot(target, num_classes=pred.shape[1]).permute(0, 3, 1, 2).float()
    intersection = (pred * target_onehot).sum(dim=(2, 3))
    union = pred.sum(dim=(2, 3)) + target_onehot.sum(dim=(2, 3))
    dice = (2 * intersection + epsilon) / (union + epsilon)
    return 1 - dice.mean()

def focal_loss(pred, target, alpha=1.0, gamma=2.0):
    ce_loss = F.cross_entropy(pred, target, reduction='none')
    pt = torch.exp(-ce_loss)
    focal = alpha * (1 - pt) ** gamma * ce_loss
    return focal.mean()

def total_loss(pred, target):
    weights = torch.tensor([0.5, 1.0, 5.0], device=pred.device)
    return (
        0.5 * F.cross_entropy(pred, target, weight=weights) +
        0.3 * dice_loss(pred, target) +
        0.2 * focal_loss(pred, target)
    )

In [None]:
#Model Evaluation
def evaluate(model, dataloader, device, num_classes=3):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for x, y, meta in dataloader:
            x, y, meta = x.to(device), y.to(device), meta.to(device)
            y = F.interpolate(y.unsqueeze(1).float(), size=(50, 50), mode='nearest').squeeze(1).long()
            logits = model(x, meta)
            preds = logits.argmax(1)
            all_preds.append(preds.cpu().numpy().flatten())
            all_labels.append(y.cpu().numpy().flatten())
    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)
    report = classification_report(
        all_labels, all_preds, labels=list(range(num_classes)),
        target_names=["background", "intact", "damaged"],
        digits=3, output_dict=True
    )
    macro_f1 = f1_score(all_labels, all_preds, average='macro')
    macro_iou = jaccard_score(all_labels, all_preds, average='macro')
    acc = (all_preds == all_labels).mean()
    return acc, macro_f1, macro_iou, report


In [None]:
#Training
import os
import torch
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler, autocast
import torch.nn.functional as F

save_dir = "/content/drive/MyDrive"
os.makedirs(save_dir, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#resume_from = os.path.join(save_dir, "")  #Path to resume checkpoint if necessary

#Load Copernicus Pretrain
model = SiameseSegmentationTrainer("/content/drive/MyDrive/copernicus_fm/CopernicusFM_ViT_large_varlang_e100.pth").to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scaler = GradScaler()
best_f1 = 0.0

#Path to resume checkpoint
#if os.path.exists(resume_from):
    #print(f"Resuming from checkpoint: {resume_from}")
    #checkpoint = torch.load(resume_from)
    #model.load_state_dict(checkpoint['model_state_dict'])
    #optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    #start_epoch = checkpoint.get('epoch', 0)
    #best_f1 = checkpoint.get('val_f1', 0.0)

dataset = CopernicusDataset(metadata_csv)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_ds, val_ds = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=2, shuffle=False, num_workers=2)

#Training Loop
for epoch in range(start_epoch, 50):
    model.train()
    epoch_loss = 0
    for x, y, meta in train_loader:
        x, meta = x.to(device), meta.to(device)
        y = F.interpolate(y.unsqueeze(1).float(), size=(50, 50), mode='nearest').squeeze(1).long()
        y = y.to(device)

        optimizer.zero_grad()
        with autocast():
            logits = model(x, meta)
            loss = total_loss(logits, y)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        epoch_loss += loss.item()

    avg_train_loss = epoch_loss / len(train_loader)
    acc, f1, iou, report = evaluate(model, val_loader, device)

    print(f"\nEpoch {epoch+1}")
    print(f"Train Loss: {avg_train_loss:.4f} | Val Acc: {acc:.4f} | Macro F1: {f1:.4f} | Macro IoU: {iou:.4f}")
    for cls in ["background", "intact", "damaged"]:
        f1_cls = report[cls]["f1-score"]
        iou_cls = report[cls]["recall"] * report[cls]["precision"] / (report[cls]["recall"] + report[cls]["precision"] + 1e-6)
        print(f"  {cls.capitalize():<10} F1: {f1_cls:.3f}, IoU (est): {iou_cls:.3f}")

    #Save best model
    if f1 > best_f1:
        best_f1 = f1
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': acc,
            'val_f1': f1,
            'val_iou': iou
        }, os.path.join(save_dir, "copernicus_aug_ext_siam_best_model.pth"))
        print(f"Saved best model (Epoch {epoch+1}, F1={f1:.4f})")

    #Save latest model (every epoch)
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_acc': acc,
        'val_f1': f1,
        'val_iou': iou
    }, os.path.join(save_dir, "copernicus_au_ext_siam_model.pth"))

    torch.cuda.empty_cache()

In [None]:
#Evaluation and Visualisation
import torch.nn.functional as F

#checkpoint = torch.load("/content/drive/MyDrive/copernicus_aug_siam_model.pth")
#model.load_state_dict(checkpoint['model_state_dict'])

model.eval()
with torch.no_grad():
    #Load sample
    x, y, meta = dataset[0]

    #Recover region from dataset for visualisation
    tile_id = dataset.tile_ids[0]
    row = dataset.data[dataset.data['path'].str.contains(f"tile_{tile_id}_")].iloc[0]
    region = row['region']

    x = x.unsqueeze(0).to(device)
    meta = meta.unsqueeze(0).to(device)

    #Run pred
    pred = model(x, meta)
    pred = pred.argmax(1)

    pred_mask = F.interpolate(pred.unsqueeze(1).float(), size=(50, 50), mode='nearest').squeeze().cpu().int().numpy()

    x_np = x.squeeze().cpu().numpy()
    vv_pre = x_np[0]
    vv_post = x_np[2]

    #Plot vis
    fig, axs = plt.subplots(1, 4, figsize=(20, 5))
    axs[0].imshow(vv_pre, cmap='gray'); axs[0].set_title(f"{region.capitalize()} VV Pre-War")
    axs[1].imshow(vv_post, cmap='gray'); axs[1].set_title(f"{region.capitalize()} VV Post-War")
    axs[2].imshow(y.cpu(), cmap='viridis', vmin=0, vmax=2); axs[2].set_title(f"{region.capitalize()} Ground Truth")
    axs[3].imshow(pred_mask, cmap='viridis', vmin=0, vmax=2); axs[3].set_title(f"{region.capitalize()} Predicted")

    for ax in axs: ax.axis('off')
    plt.tight_layout(); plt.show()