In [None]:
# from google.colab import drive
# drive.mount('/content/drive')
# !unzip '/content/drive/MyDrive/BTechProject/ChangeDetectionMergedDividedSplit-tif3.zip' -d '/content/ChangeDetectionMergedDividedSplit-tif'

In [None]:
# !pip install rasterio

## Hyperparameters

In [None]:
ROOT_DIRECTORY = "ChangeDetectionMergedDividedSplit-tif"
SAVING_DIR = "/content/drive/MyDrive/BTechProject"
CD_DIR = "cd1_Output"  #FOR STANET ALWAYS USE cd1_Output

if CD_DIR == "cd1_Output":
    CLASSES = ['no_change','vegetation_increase','vegetation_decrease']
elif CD_DIR == "cd2_Output":
    # CLASSES = ['no_change', 'water_building', 'water_sparse', 'water_dense',
    #            'building_water', 'building_sparse', 'building_dense',
    #            'sparse_water', 'sparse_building', 'sparse_dense',
    #            'dense_water', 'dense_building', 'dense_sparse']
    CLASSES = [
    'no_change','water_built', 'water_bare', 'water_sparse', 'water_trees',
    'water_crops', 'built_water', 'built_bare', 'built_sparse', 'built_trees',
    'built_crops',  'bare_water',  'bare_built',  'bare_sparse',  'bare_trees',
    'bare_crops',  'sparse_water',  'sparse_built',  'sparse_bare',
    'sparse_trees',  'sparse_crops',  'trees_water',  'trees_built',
    'trees_bare',  'trees_sparse',  'trees_crops',  'crops_water',
    'crops_built', 'crops_bare',  'crops_sparse',  'crops_trees']

NUM_WORKERS = 8
BATCH_SIZE = 32
NUM_EPOCHS = 5
MODEL_NAME = 'stanet'
ATTENTION_MODE = 'BAM' # 'BAM' or 'PAM', 'None'

## Data Loader

In [None]:
import os
import rasterio
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class ChangeDetectionDatasetTIF(Dataset):
    def __init__(self, t2019_dir, t2024_dir, mask_dir,classes, transform=None):
        self.t2019_dir = t2019_dir
        self.t2024_dir = t2024_dir
        self.mask_dir = mask_dir
        self.classes = classes  # Change detection classes
        self.transform = transform

        # Load all paths
        self.t2019_paths = sorted([f for f in os.listdir(t2019_dir) if f.endswith('.tif')])
        self.t2024_paths = sorted([f for f in os.listdir(t2024_dir) if f.endswith('.tif')])
        self.mask_paths = sorted([f for f in os.listdir(mask_dir) if f.endswith('.tif')])

    def __len__(self):
        return len(self.t2019_paths)

    def __getitem__(self, index):
        # Load images using rasterio
        with rasterio.open(os.path.join(self.t2019_dir, self.t2019_paths[index])) as src:
            img_t2019 = src.read(out_dtype=np.float32) / 255.0
        with rasterio.open(os.path.join(self.t2024_dir, self.t2024_paths[index])) as src:
            img_t2024 = src.read(out_dtype=np.float32) / 255.0
        # Load masks
        with rasterio.open(os.path.join(self.mask_dir, self.mask_paths[index])) as src:
            cd_mask = src.read(1).astype(np.int64)

        # Convert to PyTorch tensors
        img_t2019 = torch.from_numpy(img_t2019)
        img_t2024 = torch.from_numpy(img_t2024)
        cd_mask = torch.from_numpy(cd_mask)

        # Apply transforms if any
        if self.transform is not None:
            img_t2019 = self.transform(img_t2019)
            img_t2024 = self.transform(img_t2024)

        return img_t2019, img_t2024, cd_mask

def describe_loader(loader_type):
    img2019, img2024, cd_mask = next(iter(loader_type))
    print("Batch size:", loader_type.batch_size)
    print("2019 Image Shape:", img2019.shape)
    print("2024 Image Shape:", img2024.shape)
    print("Change Mask Shape:", cd_mask.shape)
    print("Number of images:", len(loader_type.dataset))
    print("Classes:", loader_type.dataset.classes)
    print("Unique CD values:", torch.unique(cd_mask))

