In [None]:
from google.colab import drive
drive.mount('/content/drive')
!unzip '/content/drive/MyDrive/BTechProject/ChangeDetectionMergedDividedSplit-tif3.zip' -d '/content/ChangeDetectionMergedDividedSplit-tif'

## Data Loader

In [None]:
!pip install rasterio

In [None]:
import os
import rasterio
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class ChangeDetectionDatasetTIF(Dataset):
    def __init__(self, t2019_dir, t2024_dir, mask_dir,classes, transform=None):
        self.t2019_dir = t2019_dir
        self.t2024_dir = t2024_dir
        self.mask_dir = mask_dir
        self.classes = classes  # Change detection classes
        self.transform = transform

        # Load all paths
        self.t2019_paths = sorted([f for f in os.listdir(t2019_dir) if f.endswith('.tif')])
        self.t2024_paths = sorted([f for f in os.listdir(t2024_dir) if f.endswith('.tif')])
        self.mask_paths = sorted([f for f in os.listdir(mask_dir) if f.endswith('.tif')])

    def __len__(self):
        return len(self.t2019_paths)

    def __getitem__(self, index):
        # Load images using rasterio
        with rasterio.open(os.path.join(self.t2019_dir, self.t2019_paths[index])) as src:
            img_t2019 = src.read(out_dtype=np.float32) / 255.0
        with rasterio.open(os.path.join(self.t2024_dir, self.t2024_paths[index])) as src:
            img_t2024 = src.read(out_dtype=np.float32) / 255.0
        # Load masks
        with rasterio.open(os.path.join(self.mask_dir, self.mask_paths[index])) as src:
            cd_mask = src.read(1).astype(np.int64)

        # Convert to PyTorch tensors
        img_t2019 = torch.from_numpy(img_t2019)
        img_t2024 = torch.from_numpy(img_t2024)
        cd_mask = torch.from_numpy(cd_mask)

        # Apply transforms if any
        if self.transform is not None:
            img_t2019 = self.transform(img_t2019)
            img_t2024 = self.transform(img_t2024)

        return img_t2019, img_t2024, cd_mask

def describe_loader(loader_type):
    img2019, img2024, cd_mask = next(iter(loader_type))
    print("Batch size:", loader_type.batch_size)
    print("2019 Image Shape:", img2019.shape)
    print("2024 Image Shape:", img2024.shape)
    print("Change Mask Shape:", cd_mask.shape)
    print("Number of images:", len(loader_type.dataset))
    print("Classes:", loader_type.dataset.classes)
    print("Unique CD values:", torch.unique(cd_mask))

# Example usage:
ROOT_DIRECTORY = "ChangeDetectionMergedDividedSplit-tif"
SAVING_DIR = "/content/drive/MyDrive/BTechProject"
CD_DIR = "cd2_Output"
#CLASSES = ['no_change','vegetation_increase','vegetation_decrease']
CLASSES = ['no_change', 'water_building', 'water_sparse', 'water_dense',
           'building_water', 'building_sparse', 'building_dense',
           'sparse_water', 'sparse_building', 'sparse_dense',
           'dense_water', 'dense_building', 'dense_sparse']

# Create datasets
train_dataset = ChangeDetectionDatasetTIF(
    t2019_dir=f"{ROOT_DIRECTORY}/train/Images/T2019",
    t2024_dir=f"{ROOT_DIRECTORY}/train/Images/T2024",
    mask_dir=f"{ROOT_DIRECTORY}/train/{CD_DIR}",
    classes=CLASSES
)

val_dataset = ChangeDetectionDatasetTIF(
    t2019_dir=f"{ROOT_DIRECTORY}/val/Images/T2019",
    t2024_dir=f"{ROOT_DIRECTORY}/val/Images/T2024",
    mask_dir=f"{ROOT_DIRECTORY}/val/{CD_DIR}",
    classes=CLASSES
)

test_dataset = ChangeDetectionDatasetTIF(
    t2019_dir=f"{ROOT_DIRECTORY}/test/Images/T2019",
    t2024_dir=f"{ROOT_DIRECTORY}/test/Images/T2024",
    mask_dir=f"{ROOT_DIRECTORY}/test/{CD_DIR}",
    classes=CLASSES
)

