In [None]:
from google.colab import drive
drive.mount('/content/drive')
!unzip '/content/drive/MyDrive/BTechProject/ChangeDetectionMergedDividedSplit-tif3.zip' -d '/content/ChangeDetectionMergedDividedSplit-tif'

## Data Loader

In [None]:
!pip install rasterio

In [None]:
import os
import rasterio
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class ChangeDetectionDatasetTIF(Dataset):
    def __init__(self, t2019_dir, t2024_dir, mask_dir,classes, transform=None):
        self.t2019_dir = t2019_dir
        self.t2024_dir = t2024_dir
        self.mask_dir = mask_dir
        self.classes = classes  # Change detection classes
        self.transform = transform

        # Load all paths
        self.t2019_paths = sorted([f for f in os.listdir(t2019_dir) if f.endswith('.tif')])
        self.t2024_paths = sorted([f for f in os.listdir(t2024_dir) if f.endswith('.tif')])
        self.mask_paths = sorted([f for f in os.listdir(mask_dir) if f.endswith('.tif')])

    def __len__(self):
        return len(self.t2019_paths)

    def __getitem__(self, index):
        # Load images using rasterio
        with rasterio.open(os.path.join(self.t2019_dir, self.t2019_paths[index])) as src:
            img_t2019 = src.read(out_dtype=np.float32) / 255.0
        with rasterio.open(os.path.join(self.t2024_dir, self.t2024_paths[index])) as src:
            img_t2024 = src.read(out_dtype=np.float32) / 255.0
        # Load masks
        with rasterio.open(os.path.join(self.mask_dir, self.mask_paths[index])) as src:
            cd_mask = src.read(1).astype(np.int64)

        # Convert to PyTorch tensors
        img_t2019 = torch.from_numpy(img_t2019)
        img_t2024 = torch.from_numpy(img_t2024)
        cd_mask = torch.from_numpy(cd_mask)

        # Apply transforms if any
        if self.transform is not None:
            img_t2019 = self.transform(img_t2019)
            img_t2024 = self.transform(img_t2024)

        return img_t2019, img_t2024, cd_mask

def describe_loader(loader_type):
    img2019, img2024, cd_mask = next(iter(loader_type))
    print("Batch size:", loader_type.batch_size)
    print("2019 Image Shape:", img2019.shape)
    print("2024 Image Shape:", img2024.shape)
    print("Change Mask Shape:", cd_mask.shape)
    print("Number of images:", len(loader_type.dataset))
    print("Classes:", loader_type.dataset.classes)
    print("Unique CD values:", torch.unique(cd_mask))

# Example usage:
ROOT_DIRECTORY = "ChangeDetectionMergedDividedSplit-tif"
SAVING_DIR = "/content/drive/MyDrive/BTechProject"
CD_DIR = "cd2_Output"
#CLASSES = ['no_change','vegetation_increase','vegetation_decrease']
CLASSES = ['no_change', 'water_building', 'water_sparse', 'water_dense',
           'building_water', 'building_sparse', 'building_dense',
           'sparse_water', 'sparse_building', 'sparse_dense',
           'dense_water', 'dense_building', 'dense_sparse']

# Create datasets
train_dataset = ChangeDetectionDatasetTIF(
    t2019_dir=f"{ROOT_DIRECTORY}/train/Images/T2019",
    t2024_dir=f"{ROOT_DIRECTORY}/train/Images/T2024",
    mask_dir=f"{ROOT_DIRECTORY}/train/{CD_DIR}",
    classes=CLASSES
)

val_dataset = ChangeDetectionDatasetTIF(
    t2019_dir=f"{ROOT_DIRECTORY}/val/Images/T2019",
    t2024_dir=f"{ROOT_DIRECTORY}/val/Images/T2024",
    mask_dir=f"{ROOT_DIRECTORY}/val/{CD_DIR}",
    classes=CLASSES
)

test_dataset = ChangeDetectionDatasetTIF(
    t2019_dir=f"{ROOT_DIRECTORY}/test/Images/T2019",
    t2024_dir=f"{ROOT_DIRECTORY}/test/Images/T2024",
    mask_dir=f"{ROOT_DIRECTORY}/test/{CD_DIR}",
    classes=CLASSES
)

# Create dataloaders
num_workers = 8
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,num_workers=num_workers, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,num_workers=num_workers, pin_memory=True)