# Create datasets
train_dataset = ChangeDetectionDatasetTIF(
    t2019_dir=f"{ROOT_DIRECTORY}/train/Images/T2019",
    t2024_dir=f"{ROOT_DIRECTORY}/train/Images/T2024",
    mask_dir=f"{ROOT_DIRECTORY}/train/{CD_DIR}",
    classes=CLASSES
)

val_dataset = ChangeDetectionDatasetTIF(
    t2019_dir=f"{ROOT_DIRECTORY}/val/Images/T2019",
    t2024_dir=f"{ROOT_DIRECTORY}/val/Images/T2024",
    mask_dir=f"{ROOT_DIRECTORY}/val/{CD_DIR}",
    classes=CLASSES
)

test_dataset = ChangeDetectionDatasetTIF(
    t2019_dir=f"{ROOT_DIRECTORY}/test/Images/T2019",
    t2024_dir=f"{ROOT_DIRECTORY}/test/Images/T2024",
    mask_dir=f"{ROOT_DIRECTORY}/test/{CD_DIR}",
    classes=CLASSES
)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,num_workers=NUM_WORKERS, pin_memory=True)

print("------------Train-----------")
describe_loader(train_loader)
print("------------Val------------")
describe_loader(val_loader)
print("------------Test------------")
describe_loader(test_loader)

## Data Visualization

In [None]:
import matplotlib.pyplot as plt
import random

# Set up the plot size and remove axes
fig, axs = plt.subplots(5, 3, figsize=(10,10))

for i in range(5):
    j = random.randint(0, len(train_dataset) - 1)
    image1, image2, mask = train_dataset[j]

    # Display images
    axs[i, 0].imshow(image1.permute(1, 2, 0))
    axs[i, 0].set_title(f"Real 2019")
    axs[i, 0].axis("off")

    axs[i, 1].imshow(image2.permute(1, 2, 0))
    axs[i, 1].set_title(f"Real 2024")
    axs[i, 1].axis("off")

    axs[i, 2].imshow(mask, cmap="turbo")
    print(np.unique(mask))
    axs[i, 2].set_title(f"CD Mask")
    axs[i, 2].axis("off")

plt.show()

## Model Definition

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import ModuleList

