In [None]:
# !pip install rasterio
# !pip install segmentation_models_pytorch

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')
# !unzip '/content/drive/MyDrive/BTechProject/ChangeDetectionMergedDividedSplit-tif3-reduced.zip' -d '/content/ChangeDetectionMergedDividedSplit-tif-reduced'

## Hyperparamters

In [None]:
ROOT_DIRECTORY = "ChangeDetectionMergedDividedSplit-tif"
SAVING_DIR = "/content/drive/MyDrive/BTechProject"
CD_DIR = "cd2_Output"   #FOR STRATEGY4 ALWAYS USE cd2_Output

if CD_DIR == "cd1_Output":
    CLASSES = ['no_change','vegetation_increase','vegetation_decrease']
elif CD_DIR == "cd2_Output":
    # CLASSES = ['no_change', 'water_building', 'water_sparse', 'water_dense',
    #            'building_water', 'building_sparse', 'building_dense',
    #            'sparse_water', 'sparse_building', 'sparse_dense',
    #            'dense_water', 'dense_building', 'dense_sparse']
    CLASSES = [
    'no_change','water_built', 'water_bare', 'water_sparse', 'water_trees',
    'water_crops', 'built_water', 'built_bare', 'built_sparse', 'built_trees',
    'built_crops',  'bare_water',  'bare_built',  'bare_sparse',  'bare_trees',
    'bare_crops',  'sparse_water',  'sparse_built',  'sparse_bare',
    'sparse_trees',  'sparse_crops',  'trees_water',  'trees_built',
    'trees_bare',  'trees_sparse',  'trees_crops',  'crops_water',
    'crops_built', 'crops_bare',  'crops_sparse',  'crops_trees']


# SEMANTIC_CLASSES = ['water', 'building', 'sparse_vegetation', 'dense_vegetation']
SEMANTIC_CLASSES = ['water', 'built', 'bare', 'sparse', 'trees', 'crops', 'others']

NUM_WORKERS = 8
BATCH_SIZE = 32
NUM_EPOCHS = 100
MODEL_NAME = 'strat4'

## Dataloader

In [None]:
import os
import rasterio
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class ChangeDetectionDatasetTIF(Dataset):
    def __init__(self, t2019_dir, t2024_dir, sem_2019_dir, sem_2024_dir, mask_dir,
                 classes, semantic_classes, transform=None):
        self.t2019_dir = t2019_dir
        self.t2024_dir = t2024_dir
        self.mask_dir = mask_dir
        self.sem_2019_dir = sem_2019_dir
        self.sem_2024_dir = sem_2024_dir
        self.classes = classes  # Change detection classes
        self.semantic_classes = semantic_classes  # Land cover classes
        self.transform = transform

        # Load all paths
        self.t2019_paths = sorted([f for f in os.listdir(t2019_dir) if f.endswith('.tif')])
        self.t2024_paths = sorted([f for f in os.listdir(t2024_dir) if f.endswith('.tif')])
        self.mask_paths = sorted([f for f in os.listdir(mask_dir) if f.endswith('.tif')])
        self.sem2019_paths = sorted([f for f in os.listdir(sem_2019_dir) if f.endswith('.tif')])
        self.sem2024_paths = sorted([f for f in os.listdir(sem_2024_dir) if f.endswith('.tif')])

        # Verify all paths match
        assert len(self.t2019_paths) == len(self.t2024_paths) == len(self.mask_paths) == \
               len(self.sem2019_paths) == len(self.sem2024_paths), "Mismatched number of images"

    def __len__(self):
        return len(self.t2019_paths)

    def __getitem__(self, index):
        # Load images using rasterio
        with rasterio.open(os.path.join(self.t2019_dir, self.t2019_paths[index])) as src:
            img_t2019 = src.read(out_dtype=np.float32) / 255.0
        with rasterio.open(os.path.join(self.t2024_dir, self.t2024_paths[index])) as src:
            img_t2024 = src.read(out_dtype=np.float32) / 255.0

        # Load masks
        with rasterio.open(os.path.join(self.mask_dir, self.mask_paths[index])) as src:
            cd_mask = src.read(1).astype(np.int64)
        with rasterio.open(os.path.join(self.sem_2019_dir, self.sem2019_paths[index])) as src:
            sem_mask_2019 = src.read(1).astype(np.int64)
        with rasterio.open(os.path.join(self.sem_2024_dir, self.sem2024_paths[index])) as src:
            sem_mask_2024 = src.read(1).astype(np.int64)

        # Convert to PyTorch tensors
        img_t2019 = torch.from_numpy(img_t2019)
        img_t2024 = torch.from_numpy(img_t2024)
        cd_mask = torch.from_numpy(cd_mask)
        sem_mask_2019 = torch.from_numpy(sem_mask_2019)
        sem_mask_2024 = torch.from_numpy(sem_mask_2024)

        # Apply transforms if any
        if self.transform is not None:
            img_t2019 = self.transform(img_t2019)
            img_t2024 = self.transform(img_t2024)

        return img_t2019, img_t2024, sem_mask_2019, sem_mask_2024, cd_mask


