In [None]:
!pip install rasterio
!pip install segmentation_models_pytorch

In [None]:
from google.colab import drive
drive.mount('/content/drive')
!unzip '/content/drive/MyDrive/BTechProject/ChangeDetectionMergedDividedSplit-tif3.zip' -d '/content/ChangeDetectionMergedDividedSplit-tif'

### Dataloader

In [None]:
import os
import rasterio
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class ChangeDetectionDatasetTIF(Dataset):
    def __init__(self, t2019_dir, t2024_dir, sem_2019_dir, sem_2024_dir, mask_dir,
                 classes, semantic_classes, transform=None):
        self.t2019_dir = t2019_dir
        self.t2024_dir = t2024_dir
        self.mask_dir = mask_dir
        self.sem_2019_dir = sem_2019_dir
        self.sem_2024_dir = sem_2024_dir
        self.classes = classes  # Change detection classes
        self.semantic_classes = semantic_classes  # Land cover classes
        self.transform = transform

        # Load all paths
        self.t2019_paths = sorted([f for f in os.listdir(t2019_dir) if f.endswith('.tif')])
        self.t2024_paths = sorted([f for f in os.listdir(t2024_dir) if f.endswith('.tif')])
        self.mask_paths = sorted([f for f in os.listdir(mask_dir) if f.endswith('.tif')])
        self.sem2019_paths = sorted([f for f in os.listdir(sem_2019_dir) if f.endswith('.tif')])
        self.sem2024_paths = sorted([f for f in os.listdir(sem_2024_dir) if f.endswith('.tif')])

        # Verify all paths match
        assert len(self.t2019_paths) == len(self.t2024_paths) == len(self.mask_paths) == \
               len(self.sem2019_paths) == len(self.sem2024_paths), "Mismatched number of images"

    def __len__(self):
        return len(self.t2019_paths)

    def __getitem__(self, index):
        # Load images using rasterio
        with rasterio.open(os.path.join(self.t2019_dir, self.t2019_paths[index])) as src:
            img_t2019 = src.read(out_dtype=np.float32) / 255.0
        with rasterio.open(os.path.join(self.t2024_dir, self.t2024_paths[index])) as src:
            img_t2024 = src.read(out_dtype=np.float32) / 255.0

        # Load masks
        with rasterio.open(os.path.join(self.mask_dir, self.mask_paths[index])) as src:
            cd_mask = src.read(1).astype(np.int64)
        with rasterio.open(os.path.join(self.sem_2019_dir, self.sem2019_paths[index])) as src:
            sem_mask_2019 = src.read(1).astype(np.int64)
        with rasterio.open(os.path.join(self.sem_2024_dir, self.sem2024_paths[index])) as src:
            sem_mask_2024 = src.read(1).astype(np.int64)

        # Convert to PyTorch tensors
        img_t2019 = torch.from_numpy(img_t2019)
        img_t2024 = torch.from_numpy(img_t2024)
        cd_mask = torch.from_numpy(cd_mask)
        sem_mask_2019 = torch.from_numpy(sem_mask_2019)
        sem_mask_2024 = torch.from_numpy(sem_mask_2024)

        # Apply transforms if any
        if self.transform is not None:
            img_t2019 = self.transform(img_t2019)
            img_t2024 = self.transform(img_t2024)

        return img_t2019, img_t2024, sem_mask_2019, sem_mask_2024, cd_mask


def describe_loader(loader_type):
    """Print information about a data loader"""
    img2019, img2024, sem2019, sem2024, cd_mask = next(iter(loader_type))
    print("Batch size:", loader_type.batch_size)
    print("Shapes:")
    print("  2019 Image:", img2019.shape)
    print("  2024 Image:", img2024.shape)
    print("  2019 Semantic Mask:", sem2019.shape)
    print("  2024 Semantic Mask:", sem2024.shape)
    print("  Change Mask:", cd_mask.shape)
    print("Number of images:", len(loader_type.dataset))
    print("Classes:", loader_type.dataset.classes)
    print("Semantic Classes:", loader_type.dataset.semantic_classes)
    print("\nUnique values:")
    print("  Change Mask:", torch.unique(cd_mask))
    print("  Semantic Mask 2019:", torch.unique(sem2019))

# Example usage:
ROOT_DIRECTORY = "/content/ChangeDetectionMergedDividedSplit-tif"
SAVING_DIR = '/content/drive/MyDrive/BTechProject'
#CLASSES = ['no_change', 'increase_vegetation', 'decrease_vegetation']
CLASSES = ['no_change', 'water_building', 'water_sparse', 'water_dense',
           'building_water', 'building_sparse', 'building_dense',
           'sparse_water', 'sparse_building', 'sparse_dense',
           'dense_water', 'dense_building', 'dense_sparse']

SEMANTIC_CLASSES = ['water', 'building', 'sparse_vegetation', 'dense_vegetation']
# Create datasets
train_dataset = ChangeDetectionDatasetTIF(
    t2019_dir=f"{ROOT_DIRECTORY}/train/Images/T2019",
    t2024_dir=f"{ROOT_DIRECTORY}/train/Images/T2024",
    sem_2019_dir=f"{ROOT_DIRECTORY}/train/Masks/T2019",
    sem_2024_dir=f"{ROOT_DIRECTORY}/train/Masks/T2024",
    mask_dir=f"{ROOT_DIRECTORY}/train/cd2_Output",
    classes=CLASSES,
    semantic_classes=SEMANTIC_CLASSES
)

