In [None]:
!pip install segmentation_models_pytorch
!pip install rasterio

In [None]:
from google.colab import drive
drive.mount('/content/drive')
!unzip '/content/drive/MyDrive/BTechProject/ChangeDetectionMergedDividedSplit-tif3.zip' -d '/content/ChangeDetectionMergedDividedSplit-tif'

## Dataloader

In [None]:
import os
import rasterio
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import json

class ChangeDetectionDatasetTIF(Dataset):
    def __init__(self, t2019_dir, t2024_dir, sem_2019_dir, sem_2024_dir, mask_dir,
                 classes, semantic_classes, transform=None):
        self.t2019_dir = t2019_dir
        self.t2024_dir = t2024_dir
        self.mask_dir = mask_dir
        self.sem_2019_dir = sem_2019_dir
        self.sem_2024_dir = sem_2024_dir
        self.classes = classes  # Change detection classes
        self.semantic_classes = semantic_classes  # Land cover classes
        self.transform = transform

        # Load all paths
        self.t2019_paths = sorted([f for f in os.listdir(t2019_dir) if f.endswith('.tif')])
        self.t2024_paths = sorted([f for f in os.listdir(t2024_dir) if f.endswith('.tif')])
        self.mask_paths = sorted([f for f in os.listdir(mask_dir) if f.endswith('.tif')])
        self.sem2019_paths = sorted([f for f in os.listdir(sem_2019_dir) if f.endswith('.tif')])
        self.sem2024_paths = sorted([f for f in os.listdir(sem_2024_dir) if f.endswith('.tif')])

        # Verify all paths match
        assert len(self.t2019_paths) == len(self.t2024_paths) == len(self.mask_paths) == \
               len(self.sem2019_paths) == len(self.sem2024_paths), "Mismatched number of images"

    def __len__(self):
        return len(self.t2019_paths)

    def __getitem__(self, index):
        # Load images using rasterio
        with rasterio.open(os.path.join(self.t2019_dir, self.t2019_paths[index])) as src:
            img_t2019 = src.read(out_dtype=np.float32) / 255.0
        with rasterio.open(os.path.join(self.t2024_dir, self.t2024_paths[index])) as src:
            img_t2024 = src.read(out_dtype=np.float32) / 255.0

        # Load masks
        with rasterio.open(os.path.join(self.mask_dir, self.mask_paths[index])) as src:
            cd_mask = src.read(1).astype(np.int64)
        with rasterio.open(os.path.join(self.sem_2019_dir, self.sem2019_paths[index])) as src:
            sem_mask_2019 = src.read(1).astype(np.int64)
        with rasterio.open(os.path.join(self.sem_2024_dir, self.sem2024_paths[index])) as src:
            sem_mask_2024 = src.read(1).astype(np.int64)

        # Convert to PyTorch tensors
        img_t2019 = torch.from_numpy(img_t2019)
        img_t2024 = torch.from_numpy(img_t2024)
        cd_mask = torch.from_numpy(cd_mask)
        sem_mask_2019 = torch.from_numpy(sem_mask_2019)
        sem_mask_2024 = torch.from_numpy(sem_mask_2024)

        # Apply transforms if any
        if self.transform is not None:
            img_t2019 = self.transform(img_t2019)
            img_t2024 = self.transform(img_t2024)

        # Return dictionary format
        return {
            'img_t2019': img_t2019,
            'img_t2024': img_t2024,
            'cd_mask': cd_mask,
            'sem_mask_2019': sem_mask_2019,
            'sem_mask_2024': sem_mask_2024
        }


def describe_loader(loader_type):
    """Print information about a data loader"""
    sample = next(iter(loader_type))
    print("Batch size:", loader_type.batch_size)
    print("Shapes:")
    print("  Image 2019:", sample['img_t2019'].shape)
    print("  Image 2024:", sample['img_t2024'].shape)
    print("  Change Mask:", sample['cd_mask'].shape)
    print("  Semantic Mask 2019:", sample['sem_mask_2019'].shape)
    print("  Semantic Mask 2024:", sample['sem_mask_2024'].shape)
    print("Number of images:", len(loader_type.dataset))
    print("Change Classes:", loader_type.dataset.classes)
    print("Semantic Classes:", loader_type.dataset.semantic_classes)

    # Print value ranges
    print("\nValue ranges:")
    #print("  Images:", torch.min(sample['img_t2019']).item(), "to", torch.max(sample['img_t2019']).item())
    print("  Change Mask:", torch.unique(sample['cd_mask']))
    print("  Semantic Mask:", torch.unique(sample['sem_mask_2019']))