def describe_loader(loader_type):
    """Print information about a data loader"""
    img2019, img2024, sem2019, sem2024, cd_mask = next(iter(loader_type))
    print("Batch size:", loader_type.batch_size)
    print("Shapes:")
    print("  2019 Image:", img2019.shape)
    print("  2024 Image:", img2024.shape)
    print("  2019 Semantic Mask:", sem2019.shape)
    print("  2024 Semantic Mask:", sem2024.shape)
    print("  Change Mask:", cd_mask.shape)
    print("Number of images:", len(loader_type.dataset))
    print("Classes:", loader_type.dataset.classes)
    print("Semantic Classes:", loader_type.dataset.semantic_classes)
    print("\nUnique values:")
    print("  Change Mask:", torch.unique(cd_mask))
    print("  Semantic Mask 2019:", torch.unique(sem2019))


# Create datasets
train_dataset = ChangeDetectionDatasetTIF(
    t2019_dir=f"{ROOT_DIRECTORY}/train/Images/T2019",
    t2024_dir=f"{ROOT_DIRECTORY}/train/Images/T2024",
    sem_2019_dir=f"{ROOT_DIRECTORY}/train/Masks/T2019",
    sem_2024_dir=f"{ROOT_DIRECTORY}/train/Masks/T2024",
    mask_dir=f"{ROOT_DIRECTORY}/train/{CD_DIR}",
    classes=CLASSES,
    semantic_classes=SEMANTIC_CLASSES
)

val_dataset = ChangeDetectionDatasetTIF(
    t2019_dir=f"{ROOT_DIRECTORY}/val/Images/T2019",
    t2024_dir=f"{ROOT_DIRECTORY}/val/Images/T2024",
    sem_2019_dir=f"{ROOT_DIRECTORY}/val/Masks/T2019",
    sem_2024_dir=f"{ROOT_DIRECTORY}/val/Masks/T2024",
    mask_dir=f"{ROOT_DIRECTORY}/val/{CD_DIR}",
    classes=CLASSES,
    semantic_classes=SEMANTIC_CLASSES
)

test_dataset = ChangeDetectionDatasetTIF(
    t2019_dir=f"{ROOT_DIRECTORY}/test/Images/T2019",
    t2024_dir=f"{ROOT_DIRECTORY}/test/Images/T2024",
    sem_2019_dir=f"{ROOT_DIRECTORY}/test/Masks/T2019",
    sem_2024_dir=f"{ROOT_DIRECTORY}/test/Masks/T2024",
    mask_dir=f"{ROOT_DIRECTORY}/test/{CD_DIR}",
    classes=CLASSES,
    semantic_classes=SEMANTIC_CLASSES
)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
print("------------Train------------")
describe_loader(train_loader)
print("------------Val------------")
describe_loader(val_loader)
print("------------Test------------")
describe_loader(test_loader)

## Visualize

In [None]:
import matplotlib.pyplot as plt
import random

# Set up the plot size and remove axes
fig, axs = plt.subplots(5, 5, figsize=(10,10))