val_dataset = ChangeDetectionDatasetTIF(
    t2019_dir=f"{ROOT_DIRECTORY}/val/Images/T2019",
    t2024_dir=f"{ROOT_DIRECTORY}/val/Images/T2024",
    sem_2019_dir=f"{ROOT_DIRECTORY}/val/Masks/T2019",
    sem_2024_dir=f"{ROOT_DIRECTORY}/val/Masks/T2024",
    mask_dir=f"{ROOT_DIRECTORY}/val/cd2_Output",
    classes=CLASSES,
    semantic_classes=SEMANTIC_CLASSES
)

test_dataset = ChangeDetectionDatasetTIF(
    t2019_dir=f"{ROOT_DIRECTORY}/test/Images/T2019",
    t2024_dir=f"{ROOT_DIRECTORY}/test/Images/T2024",
    sem_2019_dir=f"{ROOT_DIRECTORY}/test/Masks/T2019",
    sem_2024_dir=f"{ROOT_DIRECTORY}/test/Masks/T2024",
    mask_dir=f"{ROOT_DIRECTORY}/test/cd2_Output",
    classes=CLASSES,
    semantic_classes=SEMANTIC_CLASSES
)

# Create dataloaders
num_workers = 8
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
print("------------Train------------")
describe_loader(train_loader)
print("------------Val------------")
describe_loader(val_loader)
print("------------Test------------")
describe_loader(test_loader)

### Visualize

In [None]:
import matplotlib.pyplot as plt
import random

# Set up the plot size and remove axes
fig, axs = plt.subplots(5, 5, figsize=(10,10))

for i in range(5):
    j = random.randint(0, len(train_dataset) - 1)
    #image1, image2, mask = train_dataset[j]
    img_t2019, img_t2024, sem_mask_2019, sem_mask_2024, cd_mask = train_dataset[j]

    # Display images
    axs[i, 0].imshow(img_t2019.permute(1, 2, 0))
    axs[i, 0].set_title(f"Real 2019")
    axs[i, 0].axis("off")

    axs[i, 1].imshow(img_t2024.permute(1, 2, 0))
    axs[i, 1].set_title(f"Real 2024")
    axs[i, 1].axis("off")

    axs[i, 2].imshow(sem_mask_2019, cmap="turbo")
    print(np.unique(sem_mask_2019))
    axs[i, 2].set_title(f"Sem 2019 Mask")
    axs[i, 2].axis("off")

    axs[i, 3].imshow(sem_mask_2024, cmap="turbo")
    print(np.unique(sem_mask_2024))
    axs[i, 3].set_title(f"Sem 2024 Mask")
    axs[i, 3].axis("off")

    axs[i, 4].imshow(cd_mask, cmap="turbo")
    print(np.unique(cd_mask))
    axs[i, 4].set_title(f"CD Mask")
    axs[i, 4].axis("off")

plt.plot()

### Model Definition, Training function, Util functions

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import segmentation_models_pytorch as smp
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

class Strategy3Model:
    """Combined CD and LCM model with checkpoint management"""
    def __init__(self, cd_architecture='unet', lcm_architecture='unet',
                 cd_encoder='resnet34', lcm_encoder='resnet34',
                 input_channels=3, num_classes=13, num_semantic_classes=4):
        # Initialize CD model
        self.cd_model = self._create_cd_model(
            architecture=cd_architecture,
            encoder=cd_encoder,
            input_channels=input_channels
        )
        # Initialize LCM model
        self.lcm_model = self._create_lcm_model(
            architecture=lcm_architecture,
            encoder=lcm_encoder,
            input_channels=input_channels,
            num_semantic_classes=num_semantic_classes
        )

    def _create_cd_model(self, architecture, encoder, input_channels):
        """Create binary change detection model"""
        if architecture.lower() == 'unet':
            model = smp.Unet(
                encoder_name=encoder,
                encoder_weights='imagenet',
                in_channels=input_channels*2,  # Concatenated images
                classes=1,  # Binary output,
                encoder_depth=4,  # Reduce depth (def=5)
                decoder_channels=(256, 128, 64, 32)  # Reduce channels(def=(256, 128, 64, 32, 16))

            )
        elif architecture.lower() == 'deeplabv3plus':
            model = smp.DeepLabV3Plus(
                encoder_name=encoder,
                encoder_weights='imagenet',
                in_channels=input_channels*2,
                classes=1,
            )
        # Add more architectures as needed
        return model

    def _create_lcm_model(self, architecture, encoder, input_channels, num_semantic_classes=4):
        """Create land cover mapping model"""
        if architecture.lower() == 'unet':
            model = smp.Unet(
                encoder_name=encoder,
                encoder_weights='imagenet',
                in_channels=input_channels,
                classes=num_semantic_classes,  # 4 land cover classes
            )
        elif architecture.lower() == 'deeplabv3plus':
            model = smp.DeepLabV3Plus(
                encoder_name=encoder,
                encoder_weights='imagenet',
                in_channels=input_channels,
                classes=num_semantic_classes,
            )
        # Add more architectures as needed
        return model

    def to(self, device):
        """Move models to device"""
        self.cd_model = self.cd_model.to(device)
        self.lcm_model = self.lcm_model.to(device)
        return self

    def train(self):
        """Set models to training mode"""
        self.cd_model.train()
        self.lcm_model.train()

    def eval(self):
        """Set models to evaluation mode"""
        self.cd_model.eval()
        self.lcm_model.eval()