# Example usage:
ROOT_DIRECTORY = "ChangeDetectionMergedDividedSplit-tif"
SAVING_DIR = '/content/drive/MyDrive/BTechProject'
CD_DIR = "cd1_Output"
CLASSES = ['no_change', 'increase_vegetation', 'decrease_vegetation']
# CLASSES = ['no_change', 'water_building', 'water_sparse', 'water_dense',
#            'building_water', 'building_sparse', 'building_dense',
#            'sparse_water', 'sparse_building', 'sparse_dense',
#            'dense_water', 'dense_building', 'dense_sparse']

SEMANTIC_CLASSES = ['water', 'building', 'sparse_vegetation', 'dense_vegetation']
# Create datasets
train_dataset = ChangeDetectionDatasetTIF(
    t2019_dir=f"{ROOT_DIRECTORY}/train/Images/T2019",
    t2024_dir=f"{ROOT_DIRECTORY}/train/Images/T2024",
    sem_2019_dir=f"{ROOT_DIRECTORY}/train/Masks/T2019",
    sem_2024_dir=f"{ROOT_DIRECTORY}/train/Masks/T2024",
    mask_dir=f"{ROOT_DIRECTORY}/train/{CD_DIR}",
    classes=CLASSES,
    semantic_classes=SEMANTIC_CLASSES
)

val_dataset = ChangeDetectionDatasetTIF(
    t2019_dir=f"{ROOT_DIRECTORY}/val/Images/T2019",
    t2024_dir=f"{ROOT_DIRECTORY}/val/Images/T2024",
    sem_2019_dir=f"{ROOT_DIRECTORY}/val/Masks/T2019",
    sem_2024_dir=f"{ROOT_DIRECTORY}/val/Masks/T2024",
    mask_dir=f"{ROOT_DIRECTORY}/val/{CD_DIR}",
    classes=CLASSES,
    semantic_classes=SEMANTIC_CLASSES
)

test_dataset = ChangeDetectionDatasetTIF(
    t2019_dir=f"{ROOT_DIRECTORY}/test/Images/T2019",
    t2024_dir=f"{ROOT_DIRECTORY}/test/Images/T2024",
    sem_2019_dir=f"{ROOT_DIRECTORY}/test/Masks/T2019",
    sem_2024_dir=f"{ROOT_DIRECTORY}/test/Masks/T2024",
    mask_dir=f"{ROOT_DIRECTORY}/test/{CD_DIR}",
    classes=CLASSES,
    semantic_classes=SEMANTIC_CLASSES
)

# Create dataloaders
num_workers = 4
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
print("------------Train------------")
describe_loader(train_loader)
print("------------Val------------")
describe_loader(val_loader)
print("------------Test------------")
describe_loader(test_loader)

## Visualize

In [None]:
import random
import matplotlib.pyplot as plt
import torch