print("------------Train-----------")
describe_loader(train_loader)
print("------------Val------------")
describe_loader(val_loader)
print("------------Test------------")
describe_loader(test_loader)

## Data Visualization

In [None]:
import matplotlib.pyplot as plt
import random

# Set up the plot size and remove axes
fig, axs = plt.subplots(5, 3, figsize=(10,10))

for i in range(5):
    j = random.randint(0, len(train_dataset) - 1)
    image1, image2, mask = train_dataset[j]

    # Display images
    axs[i, 0].imshow(image1.permute(1, 2, 0))
    axs[i, 0].set_title(f"Real 2019")
    axs[i, 0].axis("off")

    axs[i, 1].imshow(image2.permute(1, 2, 0))
    axs[i, 1].set_title(f"Real 2024")
    axs[i, 1].axis("off")

    axs[i, 2].imshow(mask, cmap="turbo")
    print(np.unique(mask))
    axs[i, 2].set_title(f"CD Mask")
    axs[i, 2].axis("off")

plt.show()

## Model Definition

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.padding import ReplicationPad2d
import torch.optim as optim
import json
import os
class SiamUnet_conc(nn.Module):
    """SiamUnet_conc segmentation network for multiclass change detection."""

    def __init__(self, input_nbr, label_nbr):
        super(SiamUnet_conc, self).__init__()

        self.input_nbr = input_nbr
        self.label_nbr = label_nbr  # Added for clarity

        # Encoder Layers
        self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1)
        self.bn11 = nn.BatchNorm2d(16)
        self.do11 = nn.Dropout2d(p=0.2)
        self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
        self.bn12 = nn.BatchNorm2d(16)
        self.do12 = nn.Dropout2d(p=0.2)

        self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.bn21 = nn.BatchNorm2d(32)
        self.do21 = nn.Dropout2d(p=0.2)
        self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.bn22 = nn.BatchNorm2d(32)
        self.do22 = nn.Dropout2d(p=0.2)

        self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn31 = nn.BatchNorm2d(64)
        self.do31 = nn.Dropout2d(p=0.2)
        self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn32 = nn.BatchNorm2d(64)
        self.do32 = nn.Dropout2d(p=0.2)
        self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn33 = nn.BatchNorm2d(64)
        self.do33 = nn.Dropout2d(p=0.2)

        self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn41 = nn.BatchNorm2d(128)
        self.do41 = nn.Dropout2d(p=0.2)
        self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn42 = nn.BatchNorm2d(128)
        self.do42 = nn.Dropout2d(p=0.2)
        self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn43 = nn.BatchNorm2d(128)
        self.do43 = nn.Dropout2d(p=0.2)

        self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1)

        # Decoder Layers
        self.conv43d = nn.ConvTranspose2d(384, 128, kernel_size=3, padding=1)
        self.bn43d = nn.BatchNorm2d(128)
        self.do43d = nn.Dropout2d(p=0.2)
        self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1)
        self.bn42d = nn.BatchNorm2d(128)
        self.do42d = nn.Dropout2d(p=0.2)
        self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
        self.bn41d = nn.BatchNorm2d(64)
        self.do41d = nn.Dropout2d(p=0.2)

        self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1)

        self.conv33d = nn.ConvTranspose2d(192, 64, kernel_size=3, padding=1)
        self.bn33d = nn.BatchNorm2d(64)
        self.do33d = nn.Dropout2d(p=0.2)
        self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1)
        self.bn32d = nn.BatchNorm2d(64)
        self.do32d = nn.Dropout2d(p=0.2)
        self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
        self.bn31d = nn.BatchNorm2d(32)
        self.do31d = nn.Dropout2d(p=0.2)

        self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1)

        self.conv22d = nn.ConvTranspose2d(96, 32, kernel_size=3, padding=1)
        self.bn22d = nn.BatchNorm2d(32)
        self.do22d = nn.Dropout2d(p=0.2)
        self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
        self.bn21d = nn.BatchNorm2d(16)
        self.do21d = nn.Dropout2d(p=0.2)

        self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1)

        self.conv12d = nn.ConvTranspose2d(48, 16, kernel_size=3, padding=1)
        self.bn12d = nn.BatchNorm2d(16)
        self.do12d = nn.Dropout2d(p=0.2)
        # Changed to use label_nbr instead of hardcoded value
        self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1)

        # Multiclass activation (Softmax instead of LogSoftmax)
        self.sm = nn.Softmax(dim=1)

    def forward(self, x1, x2):
        """Forward method."""
        # Stage 1
        x11 = self.do11(F.relu(self.bn11(self.conv11(x1))))
        x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11))))
        x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2)

        # Stage 2
        x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
        x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21))))
        x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2)

        # Stage 3
        x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
        x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
        x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32))))
        x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2)

        # Stage 4
        x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
        x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
        x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42))))
        x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2)

        ####################################################
        # Stage 1
        x11 = self.do11(F.relu(self.bn11(self.conv11(x2))))
        x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11))))
        x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2)

        # Stage 2
        x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
        x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21))))
        x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2)

        # Stage 3
        x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
        x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
        x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32))))
        x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2)

        # Stage 4
        x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
        x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
        x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42))))
        x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2)

        ####################################################
        # Stage 4d
        x4d = self.upconv4(x4p)
        pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2)))
        x4d = torch.cat((pad4(x4d), x43_1, x43_2), 1)
        x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d))))
        x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d))))
        x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d))))

        # Stage 3d
        x3d = self.upconv3(x41d)
        pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2)))
        x3d = torch.cat((pad3(x3d), x33_1, x33_2), 1)
        x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d))))
        x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d))))
        x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d))))

        # Stage 2d
        x2d = self.upconv2(x31d)
        pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2)))
        x2d = torch.cat((pad2(x2d), x22_1, x22_2), 1)
        x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d))))
        x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d))))

        # Stage 1d
        x1d = self.upconv1(x21d)
        pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2)))
        x1d = torch.cat((pad1(x1d), x12_1, x12_2), 1)
        x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d))))
        x11d = self.conv11d(x12d)

        return self.sm(x11d)