def create_semantic_change_mask(binary_pred, lcm_pred_2019, lcm_pred_2024):
    """Convert binary change + LCM predictions to 13-class semantic change mask.

    Optimized version using vectorized operations and pre-computed lookup tables.

    Args:
        binary_pred: Binary change prediction tensor (B, 1, H, W)
        lcm_pred_2019: Land cover prediction tensor for 2019 (B, C, H, W)
        lcm_pred_2024: Land cover prediction tensor for 2024 (B, C, H, W)

    Returns:
        Semantic change mask tensor (B, H, W) with values 0-12
    """
    device = binary_pred.device
    batch_size = binary_pred.shape[0]
    height = binary_pred.shape[2]
    width = binary_pred.shape[3]

    # Pre-compute land cover predictions - do this once
    lcm_2019 = torch.argmax(lcm_pred_2019, dim=1)  # (B, H, W)
    lcm_2024 = torch.argmax(lcm_pred_2024, dim=1)  # (B, H, W)

    # Create the change mask - use threshold without squeeze/unsqueeze
    change_mask = binary_pred[:, 0] > 0.5  # (B, H, W)

    # Initialize output tensor
    semantic_mask = torch.zeros((batch_size, height, width), device=device, dtype=torch.long)

    # Create transition matrix lookup table - speeds up class mapping
    # Format: from_class * num_classes + to_class = semantic_class
    num_classes = 4  # Water, Building, Sparse, Dense
    transitions = torch.full((num_classes * num_classes,), 0, device=device)

    # Populate transition matrix - all transitions not listed default to 0 (no change)
    transition_map = {
        (0, 1): 1,   # Water → Building
        (0, 2): 2,   # Water → Sparse
        (0, 3): 3,   # Water → Dense
        (1, 0): 4,   # Building → Water
        (1, 2): 5,   # Building → Sparse
        (1, 3): 6,   # Building → Dense
        (2, 0): 7,   # Sparse → Water
        (2, 1): 8,   # Sparse → Building
        (2, 3): 9,   # Sparse → Dense
        (3, 0): 10,  # Dense → Water
        (3, 1): 11,  # Dense → Building
        (3, 2): 12,  # Dense → Sparse
    }

    for (from_idx, to_idx), semantic_idx in transition_map.items():
        transitions[from_idx * num_classes + to_idx] = semantic_idx

    # Vectorized computation of semantic classes
    # Only compute for changed pixels to save memory
    changed_pixels = change_mask.nonzero(as_tuple=True)
    if len(changed_pixels[0]) > 0:
        from_classes = lcm_2019[changed_pixels]  # (N,)
        to_classes = lcm_2024[changed_pixels]    # (N,)

        # Compute transition indices
        transition_indices = from_classes * num_classes + to_classes  # (N,)

        # Look up semantic classes from transition matrix
        semantic_classes = transitions[transition_indices]  # (N,)

        # Assign semantic classes to output mask
        semantic_mask[changed_pixels] = semantic_classes

    return semantic_mask

def calculate_class_weights(train_loader, device, num_classes=13):
    """
        method: Weighting method is 'square_balanced'
    """

    # Initialize counters on CPU first
    class_counts = torch.zeros(num_classes)
    total_pixels = 0

    print("Calculating class weights...")
    # Count frequencies on CPU
    for batch in train_loader:
        cd_mask = batch[-1]  # Get the last item which is cd_mask
        unique_labels = torch.unique(cd_mask)
        for label in unique_labels:
            if label < num_classes:  # Safety check
                class_counts[label] += (cd_mask == label).sum().item()
        total_pixels += cd_mask.numel()

    class_frequencies = class_counts / total_pixels
    # Square root of inverse frequencies (less aggressive balancing)
    weights = torch.sqrt(1.0 / class_frequencies)

    # Normalize weights to sum to num_classes
    weights = weights * (num_classes / weights.sum())

    return weights