class PAMBlock(nn.Module):
    """The basic implementation for self-attention block/non-local block"""
    def __init__(self, in_channels, key_channels, value_channels, scale=1, ds=1):
        super(PAMBlock, self).__init__()
        self.scale = scale
        self.ds = ds
        self.pool = nn.AvgPool2d(self.ds)
        self.in_channels = in_channels
        self.key_channels = key_channels
        self.value_channels = value_channels

        self.f_key = nn.Sequential(
            nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
                     kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(self.key_channels)
        )
        self.f_query = nn.Sequential(
            nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
                     kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(self.key_channels)
        )
        self.f_value = nn.Conv2d(in_channels=self.in_channels, out_channels=self.value_channels,
                                kernel_size=1, stride=1, padding=0)

    def forward(self, input):
        x = input
        if self.ds != 1:
            x = self.pool(input)

        batch_size, c, h, w = x.size(0), x.size(1), x.size(2), x.size(3) // 2

        local_y = []
        local_x = []
        step_h, step_w = h // self.scale, w // self.scale
        for i in range(0, self.scale):
            for j in range(0, self.scale):
                start_x, start_y = i * step_h, j * step_w
                end_x, end_y = min(start_x + step_h, h), min(start_y + step_w, w)
                if i == (self.scale - 1):
                    end_x = h
                if j == (self.scale - 1):
                    end_y = w
                local_x += [start_x, end_x]
                local_y += [start_y, end_y]

        value = self.f_value(x)
        query = self.f_query(x)
        key = self.f_key(x)

        value = torch.stack([value[:, :, :, :w], value[:, :, :, w:]], 4)
        query = torch.stack([query[:, :, :, :w], query[:, :, :, w:]], 4)
        key = torch.stack([key[:, :, :, :w], key[:, :, :, w:]], 4)

        local_block_cnt = 2 * self.scale * self.scale

        def self_attention(value_local, query_local, key_local):
            batch_size_new = value_local.size(0)
            h_local, w_local = value_local.size(2), value_local.size(3)
            value_local = value_local.contiguous().view(batch_size_new, self.value_channels, -1)

            query_local = query_local.contiguous().view(batch_size_new, self.key_channels, -1)
            query_local = query_local.permute(0, 2, 1)
            key_local = key_local.contiguous().view(batch_size_new, self.key_channels, -1)

            sim_map = torch.bmm(query_local, key_local)
            sim_map = (self.key_channels ** -.5) * sim_map
            sim_map = F.softmax(sim_map, dim=-1)

            context_local = torch.bmm(value_local, sim_map.permute(0, 2, 1))
            context_local = context_local.view(batch_size_new, self.value_channels, h_local, w_local, 2)
            return context_local

        v_list = [value[:, :, local_x[i]:local_x[i + 1], local_y[i]:local_y[i + 1]] for i in range(0, local_block_cnt, 2)]
        v_locals = torch.cat(v_list, dim=0)
        q_list = [query[:, :, local_x[i]:local_x[i + 1], local_y[i]:local_y[i + 1]] for i in range(0, local_block_cnt, 2)]
        q_locals = torch.cat(q_list, dim=0)
        k_list = [key[:, :, local_x[i]:local_x[i + 1], local_y[i]:local_y[i + 1]] for i in range(0, local_block_cnt, 2)]
        k_locals = torch.cat(k_list, dim=0)
        context_locals = self_attention(v_locals, q_locals, k_locals)

        context_list = []
        for i in range(0, self.scale):
            row_tmp = []
            for j in range(0, self.scale):
                left = batch_size * (j + i * self.scale)
                right = batch_size * (j + i * self.scale) + batch_size
                tmp = context_locals[left:right]
                row_tmp.append(tmp)
            context_list.append(torch.cat(row_tmp, 3))

        context = torch.cat(context_list, 2)
        context = torch.cat([context[:, :, :, :, 0], context[:, :, :, :, 1]], 3)

        if self.ds != 1:
            context = F.interpolate(context, [h * self.ds, 2 * w * self.ds])

        return context

class PAM(nn.Module):
    """PAM (Position Attention Module)"""
    def __init__(self, in_channels, out_channels, sizes=([1]), ds=1):
        super(PAM, self).__init__()
        self.group = len(sizes)
        self.stages = []
        self.ds = ds
        self.value_channels = out_channels
        self.key_channels = out_channels // 8

        self.stages = nn.ModuleList(
            [PAMBlock(in_channels, self.key_channels, self.value_channels, size, self.ds)
             for size in sizes])
        self.conv_bn = nn.Sequential(
            nn.Conv2d(in_channels * self.group, out_channels, kernel_size=1, padding=0, bias=False),
        )

    def forward(self, feats):
        priors = [stage(feats) for stage in self.stages]
        context = []
        for i in range(0, len(priors)):
            context += [priors[i]]
        output = self.conv_bn(torch.cat(context, 1))
        return output