## Util Functions and Training Loop

In [None]:
import numpy as np
import json
import os
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms

def calculate_effective_weights(train_loader, device, num_classes=3, method='square_balanced'):
    """Calculate class weights with different strategies to handle class imbalance

    Args:
        train_loader: DataLoader containing training data
        device: torch device
        num_classes: number of classes (default: 3)
        method: weighting strategy ('balanced', 'square_balanced', or 'custom')
    """
    class_counts = torch.zeros(num_classes)
    total_pixels = 0

    # Count class frequencies
    for _, _, labels in train_loader:
        labels = labels.to(device)
        for i in range(num_classes):
            class_counts[i] += (labels == i).sum().item()
        total_pixels += labels.numel()

    class_frequencies = class_counts / total_pixels

    if method == 'balanced':
        # Standard balanced weighting (inverse frequency)
        weights = 1.0 / class_frequencies

    elif method == 'square_balanced':
        # Square root of inverse frequencies (less aggressive balancing)
        weights = torch.sqrt(1.0 / class_frequencies)

    elif method == 'custom':
        # Custom weighting that maintains some natural class distribution
        # Adjust these factors based on your domain knowledge
        base_weights = 1.0 / class_frequencies
        adjustment_factors = torch.tensor([0.7, 1.2, 1.2])
        weights = base_weights * adjustment_factors

    # Normalize weights to sum to num_classes
    weights = weights * (num_classes / weights.sum())

    return weights, class_frequencies