def calculate_metrics(predictions, targets, num_classes):
    """
    Calculate metrics with per-class IoU and unweighted averaging
    """
    # Flatten predictions and targets
    pred_flat = predictions.flatten()
    target_flat = targets.flatten()

    # Compute confusion matrix
    cm = confusion_matrix(target_flat, pred_flat, labels=range(num_classes))

    # Calculate metrics for each class
    metrics = {}
    class_metrics = []

    # Per-class calculations
    for i in range(num_classes):
        tp = cm[i, i]
        fp = np.sum(cm[:, i]) - tp
        fn = np.sum(cm[i, :]) - tp
        tn = np.sum(cm) - tp - fp - fn

        # Handle divide by zero
        union = tp + fp + fn
        if union == 0:
            iou = 0
        else:
            iou = tp / union

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

        class_metrics.append({
            'class': i,
            'iou': iou,
            'precision': precision,
            'recall': recall,
            'f1': f1
        })

    # Calculate averages - only for classes present in ground truth
    # precision, recall and f1 are weighted
    present_classes = np.unique(target_flat)
    total = np.sum(cm, axis=1)
    metrics['miou'] = np.mean([class_metrics[i]['iou'] for i in present_classes])
    metrics['precision'] = np.average([m['precision'] for m in class_metrics], weights=total)
    metrics['recall'] = np.average([m['recall'] for m in class_metrics] , weights=total)
    metrics['f1_score'] = np.average([m['f1'] for m in class_metrics], weights=total)

    # Overall accuracy
    metrics['accuracy'] = np.sum(np.diag(cm)) / np.sum(cm)

    # Kappa calculation
    n = np.sum(cm)
    sum_po = np.sum(np.diag(cm))
    sum_pe = np.sum(np.sum(cm, axis=0) * np.sum(cm, axis=1)) / n
    metrics['kappa'] = (sum_po - sum_pe) / (n - sum_pe + 1e-6)

    return metrics


def train_epoch(model, train_loader, cd_criterion, lcm_criterion,
                cd_optimizer, lcm_optimizer, device, num_classes=13):
    """Train for one epoch"""
    model.train()
    total_cd_loss = 0
    total_lcm_loss = 0
    all_predictions = []
    all_targets = []


    for img_2019, img_2024, sem_2019, sem_2024, cd_mask in tqdm(train_loader):
        # Move data to device
        img_2019 = img_2019.to(device)
        img_2024 = img_2024.to(device)
        sem_2019 = sem_2019.to(device)
        sem_2024 = sem_2024.to(device)
        cd_mask = cd_mask.to(device)

        # Create binary change mask for CD network
        binary_mask = (cd_mask > 0).float().unsqueeze(1)

        # Train CD network
        cd_optimizer.zero_grad()
        cd_logits = model.cd_model(torch.cat([img_2019, img_2024], dim=1))
        cd_pred = torch.sigmoid(cd_logits)  # Apply sigmoid for binary prediction
        cd_loss = cd_criterion(cd_pred, binary_mask)
        cd_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.cd_model.parameters(), max_norm=1.0)
        cd_optimizer.step()

        # Train LCM network
        lcm_optimizer.zero_grad()
        lcm_pred_2019 = model.lcm_model(img_2019)
        lcm_pred_2024 = model.lcm_model(img_2024)
        lcm_loss = (lcm_criterion(lcm_pred_2019, sem_2019) +
                   lcm_criterion(lcm_pred_2024, sem_2024)) / 2
        lcm_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.lcm_model.parameters(), max_norm=1.0)
        lcm_optimizer.step()

        # Get semantic predictions for metrics
        with torch.no_grad():
            semantic_pred = create_semantic_change_mask(cd_pred, lcm_pred_2019, lcm_pred_2024)
            all_predictions.append(semantic_pred.cpu())
            all_targets.append(cd_mask.cpu())

        total_cd_loss += cd_loss.item()
        total_lcm_loss += lcm_loss.item()

    # Calculate metrics
    all_predictions = torch.cat(all_predictions, dim=0)
    all_targets = torch.cat(all_targets, dim=0)
    metrics = calculate_metrics(all_predictions, all_targets, num_classes=num_classes)

    # Average losses
    metrics['cd_loss'] = total_cd_loss / len(train_loader)
    metrics['lcm_loss'] = total_lcm_loss / len(train_loader)

    return metrics

def validate(model, val_loader, cd_criterion, lcm_criterion, device, num_classes):
    """Validate model"""
    model.eval()
    total_cd_loss = 0
    total_lcm_loss = 0
    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for img_2019, img_2024, sem_2019, sem_2024, cd_mask in val_loader:
            img_2019 = img_2019.to(device)
            img_2024 = img_2024.to(device)
            sem_2019 = sem_2019.to(device)
            sem_2024 = sem_2024.to(device)
            cd_mask = cd_mask.to(device)

            binary_mask = (cd_mask > 0).float().unsqueeze(1)

            # CD predictions
            cd_pred = model.cd_model(torch.cat([img_2019, img_2024], dim=1))
            cd_loss = cd_criterion(cd_pred, binary_mask)

            # LCM predictions
            lcm_pred_2019 = model.lcm_model(img_2019)
            lcm_pred_2024 = model.lcm_model(img_2024)
            lcm_loss = (lcm_criterion(lcm_pred_2019, sem_2019) +
                       lcm_criterion(lcm_pred_2024, sem_2024)) / 2

            # Get semantic predictions
            semantic_pred = create_semantic_change_mask(cd_pred, lcm_pred_2019, lcm_pred_2024)
            all_predictions.append(semantic_pred.cpu())
            all_targets.append(cd_mask.cpu())

            total_cd_loss += cd_loss.item()
            total_lcm_loss += lcm_loss.item()

    # Calculate metrics
    all_predictions = torch.cat(all_predictions, dim=0)
    all_targets = torch.cat(all_targets, dim=0)
    metrics = calculate_metrics(all_predictions, all_targets, num_classes=num_classes)

    # Average losses
    metrics['cd_loss'] = total_cd_loss / len(val_loader)
    metrics['lcm_loss'] = total_lcm_loss / len(val_loader)

    return metrics