def visualize_samples(loader, num_samples=3):
    dataset = loader.dataset
    fig, axes = plt.subplots(num_samples, 5, figsize=(15, 3 * num_samples))  # Reduced figure size

    # Randomly select indices for samples
    indices = random.sample(range(len(dataset)), num_samples)

    for i, idx in enumerate(indices):
        sample = dataset[idx]

        # Plot 2019 image
        axes[i, 0].imshow(sample['img_t2019'].permute(1, 2, 0))
        axes[i, 0].set_title('Image 2019', fontsize=8)
        axes[i, 0].axis('off')

        # Plot 2024 image
        axes[i, 1].imshow(sample['img_t2024'].permute(1, 2, 0))
        axes[i, 1].set_title('Image 2024', fontsize=8)
        axes[i, 1].axis('off')

        # Plot change detection mask
        im_cd = axes[i, 2].imshow(sample['cd_mask'], cmap='turbo')
        axes[i, 2].set_title('Change Mask', fontsize=8)
        axes[i, 2].axis('off')

        # Plot semantic masks
        im_sem_2019 = axes[i, 3].imshow(sample['sem_mask_2019'], cmap='viridis', vmin=0, vmax=3)
        axes[i, 3].set_title('Semantic Mask 2019', fontsize=8)
        axes[i, 3].axis('off')

        im_sem_2024 = axes[i, 4].imshow(sample['sem_mask_2024'], cmap='viridis', vmin=0, vmax=3)
        axes[i, 4].set_title('Semantic Mask 2024', fontsize=8)
        axes[i, 4].axis('off')

        # Print unique values
        print(f"\nSample {i} - {dataset.t2019_paths[idx]}:")
        print(f"Change mask values: {torch.unique(sample['cd_mask'])}")
        print(f"Semantic mask 2019 values: {torch.unique(sample['sem_mask_2019'])}")
        print(f"Semantic mask 2024 values: {torch.unique(sample['sem_mask_2024'])}")

    plt.tight_layout(pad=1.0, h_pad=1.0, w_pad=1.0)  # Adjust subplot spacing
    plt.show()

# Visualize training, validation, and test samples
print("\nVisualizing training samples...")
visualize_samples(train_loader)

print("\nVisualizing validation samples...")
visualize_samples(val_loader)

print("\nVisualizing test samples...")
visualize_samples(test_loader)

## Model definition, Util functions and Training loop

In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import segmentation_models_pytorch as smp
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix


class ChangeDetectionModel(nn.Module):
    def __init__(self, architecture='unet', encoder='resnet34', input_channels=3, num_semantic_classes=4, num_classes=3):
        super().__init__()
        # Semantic segmentation models for each timestamp
        if architecture.lower() == 'unet':
            self.sem_model = smp.Unet(
                encoder_name=encoder,
                encoder_weights="imagenet",
                in_channels=input_channels,
                classes=num_semantic_classes,
            )
        elif architecture.lower() == 'linknet':
            self.sem_model = smp.Linknet(
                encoder_name=encoder,
                encoder_weights="imagenet",
                in_channels=input_channels,
                classes=num_semantic_classes,
            )
        elif architecture.lower() == 'pspnet':
            self.sem_model = smp.PSPNet(
                encoder_name=encoder,
                encoder_weights="imagenet",
                in_channels=input_channels,
                classes=num_semantic_classes,
            )
        elif architecture.lower() == 'deeplabv3plus':
            self.sem_model = smp.DeepLabV3Plus(
                encoder_name=encoder,
                encoder_weights="imagenet",
                in_channels=input_channels,
                classes=num_semantic_classes,
            )

        self.change_head = nn.Sequential(
            nn.Conv2d(num_semantic_classes*2, 64, kernel_size=3, padding=1),  # Update input channels
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, num_classes, kernel_size=1),  # Update output channels
            nn.Dropout(0.3)
        )

    def forward(self, x1, x2):
        # Get semantic features for both timestamps
        sem_feat1 = self.sem_model(x1)
        sem_feat2 = self.sem_model(x2)

        # Concatenate semantic features
        combined_feat = torch.cat([sem_feat1, sem_feat2], dim=1)

        # Get change detection output
        change_out = self.change_head(combined_feat)

        return sem_feat1, sem_feat2, change_out