for i in range(5):
    j = random.randint(0, len(train_dataset) - 1)
    #image1, image2, mask = train_dataset[j]
    img_t2019, img_t2024, sem_mask_2019, sem_mask_2024, cd_mask = train_dataset[j]

    # Display images
    axs[i, 0].imshow(img_t2019.permute(1, 2, 0))
    axs[i, 0].set_title(f"Real 2019")
    axs[i, 0].axis("off")

    axs[i, 1].imshow(img_t2024.permute(1, 2, 0))
    axs[i, 1].set_title(f"Real 2024")
    axs[i, 1].axis("off")

    axs[i, 2].imshow(sem_mask_2019, cmap="turbo")
    print(np.unique(sem_mask_2019))
    axs[i, 2].set_title(f"Sem 2019 Mask")
    axs[i, 2].axis("off")

    axs[i, 3].imshow(sem_mask_2024, cmap="turbo")
    print(np.unique(sem_mask_2024))
    axs[i, 3].set_title(f"Sem 2024 Mask")
    axs[i, 3].axis("off")

    axs[i, 4].imshow(cd_mask, cmap="turbo")
    print(np.unique(cd_mask))
    axs[i, 4].set_title(f"CD Mask")
    axs[i, 4].axis("off")

plt.plot()

## Model Definition

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiTaskChangeDetectionModel(nn.Module):
    def __init__(self, input_channels, num_semantic_classes, num_cd_classes):
        super().__init__()

        # Shared Encoder
        self.encoder = nn.Sequential(
            # Initial block
            nn.Conv2d(input_channels, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),

            # Second block
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        # Land Cover Mapping Decoder (shared weights)
        self.lcm_decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, num_semantic_classes, kernel_size=4, stride=2, padding=1)
        )

        # Change Detection Decoder
        self.cd_decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, num_cd_classes, kernel_size=4, stride=2, padding=1)
        )

    def forward(self, x1, x2):
        # Ensure input images are the same size
        assert x1.shape == x2.shape, "Input images must have the same dimensions"

        # Encode both images
        enc1 = self.encoder(x1)
        enc2 = self.encoder(x2)

        # Land Cover Mapping for both time periods
        lcm1 = self.lcm_decoder(enc1)
        lcm2 = self.lcm_decoder(enc2)

        # Change Detection (using difference of encodings)
        cd_input = torch.abs(enc1 - enc2)  # Or torch.cat([enc1, enc2], dim=1)
        cd_output = self.cd_decoder(cd_input)

        return lcm1, lcm2, cd_output

## Util Functions and Training loop

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
import numpy as np
from sklearn.metrics import confusion_matrix

def calculate_metrics(predictions, targets, num_classes):
    """
    Calculate metrics with per-class IoU and unweighted averaging

    Args:
        predictions: numpy array of predictions
        targets: numpy array of target labels
        num_classes: number of classes
    """
    # Flatten predictions and targets
    pred_flat = predictions.flatten()
    target_flat = targets.flatten()

    # Compute confusion matrix
    cm = confusion_matrix(target_flat, pred_flat, labels=range(num_classes))

    # Calculate metrics for each class
    metrics = {}
    class_metrics = []

    # Per-class calculations
    for i in range(num_classes):
        tp = cm[i, i]
        fp = np.sum(cm[:, i]) - tp
        fn = np.sum(cm[i, :]) - tp
        tn = np.sum(cm) - tp - fp - fn

        # Handle divide by zero
        union = tp + fp + fn
        if union == 0:
            iou = 0
        else:
            iou = tp / union

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

        class_metrics.append({
            'class': i,
            'iou': iou,
            'precision': precision,
            'recall': recall,
            'f1': f1
        })

    # Calculate averages - only for classes present in ground truth
    # precision, recall and f1 are weighted
    present_classes = np.unique(target_flat)
    total = np.sum(cm, axis=1)
    metrics['miou'] = np.mean([class_metrics[i]['iou'] for i in present_classes])
    metrics['precision'] = np.average([m['precision'] for m in class_metrics], weights=total)
    metrics['recall'] = np.average([m['recall'] for m in class_metrics], weights=total)
    metrics['f1_score'] = np.average([m['f1'] for m in class_metrics], weights=total)

    # Overall accuracy
    metrics['accuracy'] = np.sum(np.diag(cm)) / np.sum(cm)

    # Kappa calculation
    n = np.sum(cm)
    sum_po = np.sum(np.diag(cm))
    sum_pe = np.sum(np.sum(cm, axis=0) * np.sum(cm, axis=1)) / n
    metrics['kappa'] = (sum_po - sum_pe) / (n - sum_pe + 1e-6)

    return metrics