from sklearn.metrics import confusion_matrix
def calculate_metrics(outputs, labels, num_classes=3):
    """
    Calculate comprehensive metrics for change detection using a single confusion matrix

    Args:
        outputs (torch.Tensor or np.array): Model outputs or predictions
        labels (torch.Tensor or np.array): Ground truth class labels
        num_classes (int): Number of classes in the dataset

    Returns:
        list: List of overall performance metrics
    """

    # Convert to numpy if inputs are torch tensors
    if torch.is_tensor(outputs):
        predictions = torch.argmax(outputs, dim=1).cpu().numpy()
    else:
        predictions = outputs

    if torch.is_tensor(labels):
        labels = labels.cpu().numpy()

    # Flatten predictions and targets
    pred_flat = predictions.flatten()
    target_flat = labels.flatten()

    # Compute confusion matrix once
    cm = confusion_matrix(target_flat, pred_flat, labels=list(range(num_classes)))

    # Calculate metrics from confusion matrix
    metrics = {}

    # True positives, false positives, false negatives for each class
    tp = np.diag(cm)
    fp = np.sum(cm, axis=0) - tp
    fn = np.sum(cm, axis=1) - tp

    # Overall accuracy from confusion matrix
    metrics['accuracy'] = np.sum(tp) / np.sum(cm)

    # Per-class precision, recall, F1
    precision = tp / (tp + fp + 1e-6)
    recall = tp / (tp + fn + 1e-6)
    f1 = 2 * (precision * recall) / (precision + recall + 1e-6)

    # Unweighted averages
    metrics['precision'] = np.average(precision) 
    metrics['recall'] = np.average(recall)
    metrics['f1_score'] = np.average(f1)

    # Weighted averages
    # total = np.sum(cm, axis=1)
    # metrics['precision'] = np.average(precision,weights=total)
    # metrics['recall'] = np.average(recall, weights=total)
    # metrics['f1_score'] = np.average(f1,weights=total)

    # Calculate Kappa directly from confusion matrix
    n = np.sum(cm)
    sum_po = np.sum(np.diag(cm))
    sum_pe = np.sum(np.sum(cm, axis=0) * np.sum(cm, axis=1)) / n
    metrics['kappa'] = (sum_po - sum_pe) / (n - sum_pe + 1e-6)

    # IoU from confusion matrix
    iou_per_class = tp / (tp + fp + fn + 1e-6)
    metrics['miou'] = np.mean(iou_per_class)

    return metrics

def train_model_balanced(model, train_loader, val_loader, num_epochs=50, num_classes=3, 
                         device='cuda', weighting_method='square_balanced',
                         checkpoint_path='best_model_multiclass.pt'):
    start_epoch = 0
    best_val_loss = float('inf')

    # Initialize history dictionary
    def init_phase_metrics():
        return {
            'loss': [],
            'accuracy': [],
            'precision': [],
            'recall': [],
            'f1_score': [],
            'miou': [],
            'kappa': []
        }

    history = {
        'train': init_phase_metrics(),
        'val': init_phase_metrics()
    }

    # Load checkpoint if exists
    if os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}")
        try:
            checkpoint = torch.load(checkpoint_path, weights_only=True)
            model.load_state_dict(checkpoint['model_state_dict'])
            start_epoch = checkpoint['epoch']
            best_val_loss = checkpoint['best_val_loss']
            print(f"Resuming from epoch {start_epoch} with best val loss: {best_val_loss:.4f}")
        except Exception as e:
            print(f"Error loading checkpoint: {e}")
            print("Starting training from scratch")

    # Setup model, optimizer, criterion
    class_weights, _ = calculate_effective_weights(train_loader, device, num_classes=num_classes, method=weighting_method)
    print(class_weights)
    criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))
    
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
    model.to(device)

    def process_epoch(phase, data_loader):
        if phase == 'train':
            model.train()
        else:
            model.eval()

        running_metrics = {
            'loss': 0.0,
            'accuracy': 0.0,
            'precision': 0.0,
            'recall': 0.0,
            'f1_score': 0.0,
            'miou': 0.0,
            'kappa': 0.0
        }
        samples_count = 0

        with torch.set_grad_enabled(phase == 'train'):
            for inputs1, inputs2, labels in data_loader:
                inputs1, inputs2, labels = inputs1.to(device), inputs2.to(device), labels.to(device)
                batch_size = inputs1.size(0)

                if phase == 'train':
                    optimizer.zero_grad()

                outputs = model(inputs1, inputs2)
                loss = criterion(outputs, labels)

                if phase == 'train':
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    optimizer.step()

                # Calculate metrics
                metrics = calculate_metrics(outputs, labels, num_classes=num_classes)
                metrics['loss'] = loss.item()

                # Update running metrics
                for key in running_metrics:
                    running_metrics[key] += metrics[key] * batch_size
                samples_count += batch_size

        # Calculate epoch metrics
        epoch_metrics = {key: value / samples_count for key, value in running_metrics.items()}

        # Store metrics in history
        for key in history[phase]:
            history[phase][key].append(epoch_metrics[key])

        return epoch_metrics

    # Training loop
    for epoch in range(start_epoch, num_epochs):
        print(f'\nEpoch {epoch + 1}/{num_epochs}:')

        # Training phase
        train_metrics = process_epoch('train', train_loader)

        # Validation phase
        val_metrics = process_epoch('val', val_loader)

        # Print metrics
        def print_metrics(phase, metrics):
            print(f'\n{phase.capitalize()} Metrics:')
            print(f'  Loss: {metrics["loss"]:.4f}')
            print(f'  Accuracy: {metrics["accuracy"]:.4f}')
            print(f'  Precision: {metrics["precision"]:.4f}')
            print(f'  Recall: {metrics["recall"]:.4f}')
            print(f'  F1-score: {metrics["f1_score"]:.4f}')
            print(f'  mIoU: {metrics["miou"]:.4f}')
            print(f'  Kappa: {metrics["kappa"]:.4f}')

        print_metrics('train', train_metrics)
        print_metrics('val', val_metrics)

        # Update learning rate scheduler
        scheduler.step(val_metrics['loss'])

        # Save checkpoint if it's the best model
        if val_metrics['loss'] < best_val_loss:
            best_val_loss = val_metrics['loss']
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_val_loss': best_val_loss,
                'metrics': val_metrics,
                'history': history
            }
            torch.save(checkpoint, checkpoint_path)
            print(f'\nSaved new best model with validation loss: {val_metrics["loss"]:.4f}')

    return model, history