def calculate_metrics(predictions, targets, num_classes):
    """
    Calculate comprehensive metrics for change detection using a single confusion matrix

    Args:
        predictions (np.array): Predicted class labels
        targets (np.array): Ground truth class labels
        num_classes (int): Number of classes in the dataset

    Returns:
        dict: Dictionary of performance metrics
    """
    # Flatten predictions and targets
    pred_flat = predictions.flatten()
    target_flat = targets.flatten()

    # Compute confusion matrix once
    cm = confusion_matrix(target_flat, pred_flat)

    # Calculate metrics from confusion matrix
    metrics = {}

    # True positives, false positives, false negatives for each class
    tp = np.diag(cm)
    fp = np.sum(cm, axis=0) - tp
    fn = np.sum(cm, axis=1) - tp

    # Overall accuracy from confusion matrix
    metrics['accuracy'] = np.sum(tp) / np.sum(cm)

    # Per-class precision, recall, F1
    precision = tp / (tp + fp + 1e-6)
    recall = tp / (tp + fn + 1e-6)
    f1 = 2 * (precision * recall) / (precision + recall + 1e-6)

    #### To get Weighted averages - uncomment `total` and `weights` arguments
    #total = np.sum(cm, axis=1)
    metrics['precision'] = np.average(precision,)# weights=total)
    metrics['recall'] = np.average(recall,)# weights=total)
    metrics['f1_score'] = np.average(f1,)# weights=total)

    # Calculate Kappa directly from confusion matrix
    n = np.sum(cm)
    sum_po = np.sum(np.diag(cm))
    sum_pe = np.sum(np.sum(cm, axis=0) * np.sum(cm, axis=1)) / n
    metrics['kappa'] = (sum_po - sum_pe) / (n - sum_pe + 1e-6)

    # IoU from confusion matrix
    iou_per_class = tp / (tp + fp + fn + 1e-6)
    metrics['miou'] = np.mean(iou_per_class)

    return metrics


def calculate_effective_weights(train_loader, num_classes=3, device='cuda'):
    class_counts = torch.zeros(num_classes)
    total_pixels = 0

    # Count class frequencies
    for batch in train_loader:
        labels = batch['cd_mask'].to(device)
        for i in range(num_classes):
            class_counts[i] += (labels == i).sum().item()
        total_pixels += labels.numel()

    # Avoid division by zero
    class_counts = torch.where(class_counts == 0, torch.ones_like(class_counts), class_counts)

    class_frequencies = class_counts / total_pixels
    weights = torch.sqrt(1.0 / class_frequencies)
    weights = weights * (num_classes / weights.sum())

    return weights


def train_model_pcc(model, train_loader, val_loader, num_epochs, device,
                   checkpoint_path, loss='CE', num_classes=3):
    """Train the PCC model with optimized metrics calculation"""
        # Try to load checkpoint if it exists
    if os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}")
        try:
            checkpoint = torch.load(checkpoint_path, weights_only=True)
            model.load_state_dict(checkpoint['model_state_dict'])
            start_epoch = checkpoint['epoch']
            best_val_loss = checkpoint['val_loss']  # Load best validation loss
            print(f"Resuming from epoch {start_epoch} with best val loss: {best_val_loss:.4f}")
        except Exception as e:
            print(f"Error loading checkpoint: {e}")
            print("Starting training from scratch")
            start_epoch = 0
            best_val_loss = float('inf')
    else:
        print("No checkpoint found. Starting training from scratch")
        start_epoch = 0
        best_val_loss = float('inf')

    # Initialize loss functions based on weighting method and loss type
    sem_criterion = nn.CrossEntropyLoss()

    class_weights = calculate_effective_weights(train_loader,num_classes=num_classes,device='cuda')
    print(class_weights)
    change_criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))

    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    history = {'train_loss': [], 'val_loss': [], 'val_metrics': []}

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')

        for batch in progress_bar:
            img_t2019 = batch['img_t2019'].to(device)
            img_t2024 = batch['img_t2024'].to(device)
            sem_mask_2019 = batch['sem_mask_2019'].to(device)
            sem_mask_2024 = batch['sem_mask_2024'].to(device)
            cd_mask = batch['cd_mask'].to(device)

            optimizer.zero_grad()
            sem_out1, sem_out2, change_out = model(img_t2019, img_t2024)

            sem_loss = (sem_criterion(sem_out1, sem_mask_2019) +
                       sem_criterion(sem_out2, sem_mask_2024)) / 2
            change_loss = change_criterion(change_out, cd_mask)
            loss = sem_loss + change_loss

            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})

        avg_train_loss = train_loss / len(train_loader)

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_metrics_list = []

        with torch.no_grad():
            for batch in val_loader:
                img_t2019 = batch['img_t2019'].to(device)
                img_t2024 = batch['img_t2024'].to(device)
                sem_mask_2019 = batch['sem_mask_2019'].to(device)
                sem_mask_2024 = batch['sem_mask_2024'].to(device)
                cd_mask = batch['cd_mask'].to(device)

                sem_out1, sem_out2, change_out = model(img_t2019, img_t2024)

                sem_loss = (sem_criterion(sem_out1, sem_mask_2019) +
                           sem_criterion(sem_out2, sem_mask_2024)) / 2
                change_loss = change_criterion(change_out, cd_mask)
                loss = sem_loss + change_loss

                val_loss += loss.item()

                # Calculate metrics for this batch
                preds = torch.argmax(change_out, dim=1).cpu().numpy()
                targets = cd_mask.cpu().numpy()
                batch_metrics = calculate_metrics(preds, targets, num_classes)
                val_metrics_list.append(batch_metrics)

        avg_val_loss = val_loss / len(val_loader)

        val_metrics = {}
        for key in val_metrics_list[0].keys():
          val_metrics[key] = np.mean([m[key] for m in val_metrics_list])


        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch+1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': float(best_val_loss),
                'train_loss': float(avg_train_loss),
                'val_accuracy': float(val_metrics['accuracy']),
                'val_kappa': float(val_metrics['kappa']),
                'val_miou': float(val_metrics['miou']),
                'val_f1_score': float(val_metrics['f1_score']),
            }, checkpoint_path)
            print(f"Saved best model checkpoint to {checkpoint_path}")

        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['val_metrics'].append(val_metrics)

        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"Train Loss: {avg_train_loss:.4f}")
        print(f"Val Loss: {avg_val_loss:.4f}")
        print(f"Metrics - Accuracy: {val_metrics['accuracy']:.4f}, "
              f"Kappa: {val_metrics['kappa']:.4f}, "
              f"mIoU: {val_metrics['miou']:.4f}, "
              f"F1 score: {val_metrics['f1_score']:.4f}, ")

    return model, history