def print_metrics_summary(train_metrics, val_metrics, train_total_loss, val_total_loss):
    """Print a summary of the training/validation metrics"""
    print("\nTraining Metrics:")
    print(f"  Total Loss: {train_total_loss:.4f}")
    print(f"  Accuracy: {train_metrics['accuracy']:.4f}")
    print(f"  Precision: {train_metrics['precision']:.4f}")
    print(f"  Recall: {train_metrics['recall']:.4f}")
    print(f"  F1-Score: {train_metrics['f1_score']:.4f}")
    print(f"  mIoU: {train_metrics['miou']:.4f}")
    print(f"  Kappa: {train_metrics['kappa']:.4f}")

    print("\nValidation Metrics:")
    print(f"  Total Loss: {val_total_loss:.4f}")
    print(f"  Accuracy: {val_metrics['accuracy']:.4f}")
    print(f"  Precision: {val_metrics['precision']:.4f}")
    print(f"  Recall: {val_metrics['recall']:.4f}")
    print(f"  F1-Score: {val_metrics['f1_score']:.4f}")
    print(f"  mIoU: {val_metrics['miou']:.4f}")
    print(f"  Kappa: {val_metrics['kappa']:.4f}")

def train_strategy3(model, train_loader, val_loader, num_epochs=50, device='cuda',
                   checkpoint_path=None, loss='CE', num_classes=13, num_semantic_classes=4):
    """Train Strategy 3 model with checkpointing based on validation loss and overall metrics"""

    # Loss functions
    cd_criterion = nn.BCEWithLogitsLoss()

    class_weights = calculate_class_weights(train_loader, num_classes=num_semantic_classes, device='cuda')  #square_balanced
    print(class_weights)
    lcm_criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))

    # Optimizers
    cd_optimizer = optim.AdamW(model.cd_model.parameters(), lr=1e-4, weight_decay=0.01)
    lcm_optimizer = optim.AdamW(model.lcm_model.parameters(), lr=1e-4, weight_decay=0.01)

    # Learning rate schedulers
    cd_scheduler = ReduceLROnPlateau(cd_optimizer, mode='min', factor=0.5, patience=5)
    lcm_scheduler = ReduceLROnPlateau(lcm_optimizer, mode='min', factor=0.5, patience=5)

    # Load checkpoint if exists
    start_epoch = 0
    best_total_loss = float('inf')
    if checkpoint_path and os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        #model.load_state_dict(checkpoint['model_state_dict'])  # Load model weights
        cd_optimizer.load_state_dict(checkpoint['cd_optimizer'])  # Load optimizer state
        lcm_optimizer.load_state_dict(checkpoint['lcm_optimizer'])
        cd_scheduler.load_state_dict(checkpoint['cd_scheduler'])  # Load scheduler state
        lcm_scheduler.load_state_dict(checkpoint['lcm_scheduler'])
        start_epoch = checkpoint['epoch']
        best_total_loss = checkpoint.get('best_total_loss', float('inf'))
        print(f"Loaded checkpoint from epoch {start_epoch}")

    history = {
        'train': {
            'loss': [],  # Combined CD and LCM loss
            'accuracy': [],
            'precision': [],
            'recall': [],
            'f1_score': [],
            'miou': [],
            'kappa': []
        },
        'val': {
            'loss': [],  # Combined CD and LCM loss
            'accuracy': [],
            'precision': [],
            'recall': [],
            'f1_score': [],
            'miou': [],
            'kappa': []
        }
    }

    for epoch in range(start_epoch, num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")

        # Training
        train_metrics = train_epoch(
            model, train_loader, cd_criterion, lcm_criterion,
            cd_optimizer, lcm_optimizer, device, num_classes
        )

        # Validation
        val_metrics = validate(
            model, val_loader, cd_criterion, lcm_criterion,
            device, num_classes
        )

        # Calculate total loss for both phases
        train_total_loss = train_metrics['cd_loss'] + train_metrics['lcm_loss']
        val_total_loss = val_metrics['cd_loss'] + val_metrics['lcm_loss']

        # Update learning rates
        cd_scheduler.step(val_metrics['cd_loss'])
        lcm_scheduler.step(val_metrics['lcm_loss'])

        # Store metrics in history - store overall metrics
        for phase in ['train', 'val']:
            metrics = train_metrics if phase == 'train' else val_metrics
            total_loss = train_total_loss if phase == 'train' else val_total_loss

            history[phase]['loss'].append(float(total_loss))
            history[phase]['accuracy'].append(float(metrics['accuracy']))
            history[phase]['precision'].append(float(metrics['precision']))
            history[phase]['recall'].append(float(metrics['recall']))
            history[phase]['f1_score'].append(float(metrics['f1_score']))
            history[phase]['miou'].append(float(metrics['miou']))
            history[phase]['kappa'].append(float(metrics['kappa']))

        print_metrics_summary(train_metrics, val_metrics, train_total_loss, val_total_loss)

        # Save checkpoint
        if checkpoint_path:
            metrics = {
                'total_loss': val_total_loss,
                'accuracy': val_metrics['accuracy'],
                'precision': val_metrics['precision'],
                'recall': val_metrics['recall'],
                'f1_score': val_metrics['f1_score'],
                'miou': val_metrics['miou'],
                'kappa': val_metrics['kappa']
            }

            # Best model checkpoint based on total loss
            if val_total_loss < best_total_loss:
                best_total_loss = val_total_loss
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'cd_model': model.cd_model.state_dict(),
                        'lcm_model': model.lcm_model.state_dict(),
                        'cd_optimizer': cd_optimizer.state_dict(),
                        'lcm_optimizer': lcm_optimizer.state_dict(),
                        'cd_scheduler': cd_scheduler.state_dict(),
                        'lcm_scheduler': lcm_scheduler.state_dict(),
                        'best_total_loss': best_total_loss,
                        'metrics': metrics,
                    },
                    checkpoint_path
                )
                print(f"Saved new best model with total loss: {best_total_loss:.4f}")

    return model, history