class BAM(nn.Module):
    """Basic self-attention module"""
    def __init__(self, in_dim, ds=8, activation=nn.ReLU):
        super(BAM, self).__init__()
        self.chanel_in = in_dim
        self.key_channel = self.chanel_in // 8
        self.activation = activation
        self.ds = ds
        self.pool = nn.AvgPool2d(self.ds)
        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, input):
        x = self.pool(input)
        m_batchsize, C, width, height = x.size()
        proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(m_batchsize, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        energy = (self.key_channel ** -.5) * energy
        attention = self.softmax(energy)
        proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(m_batchsize, C, width, height)
        out = F.interpolate(out, [width * self.ds, height * self.ds])
        out = out + input
        return out

class CDSA(nn.Module):
    """Self attention module for change detection"""
    def __init__(self, in_c, ds=1, mode='BAM'):
        super(CDSA, self).__init__()
        self.in_C = in_c
        self.ds = ds
        self.mode = mode
        if self.mode == 'BAM':
            self.Self_Att = BAM(self.in_C, ds=self.ds)
        elif self.mode == 'PAM':
            self.Self_Att = PAM(in_channels=self.in_C, out_channels=self.in_C, sizes=[1, 2, 4, 8], ds=self.ds)
        elif self.mode == 'None':
            self.Self_Att = nn.Identity()

    def forward(self, x1, x2):
        height = x1.shape[3]
        x = torch.cat((x1, x2), 3)
        x = self.Self_Att(x)
        return x[:, :, :, 0:height], x[:, :, :, height:]

class STANet(nn.Module):
    """STANet for multiclass change detection"""
    def __init__(self, input_channels=3, hidden_channels=32, num_cd_classes=3, attention_mode='BAM'):
        super(STANet, self).__init__()

        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.num_cd_classes = num_cd_classes
        self.attention_mode = attention_mode

        # Encoder/Backbone layers
        self.conv1 = nn.Conv2d(input_channels, hidden_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(hidden_channels)
        self.conv2 = nn.Conv2d(hidden_channels, hidden_channels*2, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(hidden_channels*2)
        self.conv3 = nn.Conv2d(hidden_channels*2, hidden_channels*4, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(hidden_channels*4)

        # Self-attention module
        self.sa = CDSA(in_c=hidden_channels*4, ds=1, mode=attention_mode)

        # Decoder layers
        self.upconv1 = nn.ConvTranspose2d(hidden_channels*8, hidden_channels*4, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.upbn1 = nn.BatchNorm2d(hidden_channels*4)
        self.upconv2 = nn.ConvTranspose2d(hidden_channels*4, hidden_channels*2, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.upbn2 = nn.BatchNorm2d(hidden_channels*2)
        self.upconv3 = nn.ConvTranspose2d(hidden_channels*2, hidden_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.upbn3 = nn.BatchNorm2d(hidden_channels)

        # Final classification layer
        self.final_conv = nn.Conv2d(hidden_channels, num_cd_classes, kernel_size=1)
        self.softmax = nn.Softmax(dim=1)

    def encode(self, x):
        # Encoder path
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.max_pool2d(x, 2)
        return x

    def decode(self, x):
        # Decoder path
        x = F.relu(self.upbn1(self.upconv1(x)))
        x = F.relu(self.upbn2(self.upconv2(x)))
        x = F.relu(self.upbn3(self.upconv3(x)))
        return x

    def forward(self, x1, x2):
        # Encode both images
        feat1 = self.encode(x1)
        feat2 = self.encode(x2)

        # Apply self-attention
        att1, att2 = self.sa(feat1, feat2)

        # Concatenate attended features
        combined = torch.cat([att1, att2], dim=1)

        # Decode
        decoded = self.decode(combined)

        # Final classification
        out = self.final_conv(decoded)

        # Only apply softmax during inference
        if not self.training:
            out = self.softmax(out)

        return out

def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

## Util Functions and Training Loop

In [None]:
from sklearn.metrics import confusion_matrix
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os
from tqdm import tqdm
import json

def calculate_effective_weights(train_loader, device, num_cd_classes=3, method='square_balanced'):
    """Calculate class weights with different strategies to handle class imbalance

    Args:
        train_loader: DataLoader containing training data
        device: torch device
        num_cd_classes: number of classes (default: 3)
        method: weighting strategy ('balanced', 'square_balanced', or 'custom')
    """
    class_counts = torch.zeros(num_cd_classes)
    total_pixels = 0

    # Count class frequencies
    for _, _, labels in train_loader:
        labels = labels.to(device)
        for i in range(num_cd_classes):
            class_counts[i] += (labels == i).sum().item()
        total_pixels += labels.numel()

    class_frequencies = class_counts / total_pixels

    if method == 'balanced':
        # Standard balanced weighting (inverse frequency)
        weights = 1.0 / class_frequencies

    elif method == 'square_balanced':
        # Square root of inverse frequencies (less aggressive balancing)
        weights = torch.sqrt(1.0 / class_frequencies)

    elif method == 'custom':
        # Custom weighting that maintains some natural class distribution
        # Adjust these factors based on your domain knowledge
        base_weights = 1.0 / class_frequencies
        adjustment_factors = torch.tensor([0.7, 1.2, 1.2])
        weights = base_weights * adjustment_factors

    # Normalize weights to sum to num_cd_classes
    weights = weights * (num_cd_classes / weights.sum())

    return weights, class_frequencies

def calculate_metrics(outputs, labels, num_cd_classes=3, weighted_metrics=False):
    """
    Calculate comprehensive metrics for change detection using a single confusion matrix

    Args:
        outputs (torch.Tensor or np.array): Model outputs or predictions
        labels (torch.Tensor or np.array): Ground truth class labels
        num_cd_classes (int): Number of classes in the dataset

    Returns:
        list: List of overall performance metrics
    """

    # Convert to numpy if inputs are torch tensors
    if torch.is_tensor(outputs):
        predictions = torch.argmax(outputs, dim=1).cpu().numpy()
    else:
        predictions = outputs

    if torch.is_tensor(labels):
        labels = labels.cpu().numpy()

    # Flatten predictions and targets
    pred_flat = predictions.flatten()
    target_flat = labels.flatten()

    # Compute confusion matrix once
    cm = confusion_matrix(target_flat, pred_flat, labels=list(range(num_cd_classes)))

    # Calculate metrics from confusion matrix
    metrics = {}

    # True positives, false positives, false negatives for each class
    tp = np.diag(cm)
    fp = np.sum(cm, axis=0) - tp
    fn = np.sum(cm, axis=1) - tp

    # Overall accuracy from confusion matrix
    metrics['accuracy'] = np.sum(tp) / np.sum(cm)

    # Per-class precision, recall, F1
    precision = tp / (tp + fp + 1e-6)
    recall = tp / (tp + fn + 1e-6)
    f1 = 2 * (precision * recall) / (precision + recall + 1e-6)

    if weighted_metrics == False:
        # Unweighted averages
        metrics['precision'] = np.average(precision)
        metrics['recall'] = np.average(recall)
        metrics['f1_score'] = np.average(f1)
    elif weighted_metrics == True:
        # Weighted averages
        total = np.sum(cm, axis=1)
        metrics['precision'] = np.average(precision,weights=total)
        metrics['recall'] = np.average(recall, weights=total)
        metrics['f1_score'] = np.average(f1,weights=total)

    # Calculate Kappa directly from confusion matrix
    n = np.sum(cm)
    sum_po = np.sum(np.diag(cm))
    sum_pe = np.sum(np.sum(cm, axis=0) * np.sum(cm, axis=1)) / n
    metrics['kappa'] = (sum_po - sum_pe) / (n - sum_pe + 1e-6)

    # IoU from confusion matrix
    iou_per_class = tp / (tp + fp + fn + 1e-6)
    metrics['miou'] = np.mean(iou_per_class)

    return metrics


def train_model(model, train_loader, val_loader, num_epochs=50, num_cd_classes=3,
                device='cuda', learning_rate=1e-4, weight_decay=0.01,
                checkpoint_path='best_stanet_model.pt', 
                weighting_method='square_balanced', weighted_metrics=False):
    """
    Training function for STANet model with comprehensive metrics tracking.

    Args:
        model: STANet model instance
        train_loader: Training data loader
        val_loader: Validation data loader
        num_epochs: Number of training epochs
        num_cd_classes: Number of classes for change detection
        device: Device to run training on
        learning_rate: Initial learning rate
        weight_decay: Weight decay for optimizer
        checkpoint_path: Path to save best model checkpoint
    """
    # Initialize starting values
    start_epoch = 0
    best_val_loss = float('inf')

    # Initialize metrics history
    history = {
        'train': {
            'loss': [], 'accuracy': [], 'precision': [],
            'recall': [], 'f1_score': [], 'miou': [], 'kappa': []
        },
        'val': {
            'loss': [], 'accuracy': [], 'precision': [],
            'recall': [], 'f1_score': [], 'miou': [], 'kappa': []
        }
    }

    # Load checkpoint if exists
    if os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}")
        try:
            checkpoint = torch.load(checkpoint_path)
            model.load_state_dict(checkpoint['model_state_dict'])
            start_epoch = checkpoint['epoch']
            best_val_loss = checkpoint['best_val_loss']
            history = checkpoint['history']
            print(f"Resuming from epoch {start_epoch} with best val loss: {best_val_loss:.4f}")
        except Exception as e:
            print(f"Error loading checkpoint: {e}")
            print("Starting training from scratch")

    # Setup optimizer and losses
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    class_weights, _ = calculate_effective_weights(train_loader, device, num_cd_classes=num_cd_classes, method=weighting_method)
    print(class_weights)
    criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )

    # Move model to device
    model = model.to(device)

    def process_epoch(phase, data_loader):
        """Process one epoch of training or validation"""
        if phase == 'train':
            model.train()
        else:
            model.eval()

        running_metrics = {
            'loss': 0.0, 'accuracy': 0.0, 'precision': 0.0,
            'recall': 0.0, 'f1_score': 0.0, 'miou': 0.0, 'kappa': 0.0
        }
        samples_count = 0

        # Use tqdm for progress bar
        pbar = tqdm(data_loader, desc=f'{phase.capitalize()} Epoch')

        with torch.set_grad_enabled(phase == 'train'):
            for inputs1, inputs2, labels in pbar:
                # Move data to device
                inputs1 = inputs1.to(device)
                inputs2 = inputs2.to(device)
                labels = labels.to(device)
                batch_size = inputs1.size(0)

                # Zero gradients for training
                if phase == 'train':
                    optimizer.zero_grad()

                # Forward pass
                outputs = model(inputs1, inputs2)
                loss = criterion(outputs, labels)

                # Backward pass for training
                if phase == 'train':
                    loss.backward()
                    # Gradient clipping
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    optimizer.step()

                # Calculate metrics
                batch_metrics = calculate_metrics(outputs, labels, num_cd_classes=num_cd_classes, 
                                                  weighted_metrics=weighted_metrics)
                batch_metrics['loss'] = loss.item()

                # Update running metrics
                for key in running_metrics:
                    running_metrics[key] += batch_metrics[key] * batch_size
                samples_count += batch_size

                # Update progress bar
                pbar.set_postfix({
                    'loss': f"{batch_metrics['loss']:.4f}",
                    'miou': f"{batch_metrics['miou']:.4f}"
                })

        # Calculate epoch metrics
        epoch_metrics = {key: value / samples_count for key, value in running_metrics.items()}

        # Store metrics in history
        for key in history[phase]:
            history[phase][key].append(epoch_metrics[key])

        return epoch_metrics

    # Training loop
    for epoch in range(start_epoch, num_epochs):
        print(f'\nEpoch {epoch + 1}/{num_epochs}:')

        # Training phase
        train_metrics = process_epoch('train', train_loader)

        # Validation phase
        val_metrics = process_epoch('val', val_loader)

        # Print metrics
        def print_metrics(phase, metrics):
            print(f'\n{phase.capitalize()} Metrics:')
            print(f"  Loss: {metrics['loss']:.4f}")
            print(f"  Accuracy: {metrics['accuracy']:.4f}")
            print(f"  Precision: {metrics['precision']:.4f}")
            print(f"  Recall: {metrics['recall']:.4f}")
            print(f"  F1-score: {metrics['f1_score']:.4f}")
            print(f"  mIoU: {metrics['miou']:.4f}")
            print(f"  Kappa: {metrics['kappa']:.4f}")

        print_metrics('train', train_metrics)
        print_metrics('val', val_metrics)

        # Update learning rate scheduler
        scheduler.step(val_metrics['loss'])

        # Save checkpoint if it's the best model
        if val_metrics['loss'] < best_val_loss:
            best_val_loss = val_metrics['loss']
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_val_loss': best_val_loss,
                'metrics': val_metrics,
                'history': history
            }
            torch.save(checkpoint, checkpoint_path)
            print(f'\nSaved new best model with validation loss: {val_metrics["loss"]:.4f}')

    return model, history


def save_training_files(history, checkpoint_path, history_filename, bestepoch_filename):
    """Save training history and best epoch info to separate JSON files"""

    def convert_to_serializable(value):
        """Recursively convert numpy/torch types to basic Python types"""
        if isinstance(value, (np.ndarray, torch.Tensor)):
            return value.tolist()
        elif isinstance(value, dict):
            return {k: convert_to_serializable(v) for k, v in value.items()}
        elif isinstance(value, list):
            return [convert_to_serializable(item) for item in value]
        return value

    history_data = {
        phase: {
            metric: convert_to_serializable(values)
            for metric, values in metrics.items()
        }
        for phase, metrics in history.items()
    }

    with open(history_filename, 'w') as f:
        json.dump(history_data, f, indent=4)

    # Load checkpoint without weights_only flag
    checkpoint = torch.load(checkpoint_path)
    # print("\nCheckpoint contents:")
    # for key in checkpoint.keys():
    #     print(f"- {key}")

    # Convert metrics to basic Python types
    epoch_data = {
        'best_epoch': checkpoint['epoch'],
        'best_val_loss': checkpoint['best_val_loss'],
        'val_metrics': convert_to_serializable(checkpoint['metrics'])
    }

    with open(bestepoch_filename, 'w') as f:
        json.dump(epoch_data, f, indent=4)

    print(f"\nSaved training history to: {history_filename}")
    print(f"Saved best epoch info to: {bestepoch_filename}")


## Model Run

In [None]:
# Initialize and train the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_cd_classes = len(CLASSES)  #num classes in change mask

weighting_method = 'square_balanced'  #balanced,square_balanced,custom

weighted_metrics = True if num_cd_classes > 5 else False   #True for 13 classes,False for 3 classes

name = f"{MODEL_NAME}_{ATTENTION_MODE}-{num_cd_classes}_classes_{NUM_EPOCHS}"
checkpoint_path = f'{SAVING_DIR}/best_{name}.pt'

# Initialize model and data loaders
model = STANet(input_channels=3, hidden_channels=32, 
               num_cd_classes=num_cd_classes, 
               attention_mode=ATTENTION_MODE)

# Train model
model, history = train_model(
    model=model,
    train_loader=train_loader,val_loader=val_loader,
    num_epochs=NUM_EPOCHS,num_cd_classes=num_cd_classes,
    device=device,
    learning_rate=1e-4,weight_decay=0.01,
    checkpoint_path=checkpoint_path,
    weighting_method=weighting_method,
    weighted_metrics=weighted_metrics
)

history_filename = f"{SAVING_DIR}/{name}_history.json"
bestepoch_filename = f"{SAVING_DIR}/{name}_best_epoch.json"
save_training_files(history=history,checkpoint_path=checkpoint_path,
                    history_filename=history_filename,bestepoch_filename=bestepoch_filename)

## Model Testing

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import random

def test_model(model, test_loader, device='cuda',
               num_cd_classes=3, weighting_method='square_balanced',
               weighted_metrics=False, checkpoint_path='best_stanet_model.pt'):

    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    model = model.to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded checkpoint from {checkpoint_path}")
    model.eval()

    # Calculate class weights
    class_weights, _ = calculate_effective_weights(test_loader, device,
                                                   num_cd_classes=num_cd_classes,
                                                   method=weighting_method)
    print(f"Class weights: {class_weights}")

    criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))

    # For visualization and metrics
    random_samples = []
    total_loss = 0.0
    total_samples = 0

    # Collect predictions and labels for comprehensive metrics
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for inputs1, inputs2, labels in test_loader:
            inputs1 = inputs1.to(device) 
            inputs2 = inputs2.to(device) 
            labels = labels.to(device)

            # Forward pass
            outputs = model(inputs1, inputs2)
            loss = criterion(outputs, labels)

            # Accumulate loss
            total_loss += loss.item() * inputs1.size(0)
            total_samples += inputs1.size(0)

            # Get predictions
            preds = torch.argmax(outputs, dim=1)

            # Store predictions and labels
            all_predictions.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

            # Store random samples for visualization
            if len(random_samples) < 5:
                for i in range(min(inputs1.size(0), 5 - len(random_samples))):
                    if random.random() < 0.2:  # 20% chance to select each sample
                        random_samples.append({
                            'image1': inputs1[i].cpu(),
                            'image2': inputs2[i].cpu(),
                            'label': labels[i].cpu(),
                            'pred': preds[i].cpu(),
                            'probabilities': torch.softmax(outputs[i], dim=0).cpu()
                        })

    # Concatenate predictions and labels
    all_predictions = np.concatenate(all_predictions)
    all_labels = np.concatenate(all_labels)

    # Calculate metrics
    test_metrics = calculate_metrics(all_predictions, all_labels, num_cd_classes, 
                                     weighted_metrics=weighted_metrics)

    # Add loss to metrics
    test_metrics['loss'] = total_loss / total_samples

    # Make sure we have exactly 5 samples
    while len(random_samples) < 5:
        random_samples.append(random_samples[-1] if random_samples else {
            'image1': torch.zeros(3, 64, 64),
            'image2': torch.zeros(3, 64, 64),
            'label': torch.zeros(64, 64),
            'pred': torch.zeros(64, 64),
            'probabilities': torch.zeros(3, 64, 64)
        })

    return random_samples, test_metrics

def visualize_results(random_samples, num_cd_classes=3):
    # Extract samples and metrics
    # random_samples = random_samples_and_metrics[0]
    # test_metrics = random_samples_and_metrics[1]

    # Create a figure with subplots
    fig, axes = plt.subplots(5, 4, figsize=(25, 25))
    plt.subplots_adjust(hspace=0.3, wspace=0.3)

    for idx, sample in enumerate(random_samples):
        # Normalize and convert images for display
        img1 = sample['image1'].numpy().transpose(1, 2, 0)
        img2 = sample['image2'].numpy().transpose(1, 2, 0)
        img1 = (img1 - img1.min()) / (img1.max() - img1.min())
        img2 = (img2 - img2.min()) / (img2.max() - img2.min())

        # Get masks
        pred_mask = sample['pred'].numpy()
        true_mask = sample['label'].numpy()

        # Plot images and masks
        axes[idx, 0].imshow(img1)
        axes[idx, 0].set_title('Image 1')
        axes[idx, 0].axis('off')

        axes[idx, 1].imshow(img2)
        axes[idx, 1].set_title('Image 2')
        axes[idx, 1].axis('off')

        # Plot predicted mask
        pred_plot = axes[idx, 2].imshow(pred_mask, cmap='tab10', vmin=0, vmax=num_cd_classes-1)
        axes[idx, 2].set_title('Predicted Change')
        axes[idx, 2].axis('off')

        # Plot ground truth mask
        true_plot = axes[idx, 3].imshow(true_mask, cmap='tab10', vmin=0, vmax=num_cd_classes-1)
        axes[idx, 3].set_title('Ground Truth')
        axes[idx, 3].axis('off')

    plt.tight_layout()
    plt.show()

def save_test_metrics(test_metrics, save_path):
    """Save test metrics to JSON"""
    # Use the pre-computed metrics directly
    with open(save_path, 'w') as f:
        json.dump(test_metrics, f, indent=4)

    print(f"\nSaved test metrics to: {save_path}")

# Test the model
random_samples, test_metrics = test_model(model, test_loader, 
                                          device=device, num_cd_classes=num_cd_classes,
                                          weighting_method=weighting_method,
                                          weighted_metrics=weighted_metrics,
                                          checkpoint_path=checkpoint_path)

# Save test metrics
save_test_metrics(test_metrics=test_metrics,
                  save_path=f'{SAVING_DIR}/{name}_test_metrics.json')

# Visualize results
visualize_results(random_samples,num_cd_classes=num_cd_classes)