def stage_one_training(model, train_loader, val_loader, device, checkpoint_path, numepochs=50):
    """
    Stage 1: Train only the LCM branches
    """
    print("\nStarting Stage 1: Training LCM branches...")

    # Optimization setup
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

    # Initialize tracking variables
    best_val_loss = float('inf')
    history = {'train_loss': [], 'val_loss': [], 'val_metrics': []}

    # Try to load checkpoint if it exists
    start_epoch = 0
    if os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}")
        try:
            checkpoint = torch.load(checkpoint_path, weights_only=False)
            model.load_state_dict(checkpoint['model_state_dict'])
            start_epoch = checkpoint['epoch']
            best_val_loss = checkpoint.get('val_loss', float('inf'))
            print(f"Resuming from epoch {start_epoch} with best val loss: {best_val_loss:.4f}")
        except Exception as e:
            print(f"Error loading checkpoint: {e}")
            print("Starting training from scratch")

    for epoch in range(start_epoch, numepochs):
        print(f"\nEpoch {epoch+1}/{numepochs}")

        # Training phase
        model.train()
        train_loss = 0.0
        for data in tqdm(train_loader, desc='Training'):
            # Move all tensors to device
            img1, img2, mask1, mask2, _ = [x.to(device) for x in data]

            optimizer.zero_grad()

            # Forward pass - only LCM branches
            lcm1_out, lcm2_out, _ = model(img1, img2)

            # Calculate loss
            loss = (criterion(lcm1_out, mask1) + criterion(lcm2_out, mask2)) / 2

            loss.backward()
            optimizer.step()
            train_loss += loss.item()

            # Clear memory
            del lcm1_out, lcm2_out
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        avg_train_loss = train_loss / len(train_loader)

        # Validation phase
        model.eval()
        val_loss = 0.0
        all_metrics_lcm1 = []
        all_metrics_lcm2 = []

        with torch.no_grad():
            for data in tqdm(val_loader, desc='Validation'):
                # Move all tensors to device
                img1, img2, mask1, mask2, _ = [x.to(device) for x in data]

                # Forward pass
                lcm1_out, lcm2_out, _ = model(img1, img2)

                # Calculate loss
                loss = (criterion(lcm1_out, mask1) + criterion(lcm2_out, mask2)) / 2
                val_loss += loss.item()

                # Calculate metrics
                lcm1_preds = torch.argmax(lcm1_out, dim=1)
                lcm2_preds = torch.argmax(lcm2_out, dim=1)

                metrics_lcm1 = calculate_metrics(lcm1_preds.cpu().numpy(),
                                              mask1.cpu().numpy(),
                                              num_classes=4)  # 4 semantic classes
                metrics_lcm2 = calculate_metrics(lcm2_preds.cpu().numpy(),
                                              mask2.cpu().numpy(),
                                              num_classes=4)

                all_metrics_lcm1.append(metrics_lcm1)
                all_metrics_lcm2.append(metrics_lcm2)

                # Clear memory
                del lcm1_out, lcm2_out, lcm1_preds, lcm2_preds
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

        # Calculate average metrics
        avg_val_loss = val_loss / len(val_loader)

        # Aggregate metrics
        val_metrics = {
            'lcm1': {key: np.mean([m[key] for m in all_metrics_lcm1])
                    for key in all_metrics_lcm1[0].keys()},
            'lcm2': {key: np.mean([m[key] for m in all_metrics_lcm2])
                    for key in all_metrics_lcm2[0].keys()}
        }

        # Save best model based on validation loss
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_loss': best_val_loss,
                'train_loss': avg_train_loss,
                'val_metrics': val_metrics
            }, checkpoint_path)
            print(f"Saved new best model with validation loss: {best_val_loss:.4f}")

        # Update learning rate based on val loss
        scheduler.step(avg_val_loss)

        # Store metrics
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['val_metrics'].append(val_metrics)

        # Print epoch results
        print(f"Train Loss: {avg_train_loss:.4f}")
        print(f"Val Loss: {avg_val_loss:.4f}")
        print(f"LCM 2019 - mIoU: {val_metrics['lcm1']['miou']:.4f}, Accuracy: {val_metrics['lcm1']['accuracy']:.4f}")
        print(f"LCM 2024 - mIoU: {val_metrics['lcm2']['miou']:.4f}, Accuracy: {val_metrics['lcm2']['accuracy']:.4f}")

    return model, history