import json
def save_training_history(history, checkpoint_path, save_path, save_path_bestepoch):
    """
    Save training history and best epoch information to JSON files.

    Args:
        history (dict): Dictionary containing training and validation metrics
        checkpoint_path (str): Path to the model checkpoint file
        save_path (str): Path to save the training history
        save_path_bestepoch (str): Path to save the best epoch info
    """
    processed_history = {}
    for phase, metrics in history.items():
        if isinstance(metrics, list):  # Check if metrics is a list
            processed_history[phase] = [
                {metric: (value.tolist() if hasattr(value, 'tolist') else value)
                 for metric, value in entry.items()} if isinstance(entry, dict) else entry
                for entry in metrics
            ]
        else:  # Handle non-list entries
            processed_history[phase] = metrics.tolist() if hasattr(metrics, 'tolist') else metrics

    # Save processed history to a JSON file
    with open(save_path, 'w') as f:
        json.dump(processed_history, f, indent=4)

    print(f"Training history saved to: {save_path}")

    # Load checkpoint to get best epoch info
    checkpoint = torch.load(checkpoint_path, weights_only=False)
    # print("\nCheckpoint contents:")
    # for key in checkpoint.keys():
    #     print(f"- {key}")

    # Extract best epoch information
    epoch_data = {
        'best_epoch': int(checkpoint['epoch']),
        'total_loss': float(checkpoint['metrics']['total_loss']),
        'accuracy': float(checkpoint['metrics']['accuracy']),
        'precision': float(checkpoint['metrics']['precision']),
        'recall': float(checkpoint['metrics']['recall']),
        'f1': float(checkpoint['metrics']['f1_score']),
        'miou': float(checkpoint['metrics']['miou']),
        'kappa': float(checkpoint['metrics']['kappa'])
    }

    # Save the best epoch info
    with open(save_path_bestepoch, 'w') as f:
        json.dump(epoch_data, f, indent=4)
    print(f"Best epoch info saved to: {save_path_bestepoch}")


def save_test_metrics(history, save_path):
    # Convert tensors or arrays in the history to lists for JSON serialization
    processed_history = {}
    for phase, metrics in history.items():
        if isinstance(metrics, list):  # Check if metrics is a list
            processed_history[phase] = [
                {metric: (value.tolist() if hasattr(value, 'tolist') else value)
                 for metric, value in entry.items()} if isinstance(entry, dict) else entry
                for entry in metrics
            ]
        else:  # Handle non-list entries
            processed_history[phase] = metrics.tolist() if hasattr(metrics, 'tolist') else metrics

    # Save processed history to a JSON file
    with open(save_path, 'w') as f:
        json.dump(processed_history, f, indent=4)

    print(f"Testing history saved to: {save_path}")

### Model Training

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes=13   #Change Detection classes (3 for cd1, 13 for cd2)
num_semantic_classes=4  #Semantic Segmentation LCM classes (4 for both)
num_epochs = 100
weighting_method = 'square_balanced'
loss = 'CE' #'CE'
checkpoint_path = f'{SAVING_DIR}/best_Strat3_{num_epochs}_epochs.pt'  #'models/strategy3_model.pt'

# Create model
model = Strategy3Model(
    cd_architecture='unet',
    lcm_architecture='unet',
    cd_encoder='resnet34',
    lcm_encoder='resnet34',
    input_channels=3,
    num_classes=num_classes,
    num_semantic_classes=num_semantic_classes
).to(device)

# Train model
model2, history = train_strategy3(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=num_epochs,
    device=device,
    checkpoint_path=checkpoint_path,
    weighting_method=weighting_method,
    loss=loss,
    num_classes=num_classes,
    num_semantic_classes=num_semantic_classes
)

save_path = f'{SAVING_DIR}/strat3_train_history.json'
save_path_bestepoch = f'{SAVING_DIR}/strat3_bestepoch.json'
save_training_history(history, checkpoint_path, save_path, save_path_bestepoch)

### Testing