def save_training_history(history, checkpoint_path, save_path, save_path_bestepoch):
    # Convert tensors or arrays in the history to lists for JSON serialization
    processed_history = {}
    for phase, metrics in history.items():
        if isinstance(metrics, list):  # Check if metrics is a list
            processed_history[phase] = [
                {metric: (value.tolist() if hasattr(value, 'tolist') else value)
                 for metric, value in entry.items()} if isinstance(entry, dict) else entry
                for entry in metrics
            ]
        else:  # Handle non-list entries
            processed_history[phase] = metrics.tolist() if hasattr(metrics, 'tolist') else metrics

    # Save processed history to a JSON file
    with open(save_path, 'w') as f:
        json.dump(processed_history, f, indent=4)

    print(f"Training history saved to: {save_path}")

    # Load checkpoint and inspect contents
    checkpoint = torch.load(checkpoint_path, weights_only=True)
    # print("\nCheckpoint contents:")
    # for key in checkpoint.keys():
    #     print(f"- {key}")

    epoch_data = {
        'best_epoch': checkpoint['epoch'],
        'train_loss': checkpoint['train_loss'],
        'val_loss': checkpoint['val_loss'],
        'val_accuracy': checkpoint['val_accuracy'],
        'val_kappa': checkpoint['val_kappa'],
        'val_miou': checkpoint['val_miou'],
        'val_f1_score': checkpoint['val_f1_score']
    }

    # Save the best epoch info
    with open(save_path_bestepoch, 'w') as f:
        json.dump(epoch_data, f, indent=4)
    print(f"Best epoch info saved to: {save_path_bestepoch}")

## Model train

In [5]:
# Initialize model and device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
architecture = 'unet'  # 'unet' or 'linknet', 'pspnet', 'deeplabv3plus'
num_classes = 3   # Change Detection classes (3 for cd1, 13 for cd2)
num_semantic_classes = 4   # Semantic segmentation LCM classes (4 for both)
num_epochs = 100
loss = 'CE'
checkpoint_path = f'{SAVING_DIR}/best_{architecture}-{num_classes}_classes_{num_epochs}_epochs.pt'

# Create model
model = ChangeDetectionModel(
    architecture=architecture,encoder='resnet34',
    input_channels=3,num_classes=num_classes,
    num_semantic_classes=num_semantic_classes
).to(device)

# Train model
model2, history = train_model_pcc(model=model,train_loader=train_loader,val_loader=val_loader,
                                 num_epochs=num_epochs,device=device,checkpoint_path=checkpoint_path,
                                 loss=loss,num_classes=num_classes)