def save_training_files(history, checkpoint_path, history_filename, bestepoch_filename):
    """Save training history and best epoch info to separate JSON files"""

    def convert_to_serializable(value):
        """Recursively convert numpy/torch types to basic Python types"""
        if isinstance(value, (np.ndarray, torch.Tensor)):
            return value.tolist()
        elif isinstance(value, dict):
            return {k: convert_to_serializable(v) for k, v in value.items()}
        elif isinstance(value, list):
            return [convert_to_serializable(item) for item in value]
        return value

    history_data = {
        phase: {
            metric: convert_to_serializable(values)
            for metric, values in metrics.items()
        }
        for phase, metrics in history.items()
    }

    with open(history_filename, 'w') as f:
        json.dump(history_data, f, indent=4)

    # Load checkpoint without weights_only flag
    checkpoint = torch.load(checkpoint_path)
    # print("\nCheckpoint contents:")
    # for key in checkpoint.keys():
    #     print(f"- {key}")

    # Convert metrics to basic Python types
    epoch_data = {
        'best_epoch': checkpoint['epoch'],
        'best_val_loss': checkpoint['best_val_loss'],
        'val_metrics': convert_to_serializable(checkpoint['metrics'])
    }

    with open(bestepoch_filename, 'w') as f:
        json.dump(epoch_data, f, indent=4)

    print(f"\nSaved training history to: {history_filename}")
    print(f"Saved best epoch info to: {bestepoch_filename}")

## Model Run

In [None]:
# Initialize and train the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = 'siamunet_conc'
strategy = 'st2' #change detection strategy {1,2,3,4}
num_classes = 13  #num classes in change mask
num_epochs = 2
weighting_method = 'square_balanced' #'custom'

checkpoint_path = f'{SAVING_DIR}/best_{strategy}_{model_name}-{num_classes}_classes_{num_epochs}.pt'

model = SiamUnet_conc(input_nbr=3, label_nbr=num_classes).to(device)
model2, history = train_model_balanced(model, train_loader, val_loader,
                                      num_epochs=num_epochs, num_classes=num_classes,
                                      device=device,
                                      weighting_method=weighting_method,
                                      checkpoint_path=checkpoint_path)


history_filename = f"{SAVING_DIR}/{strategy}_{model_name}-{num_classes}_classes_{num_epochs}_history.json"
bestepoch_filename = f"{SAVING_DIR}/{strategy}_{model_name}-{num_classes}_classes_{num_epochs}_best_epoch.json"
save_training_files(history=history,checkpoint_path=checkpoint_path,
                    history_filename=history_filename,bestepoch_filename=bestepoch_filename)

## Model Testing

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from torchvision import transforms
import random