In [None]:
import random
import matplotlib.pyplot as plt
def visualize_predictions(model, img_2019, img_2024, true_mask, seg_mask_2019, seg_mask_2024):
    """
    Visualize model predictions in a single row
    """
    model.eval()
    with torch.no_grad():
        # Remove batch dimension if present
        if img_2019.dim() == 4:
            img_2019 = img_2019.squeeze(0)
        if img_2024.dim() == 4:
            img_2024 = img_2024.squeeze(0)
        if true_mask.dim() == 4:
            true_mask = true_mask.squeeze(0)
        if seg_mask_2019.dim() == 4:
            seg_mask_2019 = seg_mask_2019.squeeze(0)
        if seg_mask_2024.dim() == 4:
            seg_mask_2024 = seg_mask_2024.squeeze(0)

        # Get predictions
        cd_pred = model.cd_model(torch.cat([img_2019.unsqueeze(0), img_2024.unsqueeze(0)], dim=1))
        lcm_pred_2019 = model.lcm_model(img_2019.unsqueeze(0))
        lcm_pred_2024 = model.lcm_model(img_2024.unsqueeze(0))
        semantic_pred = create_semantic_change_mask(cd_pred, lcm_pred_2019, lcm_pred_2024)
        semantic_pred = semantic_pred.squeeze(0)

    # Create single row visualization
    fig, axes = plt.subplots(1, 9, figsize=(24, 3))

    # Plot images without titles
    # Original images
    axes[0].imshow(img_2019.cpu().permute(1,2,0))
    axes[0].set_title('Image2019')
    axes[1].imshow(img_2024.cpu().permute(1,2,0))
    axes[1].set_title('Image2024')

    # Binary change detection
    axes[2].imshow(torch.sigmoid(cd_pred).cpu().squeeze(), cmap='gray')
    axes[2].set_title('BinaryCD')

    # Ground truth segmentation masks
    axes[3].imshow(seg_mask_2019.cpu(), cmap='tab10')
    axes[3].set_title('GTSemMask2019')
    axes[4].imshow(seg_mask_2024.cpu(), cmap='tab10')
    axes[4].set_title('GTSemMask2024')

    # Predicted segmentation masks
    axes[5].imshow(torch.argmax(lcm_pred_2019, dim=1).squeeze(0).cpu(), cmap='tab10')
    axes[5].set_title('PredSemMask2019')
    axes[6].imshow(torch.argmax(lcm_pred_2024, dim=1).squeeze(0).cpu(), cmap='tab10')
    axes[6].set_title('PredSemMask2024')
    # Change detection
    axes[7].imshow(true_mask.cpu(), cmap='tab10')
    axes[7].set_title('GTChangeMask')
    axes[8].imshow(semantic_pred.cpu(), cmap='tab10')
    axes[8].set_title('PredChangeMask')

    # Remove axes and padding
    for ax in axes:
        ax.axis('off')
    plt.subplots_adjust(wspace=0.05, hspace=0)

    return fig