# Create dataloaders
num_workers = 8
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,num_workers=num_workers, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,num_workers=num_workers, pin_memory=True)

print("------------Train-----------")
describe_loader(train_loader)
print("------------Val------------")
describe_loader(val_loader)
print("------------Test------------")
describe_loader(test_loader)

## Data Visualization

In [None]:
import matplotlib.pyplot as plt
import random

# Set up the plot size and remove axes
fig, axs = plt.subplots(5, 3, figsize=(10,10))

for i in range(5):
    j = random.randint(0, len(train_dataset) - 1)
    image1, image2, mask = train_dataset[j]

    # Display images
    axs[i, 0].imshow(image1.permute(1, 2, 0))
    axs[i, 0].set_title(f"Real 2019")
    axs[i, 0].axis("off")

    axs[i, 1].imshow(image2.permute(1, 2, 0))
    axs[i, 1].set_title(f"Real 2024")
    axs[i, 1].axis("off")

    axs[i, 2].imshow(mask, cmap="turbo")
    print(np.unique(mask))
    axs[i, 2].set_title(f"CD Mask")
    axs[i, 2].axis("off")

plt.show()

## Model Definition

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.padding import ReplicationPad2d
import torch.optim as optim

class conv_block_nested(nn.Module):
    def __init__(self, in_ch, mid_ch, out_ch):
        super(conv_block_nested, self).__init__()
        self.activation = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size=3, padding=1, bias=True)
        self.bn1 = nn.BatchNorm2d(mid_ch)
        self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1, bias=True)
        self.bn2 = nn.BatchNorm2d(out_ch)

    def forward(self, x):
        x = self.conv1(x)
        identity = x
        x = self.bn1(x)
        x = self.activation(x)

        x = self.conv2(x)
        x = self.bn2(x)
        output = self.activation(x + identity)
        return output

class up(nn.Module):
    def __init__(self, in_ch, bilinear=True):
        super(up, self).__init__()
        self.up = nn.Upsample(scale_factor=2,
                            mode='bilinear',
                            align_corners=True) if bilinear else \
                 nn.ConvTranspose2d(in_ch, in_ch, 2, stride=2)

    def forward(self, x):
        x = self.up(x)
        return x