def test_model(model, test_loader, device='cuda',
               num_classes=3, weighting_method='square_balanced'):

    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded checkpoint from {checkpoint_path}")
    model.eval()

    # Calculate class weights
    class_weights, _ = calculate_effective_weights(test_loader, device,
                                                   num_classes=num_classes,
                                                   method=weighting_method)
    print(f"Class weights: {class_weights}")

    # Select loss function
    criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))

    # For visualization and metrics
    random_samples = []
    total_loss = 0.0
    total_samples = 0

    # Collect predictions and labels for comprehensive metrics
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for inputs1, inputs2, labels in test_loader:
            inputs1, inputs2, labels = inputs1.to(device), inputs2.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs1, inputs2)
            loss = criterion(outputs, labels)

            # Accumulate loss
            total_loss += loss.item() * inputs1.size(0)
            total_samples += inputs1.size(0)

            # Get predictions
            preds = torch.argmax(outputs, dim=1)

            # Store predictions and labels
            all_predictions.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

            # Store random samples for visualization
            if len(random_samples) < 5:
                for i in range(min(inputs1.size(0), 5 - len(random_samples))):
                    if random.random() < 0.2:  # 20% chance to select each sample
                        random_samples.append({
                            'image1': inputs1[i].cpu(),
                            'image2': inputs2[i].cpu(),
                            'label': labels[i].cpu(),
                            'pred': preds[i].cpu(),
                            'probabilities': torch.softmax(outputs[i], dim=0).cpu()
                        })

    # Concatenate predictions and labels
    all_predictions = np.concatenate(all_predictions)
    all_labels = np.concatenate(all_labels)

    # Calculate metrics
    test_metrics = calculate_metrics(all_predictions, all_labels, num_classes)

    # Add loss to metrics
    test_metrics['loss'] = total_loss / total_samples

    # Make sure we have exactly 5 samples
    while len(random_samples) < 5:
        random_samples.append(random_samples[-1] if random_samples else {
            'image1': torch.zeros(3, 64, 64),
            'image2': torch.zeros(3, 64, 64),
            'label': torch.zeros(64, 64),
            'pred': torch.zeros(64, 64),
            'probabilities': torch.zeros(3, 64, 64)
        })

    return random_samples, test_metrics

def visualize_results(random_samples, num_classes=3):
    # Create a figure with subplots
    fig, axes = plt.subplots(5, 4, figsize=(25, 25))
    plt.subplots_adjust(hspace=0.3, wspace=0.3)

    for idx, sample in enumerate(random_samples):
        # Normalize and convert images for display
        img1 = sample['image1'].numpy().transpose(1, 2, 0)
        img2 = sample['image2'].numpy().transpose(1, 2, 0)
        img1 = (img1 - img1.min()) / (img1.max() - img1.min())
        img2 = (img2 - img2.min()) / (img2.max() - img2.min())

        # Get masks
        pred_mask = sample['pred'].numpy()
        true_mask = sample['label'].numpy()

        # Plot images and masks
        axes[idx, 0].imshow(img1)
        axes[idx, 0].set_title('Image 1')
        axes[idx, 0].axis('off')

        axes[idx, 1].imshow(img2)
        axes[idx, 1].set_title('Image 2')
        axes[idx, 1].axis('off')

        # Plot predicted mask
        pred_plot = axes[idx, 2].imshow(pred_mask, cmap='tab10', vmin=0, vmax=num_classes-1)
        axes[idx, 2].set_title('Predicted Change')
        axes[idx, 2].axis('off')

        # Plot ground truth mask
        true_plot = axes[idx, 3].imshow(true_mask, cmap='tab10', vmin=0, vmax=num_classes-1)
        axes[idx, 3].set_title('Ground Truth')
        axes[idx, 3].axis('off')

    plt.tight_layout()
    plt.show()

def save_test_metrics(test_metrics, save_dir, model_name, strategy, num_epochs):
    """Save test metrics to JSON"""
    metrics_file = os.path.join(save_dir, f"{strategy}_{model_name}-{num_classes}_classes_{num_epochs}_test_metrics.json")

    # Use the pre-computed metrics directly
    with open(metrics_file, 'w') as f:
        json.dump(test_metrics, f, indent=4)

    print(f"\nSaved test metrics to: {metrics_file}")

# Test the model
random_samples, test_metrics = test_model(model, test_loader, device=device, 
                                          num_classes=num_classes)

# Save test metrics
save_test_metrics(test_metrics=test_metrics,
                  save_dir=SAVING_DIR,
                  model_name=model_name,
                  strategy=strategy,
                  num_epochs=num_epochs)

# Visualize results
visualize_results(random_samples,num_classes=num_classes)