def test_strategy3(model, test_loader, device, num_samples_to_plot=5, 
                   checkpoint_path='best_model.pt', num_classes=13, num_semantic_classes=4):
    """
    Evaluate model and display metrics for both semantic segmentation and change detection
    """
    # Load checkpoint - modified to load separate models
    checkpoint = torch.load(checkpoint_path, map_location=device)
    print(f"Loading checkpoint from {checkpoint_path}")

    # Load the separate models
    try:
        model.cd_model.load_state_dict(checkpoint['cd_model'])
        model.lcm_model.load_state_dict(checkpoint['lcm_model'])
        print("Successfully loaded both CD and LCM models")
    except Exception as e:
        print(f"Error loading models: {e}")
        return

    model.eval()

    # Initialize metric storage
    change_predictions = []
    change_targets = []
    seg_predictions_2019 = []
    seg_targets_2019 = []
    seg_predictions_2024 = []
    seg_targets_2024 = []

    # Store samples for visualization
    stored_samples = []
    total_samples = len(test_loader.dataset)
    random_indices = set(random.sample(range(total_samples), num_samples_to_plot))
    current_idx = 0

    print("Testing model...")
    with torch.no_grad():
        for img_2019, img_2024, sem_2019, sem_2024, cd_mask in tqdm(test_loader):
            # Move to device
            img_2019 = img_2019.to(device)
            img_2024 = img_2024.to(device)
            sem_2019 = sem_2019.to(device)
            sem_2024 = sem_2024.to(device)
            cd_mask = cd_mask.to(device)

            # Get predictions
            cd_pred = model.cd_model(torch.cat([img_2019, img_2024], dim=1))
            lcm_pred_2019 = model.lcm_model(img_2019)
            lcm_pred_2024 = model.lcm_model(img_2024)
            semantic_pred = create_semantic_change_mask(cd_pred, lcm_pred_2019, lcm_pred_2024)

            # Get semantic segmentation predictions
            seg_pred_2019 = torch.argmax(lcm_pred_2019, dim=1)
            seg_pred_2024 = torch.argmax(lcm_pred_2024, dim=1)

            # Append predictions and targets for later metric calculation
            change_predictions.append(semantic_pred.cpu())
            change_targets.append(cd_mask.cpu())
            seg_predictions_2019.append(seg_pred_2019.cpu())
            seg_targets_2019.append(sem_2019.cpu())
            seg_predictions_2024.append(seg_pred_2024.cpu())
            seg_targets_2024.append(sem_2024.cpu())

            # Store random samples
            batch_size = img_2019.size(0)
            for i in range(batch_size):
                if current_idx + i in random_indices:
                    stored_samples.append({
                        'img_2019': img_2019[i],
                        'img_2024': img_2024[i],
                        'cd_pred': cd_pred[i],
                        'lcm_pred_2019': lcm_pred_2019[i],
                        'lcm_pred_2024': lcm_pred_2024[i],
                        'semantic_pred': semantic_pred[i],
                        'true_mask': cd_mask[i],
                        'seg_mask_2019': sem_2019[i],
                        'seg_mask_2024': sem_2024[i]
                    })
            current_idx += batch_size

    # Calculate metrics
    print("\n" + "="*50)
    print("Semantic Segmentation Metrics:")
    print("="*50)

    # 2019 Segmentation Metrics
    seg_preds_2019 = torch.cat(seg_predictions_2019, dim=0).numpy()
    seg_targets_2019 = torch.cat(seg_targets_2019, dim=0).numpy()
    metrics_seg_2019 = calculate_metrics(seg_preds_2019, seg_targets_2019, num_classes=num_semantic_classes)

    print("\n2019 Segmentation:")
    print(f"Accuracy: {metrics_seg_2019['accuracy']:.4f}")
    print(f"Mean IoU: {metrics_seg_2019['miou']:.4f}")
    print(f"F1 Score: {metrics_seg_2019['f1_score']:.4f}")
    print(f"Precision: {metrics_seg_2019['precision']:.4f}")
    print(f"Recall: {metrics_seg_2019['recall']:.4f}")
    print(f"Kappa: {metrics_seg_2019['kappa']:.4f}")

    # 2024 Segmentation Metrics
    seg_preds_2024 = torch.cat(seg_predictions_2024, dim=0).numpy()
    seg_targets_2024 = torch.cat(seg_targets_2024, dim=0).numpy()
    metrics_seg_2024 = calculate_metrics(seg_preds_2024, seg_targets_2024, num_classes=num_semantic_classes)

    print("\n2024 Segmentation:")
    print(f"Accuracy: {metrics_seg_2024['accuracy']:.4f}")
    print(f"Mean IoU: {metrics_seg_2024['miou']:.4f}")
    print(f"F1 Score: {metrics_seg_2024['f1_score']:.4f}")
    print(f"Precision: {metrics_seg_2024['precision']:.4f}")
    print(f"Recall: {metrics_seg_2024['recall']:.4f}")
    print(f"Kappa: {metrics_seg_2024['kappa']:.4f}")

    # Average Segmentation Metrics
    print("\nAverage Segmentation:")
    avg_seg_metrics = {
        'accuracy': (metrics_seg_2019['accuracy'] + metrics_seg_2024['accuracy']) / 2,
        'miou': (metrics_seg_2019['miou'] + metrics_seg_2024['miou']) / 2,
        'f1_score': (metrics_seg_2019['f1_score'] + metrics_seg_2024['f1_score']) / 2,
        'precision': (metrics_seg_2019['precision'] + metrics_seg_2024['precision']) / 2,
        'recall': (metrics_seg_2019['recall'] + metrics_seg_2024['recall']) / 2,
        'kappa': (metrics_seg_2019['kappa'] + metrics_seg_2024['kappa']) / 2
    }
    print(f"Accuracy: {avg_seg_metrics['accuracy']:.4f}")
    print(f"Mean IoU: {avg_seg_metrics['miou']:.4f}")
    print(f"F1 Score: {avg_seg_metrics['f1_score']:.4f}")
    print(f"Precision: {avg_seg_metrics['precision']:.4f}")
    print(f"Recall: {avg_seg_metrics['recall']:.4f}")
    print(f"Kappa: {avg_seg_metrics['kappa']:.4f}")

    print("\n" + "="*50)
    print("Change Detection Metrics:")
    print("="*50)

    # Change Detection Metrics
    change_preds = torch.cat(change_predictions, dim=0).numpy()
    change_targets = torch.cat(change_targets, dim=0).numpy()
    metrics_change = calculate_metrics(change_preds, change_targets, num_classes=num_classes)

    print(f"Accuracy: {metrics_change['accuracy']:.4f}")
    print(f"Mean IoU: {metrics_change['miou']:.4f}")
    print(f"F1 Score: {metrics_change['f1_score']:.4f}")
    print(f"Precision: {metrics_change['precision']:.4f}")
    print(f"Recall: {metrics_change['recall']:.4f}")
    print(f"Kappa: {metrics_change['kappa']:.4f}")

    # Plot stored samples
    print("\nPlotting random samples...")
    for sample in stored_samples:
        fig = visualize_predictions(
            model,
            sample['img_2019'],
            sample['img_2024'],
            sample['true_mask'],
            sample['seg_mask_2019'],
            sample['seg_mask_2024']
        )
        plt.show()
        plt.close(fig)

    return {
        'segmentation_2019': metrics_seg_2019,
        'segmentation_2024': metrics_seg_2024,
        'segmentation_avg': avg_seg_metrics,
        'change_detection': metrics_change
    }

### Test model
test_metrics = test_strategy3(model, test_loader, device,
                                  num_samples_to_plot=5,
                                  checkpoint_path=checkpoint_path,
                                  num_classes=num_classes,
                                  num_semantic_classes=num_semantic_classes)

save_test_metrics(test_metrics, save_path=f'{SAVING_DIR}/strat3_test_metrics.json')