#Save model files
save_history_path=f'{SAVING_DIR}/PCC_{architecture}-{num_classes}_classes_{num_epochs}_history.json'
save_bestepoch_path=f'{SAVING_DIR}/PCC_{architecture}-{num_classes}_classes_{num_epochs}_best_epoch.json'
save_training_history(history, checkpoint_path=checkpoint_path,
                      save_path=save_history_path, save_path_bestepoch=save_bestepoch_path)

## Testing

In [None]:
def plot_results_pcc(img1, img2, sem_pred1, sem_pred2, sem_gt1, sem_gt2,
                    change_pred, change_gt):
    """Plot the results from the PCC model in notebook cells"""
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))

    # Plot images
    axes[0, 0].imshow(img1.cpu().permute(1, 2, 0))
    axes[0, 0].set_title('Image 2019')
    axes[0, 1].imshow(img2.cpu().permute(1, 2, 0))
    axes[0, 1].set_title('Image 2024')

    # Plot semantic predictions and ground truth
    axes[0, 2].imshow(sem_pred1.cpu())
    axes[0, 2].set_title('Semantic Pred 2019')
    axes[0, 3].imshow(sem_pred2.cpu())
    axes[0, 3].set_title('Semantic Pred 2024')

    axes[1, 0].imshow(sem_gt1.cpu())
    axes[1, 0].set_title('Semantic GT 2019')
    axes[1, 1].imshow(sem_gt2.cpu())
    axes[1, 1].set_title('Semantic GT 2024')

    # Plot change detection results
    axes[1, 2].imshow(change_pred.cpu())
    axes[1, 2].set_title('Change Prediction')
    axes[1, 3].imshow(change_gt.cpu())
    axes[1, 3].set_title('Change GT')

    plt.tight_layout()
    plt.show()