class Siam_NestedUNet_Conc(nn.Module):
    def __init__(self, in_ch=3, out_ch=3):
        super(Siam_NestedUNet_Conc, self).__init__()
        torch.nn.Module.dump_patches = True
        n1 = 32     # Initial number of channels
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Encoder path for both images
        self.conv0_0 = conv_block_nested(in_ch, filters[0], filters[0])
        self.conv1_0 = conv_block_nested(filters[0], filters[1], filters[1])
        self.Up1_0 = up(filters[1])
        self.conv2_0 = conv_block_nested(filters[1], filters[2], filters[2])
        self.Up2_0 = up(filters[2])
        self.conv3_0 = conv_block_nested(filters[2], filters[3], filters[3])
        self.Up3_0 = up(filters[3])
        self.conv4_0 = conv_block_nested(filters[3], filters[4], filters[4])
        self.Up4_0 = up(filters[4])

        # Nested dense connections with batch norm
        self.conv0_1 = conv_block_nested(filters[0] * 2 + filters[1], filters[0], filters[0])
        self.conv1_1 = conv_block_nested(filters[1] * 2 + filters[2], filters[1], filters[1])
        self.Up1_1 = up(filters[1])
        self.conv2_1 = conv_block_nested(filters[2] * 2 + filters[3], filters[2], filters[2])
        self.Up2_1 = up(filters[2])
        self.conv3_1 = conv_block_nested(filters[3] * 2 + filters[4], filters[3], filters[3])
        self.Up3_1 = up(filters[3])

        self.conv0_2 = conv_block_nested(filters[0] * 3 + filters[1], filters[0], filters[0])
        self.conv1_2 = conv_block_nested(filters[1] * 3 + filters[2], filters[1], filters[1])
        self.Up1_2 = up(filters[1])
        self.conv2_2 = conv_block_nested(filters[2] * 3 + filters[3], filters[2], filters[2])
        self.Up2_2 = up(filters[2])

        self.conv0_3 = conv_block_nested(filters[0] * 4 + filters[1], filters[0], filters[0])
        self.conv1_3 = conv_block_nested(filters[1] * 4 + filters[2], filters[1], filters[1])
        self.Up1_3 = up(filters[1])

        self.conv0_4 = conv_block_nested(filters[0] * 5 + filters[1], filters[0], filters[0])

        # Add batch normalization to deep supervision outputs
        self.final1 = nn.Sequential(
            nn.Conv2d(filters[0], filters[0] // 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(filters[0] // 2),
            nn.ReLU(inplace=True),
            nn.Conv2d(filters[0] // 2, out_ch, kernel_size=1)
        )

        self.final2 = nn.Sequential(
            nn.Conv2d(filters[0], filters[0] // 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(filters[0] // 2),
            nn.ReLU(inplace=True),
            nn.Conv2d(filters[0] // 2, out_ch, kernel_size=1)
        )

        self.final3 = nn.Sequential(
            nn.Conv2d(filters[0], filters[0] // 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(filters[0] // 2),
            nn.ReLU(inplace=True),
            nn.Conv2d(filters[0] // 2, out_ch, kernel_size=1)
        )

        self.final4 = nn.Sequential(
            nn.Conv2d(filters[0], filters[0] // 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(filters[0] // 2),
            nn.ReLU(inplace=True),
            nn.Conv2d(filters[0] // 2, out_ch, kernel_size=1)
        )

        # Final combination layer with better feature extraction
        self.conv_final = nn.Sequential(
            nn.Conv2d(out_ch * 4, filters[0], kernel_size=3, padding=1),
            nn.BatchNorm2d(filters[0]),
            nn.ReLU(inplace=True),
            nn.Conv2d(filters[0], filters[0] // 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(filters[0] // 2),
            nn.ReLU(inplace=True),
            nn.Conv2d(filters[0] // 2, out_ch, kernel_size=1)
        )

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # Use Xavier/Glorot initialization for final layers
                if m.kernel_size[0] == 1:
                    nn.init.xavier_uniform_(m.weight)
                else:
                    nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, xA, xB):
        # Encoder Path A
        x0_0A = self.conv0_0(xA)
        x1_0A = self.conv1_0(self.pool(x0_0A))
        x2_0A = self.conv2_0(self.pool(x1_0A))
        x3_0A = self.conv3_0(self.pool(x2_0A))

        # Encoder Path B
        x0_0B = self.conv0_0(xB)
        x1_0B = self.conv1_0(self.pool(x0_0B))
        x2_0B = self.conv2_0(self.pool(x1_0B))
        x3_0B = self.conv3_0(self.pool(x2_0B))
        x4_0B = self.conv4_0(self.pool(x3_0B))

        # Nested Dense Connections and Decoder Path
        x0_1 = self.conv0_1(torch.cat([x0_0A, x0_0B, self.Up1_0(x1_0B)], 1))
        x1_1 = self.conv1_1(torch.cat([x1_0A, x1_0B, self.Up2_0(x2_0B)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0A, x0_0B, x0_1, self.Up1_1(x1_1)], 1))

        x2_1 = self.conv2_1(torch.cat([x2_0A, x2_0B, self.Up3_0(x3_0B)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0A, x1_0B, x1_1, self.Up2_1(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0A, x0_0B, x0_1, x0_2, self.Up1_2(x1_2)], 1))

        x3_1 = self.conv3_1(torch.cat([x3_0A, x3_0B, self.Up4_0(x4_0B)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0A, x2_0B, x2_1, self.Up3_1(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0A, x1_0B, x1_1, x1_2, self.Up2_2(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0A, x0_0B, x0_1, x0_2, x0_3, self.Up1_3(x1_3)], 1))

        # Get outputs at different scales
        output1 = self.final1(x0_1)
        output2 = self.final2(x0_2)
        output3 = self.final3(x0_3)
        output4 = self.final4(x0_4)

        # Combine outputs
        output = self.conv_final(torch.cat([output1, output2, output3, output4], 1))

        # Return logits without softmax for CrossEntropyLoss
        return output

## Util Functions and Training Loop

In [4]:
def calculate_effective_weights(train_loader, device, num_classes=3, method='square_balanced'):
    """Calculate class weights with different strategies to handle class imbalance

    Args:
        train_loader: DataLoader containing training data
        device: torch device
        num_classes: number of classes (default: 3)
        method: weighting strategy ('balanced', 'square_balanced', or 'custom')
    """
    class_counts = torch.zeros(num_classes)
    total_pixels = 0

    # Count class frequencies
    for _, _, labels in train_loader:
        labels = labels.to(device)
        for i in range(num_classes):
            class_counts[i] += (labels == i).sum().item()
        total_pixels += labels.numel()

    class_frequencies = class_counts / total_pixels

    if method == 'balanced':
        # Standard balanced weighting (inverse frequency)
        weights = 1.0 / class_frequencies

    elif method == 'square_balanced':
        # Square root of inverse frequencies (less aggressive balancing)
        weights = torch.sqrt(1.0 / class_frequencies)

    elif method == 'custom':
        # Custom weighting that maintains some natural class distribution
        # Adjust these factors based on your domain knowledge
        base_weights = 1.0 / class_frequencies
        adjustment_factors = torch.tensor([0.7, 1.2, 1.2])  # Reduce weight of class 0, increase others
        weights = base_weights * adjustment_factors

    # Normalize weights to sum to num_classes
    weights = weights * (num_classes / weights.sum())

    return weights, class_frequencies

# Define Focal Loss
class FocalLoss(nn.Module):
    def __init__(self, weight=None, gamma=2.0):
        super().__init__()
        self.weight = weight
        self.gamma = gamma

    def forward(self, input, target):
        ce_loss = F.cross_entropy(input, target, reduction='none', weight=self.weight)
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
        return focal_loss

class MulticlassBCLLoss(nn.Module):
    """Multiclass Batch-balanced Contrastive Loss for 3-class change detection"""

    def __init__(self, margin=2.0, ignore_index=255):
        super().__init__()
        self.margin = margin
        self.ignore_index = ignore_index
        self.eps = 1e-4

    def forward(self, pred, target):
        """
        Args:
            pred: Model predictions (B, C, H, W) where C is number of classes
            target: Ground truth labels (B, H, W) with values [0, 1, 2]
                   0: no change
                   1: first type of change
                   2: second type of change
        """
        # Apply softmax to get probabilities
        pred = torch.softmax(pred, dim=1)

        # Convert predictions and target to correct shape
        pred = pred.permute(0, 2, 3, 1).reshape(-1, pred.size(1))  # (N, C)
        target = target.reshape(-1)  # (N,)

        # Create mask for valid pixels
        mask = (target != self.ignore_index).float()

        # Initialize loss
        total_loss = 0.0

        # Handle each class
        for class_idx in range(3):
            # Create binary target for current class
            class_target = (target == class_idx).float() * mask

            # Get prediction probability for current class
            class_pred = pred[:, class_idx]

            # Count pixels
            n_class = class_target.sum() + self.eps
            n_others = (mask.sum() - n_class) + self.eps

            # Calculate loss for current class
            # For pixels of current class: penalize predictions below margin
            class_loss = torch.sum(
                class_target * torch.pow(torch.clamp(self.margin - class_pred, min=0.), 2)
            ) / n_class

            # For pixels of other classes: penalize high predictions
            other_loss = torch.sum(
                (1 - class_target) * torch.pow(class_pred, 2) * mask
            ) / n_others

            total_loss += (class_loss + other_loss) / 3

        return total_loss

    def _calculate_margin_loss(self, pred, target, mask, class_idx):
        """Helper function to calculate margin-based loss for one class"""
        # Binary target for current class
        class_target = (target == class_idx).float() * mask

        # Get prediction probability
        class_pred = pred[:, class_idx]

        # Count pixels
        n_class = class_target.sum() + self.eps
        n_others = (mask.sum() - n_class) + self.eps

        # Calculate losses
        positive_loss = torch.sum(
            class_target * torch.pow(torch.clamp(self.margin - class_pred, min=0.), 2)
        ) / n_class

        negative_loss = torch.sum(
            (1 - class_target) * torch.pow(class_pred, 2) * mask
        ) / n_others

        return positive_loss + negative_loss

# def calculate_metrics(outputs, labels):
#     """
#     Calculate comprehensive metrics for change detection using confusion matrix.
#     Returns only overall metrics for simpler interpretation.

#     Args:
#         outputs (torch.Tensor): Model outputs (N, C, H, W)
#         labels (torch.Tensor): Ground truth labels (N, H, W)

#     Returns:
#         dict: Dictionary of overall performance metrics
#     """
#     # Get predicted classes
#     preds = torch.argmax(outputs, dim=1)
#     num_classes = outputs.size(1)

#     # Create confusion matrix
#     num_samples = labels.numel()
#     confusion_matrix = torch.zeros((num_classes, num_classes), device=outputs.device)
#     indices = num_classes * labels.long() + preds.long()
#     confusion_matrix = confusion_matrix.view(-1).index_add_(
#         0, indices.view(-1), torch.ones(num_samples, device=outputs.device)
#     ).view(num_classes, num_classes)

#     # Move to CPU for numpy operations
#     cm = confusion_matrix.cpu().numpy()

#     # True positives, false positives, false negatives for each class
#     tp = np.diag(cm)
#     fp = np.sum(cm, axis=0) - tp
#     fn = np.sum(cm, axis=1) - tp

#     # Overall accuracy
#     accuracy = np.sum(tp) / np.sum(cm)

#     # Per-class precision, recall, F1 to calculate weighted averages
#     precision = tp / (tp + fp + 1e-6)
#     recall = tp / (tp + fn + 1e-6)
#     f1 = 2 * (precision * recall) / (precision + recall + 1e-6)

#     # Weighted averages using class frequencies
#     weights = np.sum(cm, axis=1)
#     weighted_precision = np.average(precision, weights=weights)
#     weighted_recall = np.average(recall, weights=weights)
#     weighted_f1 = np.average(f1, weights=weights)

#     # Calculate Kappa
#     n = np.sum(cm)
#     sum_po = np.sum(np.diag(cm))
#     sum_pe = np.sum(np.sum(cm, axis=0) * np.sum(cm, axis=1)) / n
#     kappa = (sum_po - sum_pe) / (n - sum_pe + 1e-6)

#     # Mean IoU
#     iou_per_class = tp / (tp + fp + fn + 1e-6)
#     miou = np.mean(iou_per_class)

#     metrics = {
#         'accuracy': float(accuracy),
#         'precision': float(weighted_precision),
#         'recall': float(weighted_recall),
#         'f1_score': float(weighted_f1),
#         'miou': float(miou),
#         'kappa': float(kappa)
#     }

#     return metrics

from sklearn.metrics import confusion_matrix
def calculate_metrics(outputs, labels, num_classes=3):
    """
    Calculate comprehensive metrics for change detection using a single confusion matrix

    Args:
        outputs (torch.Tensor or np.array): Model outputs or predictions
        labels (torch.Tensor or np.array): Ground truth class labels
        num_classes (int): Number of classes in the dataset

    Returns:
        dict: Dictionary of performance metrics
    """
    # Convert to numpy if inputs are torch tensors
    if torch.is_tensor(outputs):
        # For model outputs, get predictions first
        predictions = torch.argmax(outputs, dim=1).cpu().numpy()
    else:
        predictions = outputs

    if torch.is_tensor(labels):
        labels = labels.cpu().numpy()

    # Flatten predictions and targets
    pred_flat = predictions.flatten()
    target_flat = labels.flatten()

    # Compute confusion matrix once
    cm = confusion_matrix(target_flat, pred_flat, labels=list(range(num_classes)))
    # Flatten predictions and targets

    # pred_flat = predictions.flatten()
    # target_flat = targets.flatten()

    # # Compute confusion matrix once
    # cm = confusion_matrix(target_flat, pred_flat)

    # Calculate metrics from confusion matrix
    metrics = {}

    # True positives, false positives, false negatives for each class
    tp = np.diag(cm)
    fp = np.sum(cm, axis=0) - tp
    fn = np.sum(cm, axis=1) - tp

    # Overall accuracy from confusion matrix
    metrics['accuracy'] = np.sum(tp) / np.sum(cm)

    # Per-class precision, recall, F1
    precision = tp / (tp + fp + 1e-6)
    recall = tp / (tp + fn + 1e-6)
    f1 = 2 * (precision * recall) / (precision + recall + 1e-6)

    # Weighted averages
    #total = np.sum(cm, axis=1)
    metrics['precision'] = np.average(precision,) #weights=total)
    metrics['recall'] = np.average(recall,)# weights=total)
    metrics['f1_score'] = np.average(f1,) #weights=total)

    # Calculate Kappa directly from confusion matrix
    n = np.sum(cm)
    sum_po = np.sum(np.diag(cm))
    sum_pe = np.sum(np.sum(cm, axis=0) * np.sum(cm, axis=1)) / n
    metrics['kappa'] = (sum_po - sum_pe) / (n - sum_pe + 1e-6)

    # IoU from confusion matrix
    iou_per_class = tp / (tp + fp + fn + 1e-6)
    metrics['miou'] = np.mean(iou_per_class)

    return metrics

def train_model_balanced(model, train_loader, val_loader, num_epochs=50, num_classes=3, device='cuda',
                         weighting_method='square_balanced', loss='CE',
                         checkpoint_path='/content/drive/MyDrive/best_model_multiclass.pt'):
    """
    Training function with simplified metrics tracking.
    """
    start_epoch = 0
    best_val_loss = float('inf')

    # Initialize history dictionary
    def init_phase_metrics():
        return {
            'loss': [],
            'accuracy': [],
            'precision': [],
            'recall': [],
            'f1_score': [],
            'miou': [],
            'kappa': []
        }

    history = {
        'train': init_phase_metrics(),
        'val': init_phase_metrics()
    }

    # Load checkpoint if exists
    if os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}")
        try:
            checkpoint = torch.load(checkpoint_path, weights_only=True)
            model.load_state_dict(checkpoint['model_state_dict'])
            start_epoch = checkpoint['epoch']
            best_val_loss = checkpoint['best_val_loss']
            print(f"Resuming from epoch {start_epoch} with best val loss: {best_val_loss:.4f}")
        except Exception as e:
            print(f"Error loading checkpoint: {e}")
            print("Starting training from scratch")

    # Setup model, optimizer, criterion
    class_weights, _ = calculate_effective_weights(train_loader, device, num_classes=num_classes, method=weighting_method)
    print(class_weights)
    criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
    model.to(device)

    def process_epoch(phase, data_loader):
        if phase == 'train':
            model.train()
        else:
            model.eval()

        running_metrics = {
            'loss': 0.0,
            'accuracy': 0.0,
            'precision': 0.0,
            'recall': 0.0,
            'f1_score': 0.0,
            'miou': 0.0,
            'kappa': 0.0
        }
        samples_count = 0

        with torch.set_grad_enabled(phase == 'train'):
            for inputs1, inputs2, labels in data_loader:
                inputs1, inputs2, labels = inputs1.to(device), inputs2.to(device), labels.to(device)
                batch_size = inputs1.size(0)

                if phase == 'train':
                    optimizer.zero_grad()

                outputs = model(inputs1, inputs2)
                loss = criterion(outputs, labels)

                if phase == 'train':
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    optimizer.step()

                # Calculate metrics
                metrics = calculate_metrics(outputs, labels, num_classes=num_classes)
                metrics['loss'] = loss.item()

                # Update running metrics
                for key in running_metrics:
                    running_metrics[key] += metrics[key] * batch_size
                samples_count += batch_size

        # Calculate epoch metrics
        epoch_metrics = {key: value / samples_count for key, value in running_metrics.items()}

        # Store metrics in history
        for key in history[phase]:
            history[phase][key].append(epoch_metrics[key])

        return epoch_metrics

    # Training loop
    for epoch in range(start_epoch, num_epochs):
        print(f'\nEpoch {epoch + 1}/{num_epochs}:')

        # Training phase
        train_metrics = process_epoch('train', train_loader)

        # Validation phase
        val_metrics = process_epoch('val', val_loader)

        # Print metrics
        def print_metrics(phase, metrics):
            print(f'\n{phase.capitalize()} Metrics:')
            print(f'  Loss: {metrics["loss"]:.4f}')
            print(f'  Accuracy: {metrics["accuracy"]:.4f}')
            print(f'  Precision: {metrics["precision"]:.4f}')
            print(f'  Recall: {metrics["recall"]:.4f}')
            print(f'  F1-score: {metrics["f1_score"]:.4f}')
            print(f'  mIoU: {metrics["miou"]:.4f}')
            print(f'  Kappa: {metrics["kappa"]:.4f}')

        print_metrics('train', train_metrics)
        print_metrics('val', val_metrics)

        # Update learning rate scheduler
        scheduler.step(val_metrics['loss'])

        # Save checkpoint if it's the best model
        if val_metrics['loss'] < best_val_loss:
            best_val_loss = val_metrics['loss']
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_val_loss': best_val_loss,
                'metrics': val_metrics,
                'history': history
            }
            torch.save(checkpoint, checkpoint_path)
            print(f'\nSaved new best model with validation loss: {val_metrics["loss"]:.4f}')

    return model, history

import json

import json
import torch
import numpy as np

def save_training_files(history, checkpoint_path, history_filename, bestepoch_filename):
    """Save training history and best epoch info to separate JSON files"""

    def convert_to_serializable(value):
        """Recursively convert numpy/torch types to basic Python types"""
        if isinstance(value, (np.ndarray, torch.Tensor)):
            return value.tolist()
        elif isinstance(value, dict):
            return {k: convert_to_serializable(v) for k, v in value.items()}
        elif isinstance(value, list):
            return [convert_to_serializable(item) for item in value]
        return value

    history_data = {
        phase: {
            metric: convert_to_serializable(values)
            for metric, values in metrics.items()
        }
        for phase, metrics in history.items()
    }

    with open(history_filename, 'w') as f:
        json.dump(history_data, f, indent=4)

    # Load checkpoint without weights_only flag
    checkpoint = torch.load(checkpoint_path)
    # print("\nCheckpoint contents:")
    # for key in checkpoint.keys():
    #     print(f"- {key}")

    # Convert metrics to basic Python types
    epoch_data = {
        'best_epoch': checkpoint['epoch'],
        'best_val_loss': checkpoint['best_val_loss'],
        'val_metrics': convert_to_serializable(checkpoint['metrics'])
    }

    with open(bestepoch_filename, 'w') as f:
        json.dump(epoch_data, f, indent=4)

    print(f"\nSaved training history to: {history_filename}")
    print(f"Saved best epoch info to: {bestepoch_filename}")


## Model Run

In [None]:
# Initialize and train the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = 'snunet_conc'
strategy = 'st2' #change detection strategy {1,2,3,4}
num_classes = 13  #num classes in change mask
num_epochs = 2
weighting_method = 'square_balanced' #'custom'
loss = 'CE' #'focal' #'bcl'
checkpoint_path = f'{SAVING_DIR}/best_{strategy}_{model_name}-{num_classes}_classes_{num_epochs}.pt'

model = Siam_NestedUNet_Conc(in_ch=3, out_ch=num_classes).to(device)
model2, history = train_model_balanced(model, train_loader, val_loader,
                                      num_epochs=num_epochs, num_classes=num_classes,
                                      device=device,
                                      weighting_method=weighting_method,loss=loss,
                                      checkpoint_path=checkpoint_path)


history_filename = f"{SAVING_DIR}/{strategy}_{model_name}-{num_classes}_classes_{num_epochs}_history.json"
bestepoch_filename = f"{SAVING_DIR}/{strategy}_{model_name}-{num_classes}_classes_{num_epochs}_best_epoch.json"
save_training_files(history=history,checkpoint_path=checkpoint_path,
                    history_filename=history_filename,bestepoch_filename=bestepoch_filename)

## Model Testing

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from torchvision import transforms
import random


def test_model(model, test_loader, loss='CE', device='cuda',
               num_classes=3, weighting_method='square_balanced'):

    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded checkpoint from {checkpoint_path}")
    model.eval()

    # Calculate class weights
    class_weights, _ = calculate_effective_weights(test_loader, device,
                                                   num_classes=num_classes,
                                                   method=weighting_method)
    print(f"Class weights: {class_weights}")

    # Select loss function
    if loss.lower() == 'focal':
        focal_gamma = 2.0
        criterion = FocalLoss(weight=class_weights.to(device), gamma=focal_gamma)
    elif loss.lower() == 'bcl':
        criterion = MulticlassBCLLoss(margin=2.0).to(device)
        print("Using Batch-balanced Contrastive Loss")
    else:  # 'CE'
        criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))

    # For visualization and metrics
    random_samples = []
    total_loss = 0.0
    total_samples = 0

    # Collect predictions and labels for comprehensive metrics
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for inputs1, inputs2, labels in test_loader:
            inputs1, inputs2, labels = inputs1.to(device), inputs2.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs1, inputs2)
            loss = criterion(outputs, labels)

            # Accumulate loss
            total_loss += loss.item() * inputs1.size(0)
            total_samples += inputs1.size(0)

            # Get predictions
            preds = torch.argmax(outputs, dim=1)

            # Store predictions and labels
            all_predictions.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

            # Store random samples for visualization
            if len(random_samples) < 5:
                for i in range(min(inputs1.size(0), 5 - len(random_samples))):
                    if random.random() < 0.2:  # 20% chance to select each sample
                        random_samples.append({
                            'image1': inputs1[i].cpu(),
                            'image2': inputs2[i].cpu(),
                            'label': labels[i].cpu(),
                            'pred': preds[i].cpu(),
                            'probabilities': torch.softmax(outputs[i], dim=0).cpu()
                        })

    # Concatenate predictions and labels
    all_predictions = np.concatenate(all_predictions)
    all_labels = np.concatenate(all_labels)

    # Calculate metrics
    test_metrics = calculate_metrics(all_predictions, all_labels, num_classes)

    # Add loss to metrics
    test_metrics['loss'] = total_loss / total_samples

    # Make sure we have exactly 5 samples
    while len(random_samples) < 5:
        random_samples.append(random_samples[-1] if random_samples else {
            'image1': torch.zeros(3, 64, 64),
            'image2': torch.zeros(3, 64, 64),
            'label': torch.zeros(64, 64),
            'pred': torch.zeros(64, 64),
            'probabilities': torch.zeros(3, 64, 64)
        })

    return random_samples, test_metrics

def visualize_results(random_samples, num_classes=3):
    # Extract samples and metrics
    # random_samples = random_samples_and_metrics[0]
    # test_metrics = random_samples_and_metrics[1]

    # Create a figure with subplots
    fig, axes = plt.subplots(5, 4, figsize=(25, 25))
    plt.subplots_adjust(hspace=0.3, wspace=0.3)

    for idx, sample in enumerate(random_samples):
        # Normalize and convert images for display
        img1 = sample['image1'].numpy().transpose(1, 2, 0)
        img2 = sample['image2'].numpy().transpose(1, 2, 0)
        img1 = (img1 - img1.min()) / (img1.max() - img1.min())
        img2 = (img2 - img2.min()) / (img2.max() - img2.min())

        # Get masks
        pred_mask = sample['pred'].numpy()
        true_mask = sample['label'].numpy()

        # Plot images and masks
        axes[idx, 0].imshow(img1)
        axes[idx, 0].set_title('Image 1')
        axes[idx, 0].axis('off')

        axes[idx, 1].imshow(img2)
        axes[idx, 1].set_title('Image 2')
        axes[idx, 1].axis('off')

        # Plot predicted mask
        pred_plot = axes[idx, 2].imshow(pred_mask, cmap='tab10', vmin=0, vmax=num_classes-1)
        axes[idx, 2].set_title('Predicted Change')
        axes[idx, 2].axis('off')

        # Plot ground truth mask
        true_plot = axes[idx, 3].imshow(true_mask, cmap='tab10', vmin=0, vmax=num_classes-1)
        axes[idx, 3].set_title('Ground Truth')
        axes[idx, 3].axis('off')

    plt.tight_layout()
    plt.show()

def save_test_metrics(test_metrics, save_dir, model_name, strategy, num_epochs):
    """Save test metrics to JSON"""
    metrics_file = os.path.join(save_dir, f"{strategy}_{model_name}-{num_classes}_classes_{num_epochs}_test_metrics.json")

    # Use the pre-computed metrics directly
    with open(metrics_file, 'w') as f:
        json.dump(test_metrics, f, indent=4)

    print(f"\nSaved test metrics to: {metrics_file}")

# Test the model
random_samples, test_metrics = test_model(model, test_loader, loss=loss, device=device, num_classes=num_classes)

# Save test metrics
save_test_metrics(test_metrics=test_metrics,
                  save_dir=SAVING_DIR,
                  model_name=model_name,
                  strategy=strategy,
                  num_epochs=num_epochs)

# Visualize results
visualize_results(random_samples,num_classes=num_classes)