def stage_two_training(model, train_loader, val_loader, device, checkpoint_path, numepochs=50):
    """
    Stage 2: Train the full model end-to-end with fixed LCM weights
    """
    print("\nStarting Stage 2: Training CD branch...")

    # Freeze encoder and LCM decoder weights
    for param in model.encoder.parameters():
        param.requires_grad = False
    for param in model.lcm_decoder.parameters():
        param.requires_grad = False

    # Optimization setup
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

    # Initialize tracking variables
    best_val_loss = float('inf')
    history = {'train_loss': [], 'val_loss': [], 'val_metrics': []}

    # Try to load checkpoint if it exists
    start_epoch = 0
    if os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}")
        try:
            checkpoint = torch.load(checkpoint_path, weights_only=False)
            model.load_state_dict(checkpoint['model_state_dict'])
            start_epoch = checkpoint['epoch']
            best_val_loss = checkpoint.get('val_loss', float('inf'))
            print(f"Resuming from epoch {start_epoch} with best val loss: {best_val_loss:.4f}")
        except Exception as e:
            print(f"Error loading checkpoint: {e}")
            print("Starting training from scratch")

    for epoch in range(start_epoch, numepochs):
        print(f"\nEpoch {epoch+1}/{numepochs}")

        # Training phase
        model.train()
        train_loss = 0.0
        for data in tqdm(train_loader, desc='Training'):
            # Move all tensors to device
            img1, img2, _, _, cd_mask = [x.to(device) for x in data]

            optimizer.zero_grad()

            # Forward pass
            _, _, cd_out = model(img1, img2)

            # Calculate loss
            loss = criterion(cd_out, cd_mask)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            # Clear memory
            del cd_out
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        avg_train_loss = train_loss / len(train_loader)

        # Validation phase
        model.eval()
        val_loss = 0.0
        all_metrics_cd = []

        with torch.no_grad():
            for data in tqdm(val_loader, desc='Validation'):
                # Move all tensors to device
                img1, img2, _, _, cd_mask = [x.to(device) for x in data]

                # Forward pass
                _, _, cd_out = model(img1, img2)

                # Calculate loss
                loss = criterion(cd_out, cd_mask)
                val_loss += loss.item()

                # Move tensors to CPU and convert to numpy
                cd_preds = torch.argmax(cd_out, dim=1).cpu().numpy()
                cd_mask_np = cd_mask.cpu().numpy()

                metrics = calculate_metrics(cd_preds,
                                         cd_mask_np,
                                         num_classes=13)  # 13 change classes
                all_metrics_cd.append(metrics)

                # Clear memory
                del cd_out, cd_preds
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

        # Calculate average metrics
        avg_val_loss = val_loss / len(val_loader)
        val_metrics = {key: np.mean([m[key] for m in all_metrics_cd])
                      for key in all_metrics_cd[0].keys()}

        # Save best model based on validation loss
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_loss': best_val_loss,
                'train_loss': avg_train_loss,
                'val_metrics': val_metrics
            }, checkpoint_path)
            print(f"Saved new best model with validation loss: {best_val_loss:.4f}")

        # Update learning rate based on val loss
        scheduler.step(avg_val_loss)

        # Store metrics
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['val_metrics'].append(val_metrics)

        # Print epoch results
        print(f"Train Loss: {avg_train_loss:.4f}")
        print(f"Val Loss: {avg_val_loss:.4f}")
        print(f"CD Metrics - mIoU: {val_metrics['miou']:.4f}, Accuracy: {val_metrics['accuracy']:.4f}")

    return model, history