def test_model_pcc(model, test_loader, checkpoint_path, device, num_samples_to_plot=5, num_classes=3):
    """Test the PCC model with enhanced metrics for both change detection and semantic segmentation"""
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded checkpoint from {checkpoint_path}")
    model.eval()

    # Initialize metric trackers
    cd_metrics = []  # Change detection metrics
    sem_2019_metrics = []  # Semantic segmentation 2019 metrics
    sem_2024_metrics = []  # Semantic segmentation 2024 metrics
    all_predictions = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc='Testing'):
            img_t2019 = batch['img_t2019'].to(device)
            img_t2024 = batch['img_t2024'].to(device)
            sem_mask_2019 = batch['sem_mask_2019'].to(device)
            sem_mask_2024 = batch['sem_mask_2024'].to(device)
            cd_mask = batch['cd_mask'].to(device)

            sem_out1, sem_out2, change_out = model(img_t2019, img_t2024)

            # Get predictions
            sem_pred1 = torch.argmax(sem_out1, dim=1)
            sem_pred2 = torch.argmax(sem_out2, dim=1)
            change_pred = torch.argmax(change_out, dim=1)

            # Calculate metrics for all tasks
            batch_cd_metrics = calculate_metrics(change_pred.cpu().numpy(),
                                              cd_mask.cpu().numpy(),
                                              num_classes)

            batch_sem_2019_metrics = calculate_metrics(sem_pred1.cpu().numpy(),
                                                     sem_mask_2019.cpu().numpy(),
                                                     num_classes)

            batch_sem_2024_metrics = calculate_metrics(sem_pred2.cpu().numpy(),
                                                     sem_mask_2024.cpu().numpy(),
                                                     num_classes)

            cd_metrics.append(batch_cd_metrics)
            sem_2019_metrics.append(batch_sem_2019_metrics)
            sem_2024_metrics.append(batch_sem_2024_metrics)

            # Store predictions for plotting
            all_predictions.append({
                'img_t2019': img_t2019.cpu(),
                'img_t2024': img_t2024.cpu(),
                'sem_pred1': sem_pred1.cpu(),
                'sem_pred2': sem_pred2.cpu(),
                'sem_mask_2019': sem_mask_2019.cpu(),
                'sem_mask_2024': sem_mask_2024.cpu(),
                'change_pred': change_pred.cpu(),
                'cd_mask': cd_mask.cpu()
            })

    # Aggregate metrics
    def aggregate_metrics(metrics_list):
        final_metrics = {}
        for key in metrics_list[0].keys():
            final_metrics[key] = np.mean([m[key] for m in metrics_list])
        return final_metrics

    final_cd_metrics = aggregate_metrics(cd_metrics)
    final_sem_2019_metrics = aggregate_metrics(sem_2019_metrics)
    final_sem_2024_metrics = aggregate_metrics(sem_2024_metrics)

    # Calculate average semantic segmentation metrics
    avg_sem_metrics = {}
    for key in final_sem_2019_metrics.keys():
        avg_sem_metrics[key] = (final_sem_2019_metrics[key] + final_sem_2024_metrics[key]) / 2

    # Print detailed metrics
    print("\n=== Change Detection Metrics ===")
    print(f"Overall Accuracy: {final_cd_metrics['accuracy']:.4f}")
    print(f"Kappa Score: {final_cd_metrics['kappa']:.4f}")
    print(f"mIoU: {final_cd_metrics['miou']:.4f}")
    print(f"F1 score: {final_cd_metrics['f1_score']:.4f}")

    print("\n=== 2019 Semantic Segmentation Metrics ===")
    print(f"Overall Accuracy: {final_sem_2019_metrics['accuracy']:.4f}")
    print(f"Kappa Score: {final_sem_2019_metrics['kappa']:.4f}")
    print(f"mIoU: {final_sem_2019_metrics['miou']:.4f}")
    print(f"F1 score: {final_sem_2019_metrics['f1_score']:.4f}")

    print("\n=== 2024 Semantic Segmentation Metrics ===")
    print(f"Overall Accuracy: {final_sem_2024_metrics['accuracy']:.4f}")
    print(f"Kappa Score: {final_sem_2024_metrics['kappa']:.4f}")
    print(f"mIoU: {final_sem_2024_metrics['miou']:.4f}")
    print(f"F1 score: {final_sem_2024_metrics['f1_score']:.4f}")

    print("\n=== Average Semantic Segmentation Metrics ===")
    print(f"Overall Accuracy: {avg_sem_metrics['accuracy']:.4f}")
    print(f"Kappa Score: {avg_sem_metrics['kappa']:.4f}")
    print(f"mIoU: {avg_sem_metrics['miou']:.4f}")
    print(f"F1 score: {avg_sem_metrics['f1_score']:.4f}")

    # Plot random samples
    total_samples = len(all_predictions)
    random_indices = np.random.choice(total_samples, num_samples_to_plot, replace=False)

    print(f"\nPlotting {num_samples_to_plot} random samples...")
    for idx in random_indices:
        batch = all_predictions[idx]
        plot_results_pcc(
            batch['img_t2019'][0],
            batch['img_t2024'][0],
            batch['sem_pred1'][0],
            batch['sem_pred2'][0],
            batch['sem_mask_2019'][0],
            batch['sem_mask_2024'][0],
            batch['change_pred'][0],
            batch['cd_mask'][0]
        )

    return {
        'change_detection': final_cd_metrics,
        'semantic_2019': final_sem_2019_metrics,
        'semantic_2024': final_sem_2024_metrics,
        'semantic_average': avg_sem_metrics
    }


def save_test_metrics(history, save_path):
    # Convert tensors or arrays in the history to lists for JSON serialization
    processed_history = {}
    for phase, metrics in history.items():
        if isinstance(metrics, list):  # Check if metrics is a list
            processed_history[phase] = [
                {metric: (value.tolist() if hasattr(value, 'tolist') else value)
                 for metric, value in entry.items()} if isinstance(entry, dict) else entry
                for entry in metrics
            ]
        else:  # Handle non-list entries
            processed_history[phase] = metrics.tolist() if hasattr(metrics, 'tolist') else metrics

    # Save processed history to a JSON file
    with open(save_path, 'w') as f:
        json.dump(processed_history, f, indent=4)

    print(f"Training history saved to: {save_path}")

test_metrics = test_model_pcc(model, test_loader, checkpoint_path=checkpoint_path,
                              device=device, num_samples_to_plot=3,num_classes=num_classes)

save_path = f'{SAVING_DIR}/PCC_{architecture}-{num_classes}_classes_{num_epochs}_test_metrics.json'
save_test_metrics(test_metrics,save_path=save_path)