import json
def save_training_history(stage1_history, stage2_history, checkpoint_path, 
                          save_path_stage1, save_path_stage2, save_path_bestepoch):
    """
    Save training history and best epoch information to JSON files.

    Args:
        history (dict): Dictionary containing training and validation metrics
        checkpoint_path (str): Path to the model checkpoint file
        save_path (str): Path to save the training history
        save_path_bestepoch (str): Path to save the best epoch info
    """
    processed_history = {}
    for phase, metrics in stage1_history.items():
        if isinstance(metrics, list):  # Check if metrics is a list
            processed_history[phase] = [
                {metric: (value.tolist() if hasattr(value, 'tolist') else value)
                 for metric, value in entry.items()} if isinstance(entry, dict) else entry
                for entry in metrics
            ]
        else:  # Handle non-list entries
            processed_history[phase] = metrics.tolist() if hasattr(metrics, 'tolist') else metrics

    # Save processed history to a JSON file
    with open(save_path_stage1, 'w') as f:
        json.dump(processed_history, f, indent=4)

    print(f"Training history saved to: {save_path_stage1}")

    processed_history2 = {}
    for phase, metrics in stage2_history.items():
        if isinstance(metrics, list):  # Check if metrics is a list
            processed_history2[phase] = [
                {metric: (value.tolist() if hasattr(value, 'tolist') else value)
                 for metric, value in entry.items()} if isinstance(entry, dict) else entry
                for entry in metrics
            ]
        else:  # Handle non-list entries
            processed_history2[phase] = metrics.tolist() if hasattr(metrics, 'tolist') else metrics

    # Save processed history to a JSON file
    with open(save_path_stage2, 'w') as f:
        json.dump(processed_history2, f, indent=4)

    print(f"Training history saved to: {save_path_stage2}")

    # Load checkpoint to get best epoch info
    checkpoint = torch.load(checkpoint_path, weights_only=False)
    # print("\nCheckpoint contents:")
    # for key in checkpoint.keys():
    #     print(f"- {key}")

    # Extract best epoch information
    epoch_data = {
        'best_epoch': int(checkpoint['epoch']),
        'total_loss': float(checkpoint['val_loss']),
        'accuracy': float(checkpoint['val_metrics']['accuracy']),
        'precision': float(checkpoint['val_metrics']['precision']),
        'recall': float(checkpoint['val_metrics']['recall']),
        'f1': float(checkpoint['val_metrics']['f1_score']),
        'miou': float(checkpoint['val_metrics']['miou']),
        'kappa': float(checkpoint['val_metrics']['kappa'])
    }

    # Save the best epoch info
    with open(save_path_bestepoch, 'w') as f:
        json.dump(epoch_data, f, indent=4)
    print(f"Best epoch info saved to: {save_path_bestepoch}")

## Model Training

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Model configuration
input_channels = 3
num_semantic_classes = len(SEMANTIC_CLASSES)
num_cd_classes = len(CLASSES)

# Create model
model = MultiTaskChangeDetectionModel(
    input_channels=input_channels,
    num_semantic_classes=num_semantic_classes,
    num_cd_classes=num_cd_classes
).to(device)

# Define checkpoint paths
lcm_checkpoint_path = f'{SAVING_DIR}/best_lcm_model_{NUM_EPOCHS}.pt'
full_checkpoint_path = f'{SAVING_DIR}/best_full_model_{NUM_EPOCHS}.pt'

# Stage 1: Train LCM branches
model1, stage1_history = stage_one_training(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    checkpoint_path=lcm_checkpoint_path,
    numepochs=NUM_EPOCHS
)

# Load best LCM model before stage 2
if os.path.exists(lcm_checkpoint_path):
    checkpoint = torch.load(lcm_checkpoint_path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded best LCM model from epoch {checkpoint['epoch']}")

# Stage 2: End-to-end training
model2, stage2_history = stage_two_training(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    checkpoint_path=full_checkpoint_path,
    numepochs=NUM_EPOCHS
)

# Save training history
save_path_stage1 = f'{SAVING_DIR}/strategy4_training_history_stage1_{NUM_EPOCHS}.json'
save_path_stage2 = f'{SAVING_DIR}/strategy4_training_history_stage2_{NUM_EPOCHS}.json'
save_path_bestepoch = f'{SAVING_DIR}/strategy4_best_epoch_{NUM_EPOCHS}.json'

save_training_history(stage1_history,stage2_history, 
                      full_checkpoint_path, 
                      save_path_stage1, save_path_stage2, save_path_bestepoch)

## Testing

In [None]:
import matplotlib.pyplot as plt
def test_model_complete(model, test_loader, device, checkpoint_path, num_classes=13, num_semantic_classes=4):
    """
    Complete test function for multi-task change detection model.
    Tests both LCM and CD branches and provides detailed metrics.
    """
    print("\nStarting model testing...")

    # Load model checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    print("Loaded model checkpoint")

    # Initialize storage for metrics
    all_cd_metrics = []
    all_lcm_metrics = []  # Single list for averaged LCM metrics
    total_cd_loss = 0
    total_lcm_loss = 0
    samples = 0
    criterion = nn.CrossEntropyLoss()

    # Store some random samples for visualization
    random_samples = []

    print("\nProcessing test data...")
    with torch.no_grad():
        for batch_idx, (img1, img2, mask1, mask2, cd_mask) in enumerate(tqdm(test_loader)):
            # Move data to device
            img1, img2 = img1.to(device), img2.to(device)
            mask1, mask2 = mask1.to(device), mask2.to(device)
            cd_mask = cd_mask.to(device)

            # Forward pass
            lcm1_out, lcm2_out, cd_out = model(img1, img2)

            # Calculate losses
            cd_loss = criterion(cd_out, cd_mask)
            lcm_loss = (criterion(lcm1_out, mask1) + criterion(lcm2_out, mask2)) / 2

            batch_size = img1.size(0)
            total_cd_loss += cd_loss.item() * batch_size
            total_lcm_loss += lcm_loss.item() * batch_size
            samples += batch_size

            # Get predictions
            cd_preds = torch.argmax(cd_out, dim=1)
            lcm1_preds = torch.argmax(lcm1_out, dim=1)
            lcm2_preds = torch.argmax(lcm2_out, dim=1)

            # Calculate metrics
            cd_metrics = calculate_metrics(cd_preds.cpu().numpy(),
                                        cd_mask.cpu().numpy(),
                                        num_classes=num_classes)  # 13 change classes

            # Calculate LCM metrics and average them
            lcm1_metrics = calculate_metrics(lcm1_preds.cpu().numpy(),
                                          mask1.cpu().numpy(),
                                          num_classes=num_semantic_classes)   # 4 semantic classes
            lcm2_metrics = calculate_metrics(lcm2_preds.cpu().numpy(),
                                          mask2.cpu().numpy(),
                                          num_classes=num_semantic_classes)   # 4 semantic classes

            # Average the LCM metrics
            lcm_metrics = {}
            for key in lcm1_metrics.keys():
                lcm_metrics[key] = (lcm1_metrics[key] + lcm2_metrics[key]) / 2

            all_cd_metrics.append(cd_metrics)
            all_lcm_metrics.append(lcm_metrics)

            # Store random samples for visualization
            if len(random_samples) < 5 and batch_idx % 10 == 0:  # Every 10th batch
                for i in range(min(batch_size, 5 - len(random_samples))):
                    random_samples.append({
                        'img1': img1[i].cpu(),
                        'img2': img2[i].cpu(),
                        'lcm1_pred': lcm1_preds[i].cpu(),
                        'lcm1_true': mask1[i].cpu(),
                        'lcm2_pred': lcm2_preds[i].cpu(),
                        'lcm2_true': mask2[i].cpu(),
                        'cd_pred': cd_preds[i].cpu(),
                        'cd_true': cd_mask[i].cpu(),
                        'cd_probs': torch.softmax(cd_out[i], dim=0).cpu()
                    })

            # Clear memory
            del lcm1_out, lcm2_out, cd_out
            del cd_preds, lcm1_preds, lcm2_preds
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    # Calculate average losses
    avg_cd_loss = total_cd_loss / samples
    avg_lcm_loss = total_lcm_loss / samples

    # Calculate average metrics
    def aggregate_metrics(metrics_list):
        result = {}
        for key in metrics_list[0].keys():
            result[key] = np.mean([m[key] for m in metrics_list])
        return result

    cd_metrics = aggregate_metrics(all_cd_metrics)
    lcm_metrics = aggregate_metrics(all_lcm_metrics)

    # Print results
    print("\nTest Results:")
    print("\nChange Detection Metrics:")
    print(f"Loss: {avg_cd_loss:.4f}")
    print(f"Accuracy: {cd_metrics['accuracy']:.4f}")
    print(f"mIoU: {cd_metrics['miou']:.4f}")
    print(f"F1-Score: {cd_metrics['f1_score']:.4f}")
    print(f"Kappa: {cd_metrics['kappa']:.4f}")

    print("\nLand Cover Mapping Metrics (Averaged):")
    print(f"Loss: {avg_lcm_loss:.4f}")
    print(f"Accuracy: {lcm_metrics['accuracy']:.4f}")
    print(f"mIoU: {lcm_metrics['miou']:.4f}")
    print(f"F1-Score: {lcm_metrics['f1_score']:.4f}")
    print(f"Kappa: {lcm_metrics['kappa']:.4f}")

    # Visualize results
    if random_samples:
        visualize_results(random_samples)
    else:
        print("\nNo samples available for visualization")

    return {
        'cd_metrics': cd_metrics,
        'lcm_metrics': lcm_metrics,  # Single set of averaged LCM metrics
        'cd_loss': avg_cd_loss,
        'lcm_loss': avg_lcm_loss
    }

def visualize_results(samples):
    """
    Visualize test results including original images, predictions, 
    and ground truth for both LCM and CD
    """
    num_samples = len(samples)
    fig, axes = plt.subplots(num_samples, 8, figsize=(32, 4*num_samples))
    if num_samples == 1:
        axes = axes[np.newaxis, :]

    for idx, sample in enumerate(samples):
        # Original images
        axes[idx, 0].imshow(sample['img1'].permute(1, 2, 0))
        axes[idx, 0].set_title('Image 2019')
        axes[idx, 0].axis('off')

        axes[idx, 1].imshow(sample['img2'].permute(1, 2, 0))
        axes[idx, 1].set_title('Image 2024')
        axes[idx, 1].axis('off')

        # LCM 2019 results
        axes[idx, 2].imshow(sample['lcm1_pred'], cmap='tab10')
        axes[idx, 2].set_title('LCM 2019 Pred')
        axes[idx, 2].axis('off')

        axes[idx, 3].imshow(sample['lcm1_true'], cmap='tab10')
        axes[idx, 3].set_title('LCM 2019 GT')
        axes[idx, 3].axis('off')

        # LCM 2024 results
        axes[idx, 4].imshow(sample['lcm2_pred'], cmap='tab10')
        axes[idx, 4].set_title('LCM 2024 Pred')
        axes[idx, 4].axis('off')

        axes[idx, 5].imshow(sample['lcm2_true'], cmap='tab10')
        axes[idx, 5].set_title('LCM 2024 GT')
        axes[idx, 5].axis('off')

        # CD results
        axes[idx, 6].imshow(sample['cd_pred'], cmap='tab10')
        axes[idx, 6].set_title('CD Prediction')
        axes[idx, 6].axis('off')

        axes[idx, 7].imshow(sample['cd_true'], cmap='tab10')
        axes[idx, 7].set_title('CD Ground Truth')
        axes[idx, 7].axis('off')

    # Add colorbars and adjust layout
    plt.tight_layout()
    plt.show()

def save_test_metrics(history, save_path):
    # Convert tensors or arrays in the history to lists for JSON serialization
    processed_history = {}
    for phase, metrics in history.items():
        if isinstance(metrics, list):  # Check if metrics is a list
            processed_history[phase] = [
                {metric: (value.tolist() if hasattr(value, 'tolist') else value)
                 for metric, value in entry.items()} if isinstance(entry, dict) else entry
                for entry in metrics
            ]
        else:  # Handle non-list entries
            processed_history[phase] = metrics.tolist() if hasattr(metrics, 'tolist') else metrics

    # Save processed history to a JSON file
    with open(save_path, 'w') as f:
        json.dump(processed_history, f, indent=4)

    print(f"Testing history saved to: {save_path}")

# Test the model
test_metrics = test_model_complete(
    model=model,
    test_loader=test_loader,
    device=device,
    checkpoint_path=full_checkpoint_path,
    num_classes=num_cd_classes,
    num_semantic_classes=num_semantic_classes
)
save_path = f'{SAVING_DIR}/strategy4_test_metrics.json'
save_test_metrics(test_metrics, save_path)