<a href="https://colab.research.google.com/github/caaszj/GLAFormer/blob/main/Untitled1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.io import loadmat
from torch.utils.data import Dataset, DataLoader
import math
import os
from sklearn.metrics import precision_score, recall_score, f1_score

# 数据预处理和加载
class HSIChangeDetectionDataset(Dataset):
    def __init__(self, before_path, after_path, gt_path, patch_size=9, mode='train'):
        self.before = self.load_mat(before_path)  # (H, W, C)
        self.after = self.load_mat(after_path)
        self.gt = self.load_gt(gt_path)  # (H, W)

        # 归一化
        self.before = (self.before - self.before.min()) / (self.before.max() - self.before.min())
        self.after = (self.after - self.after.min()) / (self.after.max() - self.after.min())

        self.patch_size = patch_size
        self.half = patch_size // 2
        self.mode = mode

        # 生成有效位置索引
        self.coords = self.get_valid_coords()

        # 划分训练/验证/测试集
        np.random.seed(42)  # 设置随机种子以确保可重复性
        idx = np.random.permutation(len(self.coords))  # 随机打乱索引
        train_num = int(len(idx) * 0.7)  # 训练集占70%
        val_num = int(len(idx) * 0.15)  # 验证集占15%

        if mode == 'train':
            self.coords = self.coords[idx[:train_num]]  # 取前70%作为训练集
        elif mode == 'val':
            self.coords = self.coords[idx[train_num:train_num + val_num]]  # 取接下来的15%作为验证集
        else:
            self.coords = self.coords[idx[train_num + val_num:]]  # 剩下的15%作为测试集

    def load_mat(self, path):
        mat = loadmat(path)
        # 获取.mat文件中的第一个键，假设它是数据键
        keys = [k for k in mat.keys() if not k.startswith('__')]
        if 'river_before' in mat:
            return mat['river_before'].astype(np.float32)
        elif 'river_after' in mat:
            return mat['river_after'].astype(np.float32)
        elif len(keys) > 0:
            return mat[keys[0]].astype(np.float32)
        else:
            raise ValueError(f"无法从{path}加载数据")

    def load_gt(self, path):
        mat = loadmat(path)
        keys = [k for k in mat.keys() if not k.startswith('__')]
        if 'lakelabel_v1' in mat:
            return mat['lakelabel_v1'].astype(np.int64)
        elif len(keys) > 0:
            return mat[keys[0]].astype(np.int64)
        else:
            raise ValueError(f"无法从{path}加载标签")

    def get_valid_coords(self):
        H, W = self.gt.shape
        coords = []
        for i in range(self.half, H-self.half):
            for j in range(self.half, W-self.half):
                if self.gt[i, j] in [0, 255]:  # 只处理有效像素
                    coords.append((i, j))
        return np.array(coords)

    def __len__(self):
        return len(self.coords)

    def __getitem__(self, idx):
        i, j = self.coords[idx]
        # 提取双时相patch
        before_patch = self.before[i-self.half:i+self.half+1, j-self.half:j+self.half+1, :]
        after_patch = self.after[i-self.half:i+self.half+1, j-self.half:j+self.half+1, :]
        # 转为CHW格式
        before_patch = torch.from_numpy(before_patch).permute(2,0,1).float()
        after_patch = torch.from_numpy(after_patch).permute(2,0,1).float()
        label = self.gt[i, j] // 255  # 将255映射为1

        return before_patch, after_patch, label

# GLAM模块实现
class GLAM(nn.Module):
    def __init__(self, dim, num_heads=8, window_size=3):
        super().__init__()
        self.num_heads = num_heads
        self.split_dim = dim // 2  # 明确分割维度
        self.num_heads_local = num_heads // 2
        self.head_dim_local = self.split_dim // self.num_heads_local  # 正确计算头维度
        self.window_size = window_size

        # Local分支
        self.local_qkv = nn.Conv2d(self.split_dim, self.split_dim*3, kernel_size=1)

        # Global分支
        self.global_q = nn.Conv2d(self.split_dim, self.split_dim, kernel_size=1)
        self.global_kv = nn.Conv2d(self.split_dim, self.split_dim*2, kernel_size=1)

        self.proj = nn.Conv2d(dim, dim, kernel_size=1)

    def forward(self, x):
        B, C, H, W = x.shape
        x_local, x_global = torch.split(x, self.split_dim, dim=1)

        # Local Attention
        qkv = self.local_qkv(x_local)  # [B, 3*split_dim, H, W]
        q, k, v = torch.chunk(qkv, 3, dim=1)  # 各为split_dim
        q = q.view(B, self.num_heads_local, self.head_dim_local, H*W)
        k = k.view(B, self.num_heads_local, self.head_dim_local, H*W)
        v = v.view(B, self.num_heads_local, self.head_dim_local, H*W)

        attn = (q.transpose(-2, -1) @ k) / math.sqrt(self.head_dim_local)
        local_out = (v @ attn.softmax(dim=-1).transpose(-2, -1)).view(B, self.split_dim, H, W)

        # Global Attention
        x_pool = F.avg_pool2d(x_global, self.window_size)
        pool_H, pool_W = x_pool.shape[2], x_pool.shape[3]
        kv = self.global_kv(x_pool)
        k_g, v_g = torch.chunk(kv, 2, dim=1)

        q_g = self.global_q(x_global).view(B, self.num_heads_local, self.head_dim_local, H*W)
        k_g = k_g.view(B, self.num_heads_local, self.head_dim_local, pool_H*pool_W)
        v_g = v_g.view(B, self.num_heads_local, self.head_dim_local, pool_H*pool_W)

        attn_g = (q_g.transpose(-2, -1) @ k_g) / math.sqrt(self.head_dim_local)
        global_out = (v_g @ attn_g.softmax(dim=-1).transpose(-2, -1)).view(B, self.split_dim, H, W)

        return self.proj(torch.cat([local_out, global_out], dim=1))

# CGFN模块实现
class CGFN(nn.Module):
    def __init__(self, dim, expansion=4):
        super().__init__()
        hidden_dim = dim * expansion

        self.conv1 = nn.Conv2d(dim, hidden_dim, 1)
        self.dwconv3 = nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1, groups=hidden_dim)
        self.dwconv5 = nn.Conv2d(hidden_dim, hidden_dim, 5, padding=2, groups=hidden_dim)
        self.conv2 = nn.Conv2d(hidden_dim*2, dim, 1)

        self.gate1 = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim, 1),
            nn.GELU(),
            nn.Conv2d(hidden_dim, hidden_dim, 1),
            nn.Sigmoid()
        )
        self.gate2 = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim, 1),
            nn.GELU(),
            nn.Conv2d(hidden_dim, hidden_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv1(x)
        x1 = self.dwconv3(x)
        x2 = self.dwconv5(x)

        g1 = self.gate1(x1)
        g2 = self.gate2(x2)

        x1 = x1 * g1 + x2
        x2 = x2 * g2 + x1
        x = torch.cat([x1, x2], dim=1)
        return self.conv2(x)

# 辅助模块
class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super().__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)

# GLAFormer Block
class GLAFormerBlock(nn.Module):
    def __init__(self, dim, num_heads, window_size=3):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.glam = GLAM(dim, num_heads, window_size)
        self.norm2 = nn.LayerNorm(dim)
        self.cgfn = CGFN(dim)

    def forward(self, x):
        # 注意这里使用LayerNorm，需要先调整维度(B,C,H,W)->(B,H,W,C)
        x_norm = self.norm1(x.permute(0,2,3,1)).permute(0,3,1,2)
        x = x + self.glam(x_norm)

        x_norm = self.norm2(x.permute(0,2,3,1)).permute(0,3,1,2)
        x = x + self.cgfn(x_norm)
        return x

# 完整模型
class GLAFormer(nn.Module):
    def __init__(self, in_channels=198, dim=256, num_blocks=4, num_heads=8, patch_size=9):
        super().__init__()
        self.patch_size = patch_size
        self.center = patch_size // 2  # 预计算中心位置索引

        # 输入嵌入层
        self.embed = nn.Sequential(
            nn.Conv2d(in_channels*2, dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(dim),
            nn.ReLU(inplace=True)
        )

        # 主干网络
        self.blocks = nn.Sequential(
            *[GLAFormerBlock(dim, num_heads) for _ in range(num_blocks)]
        )

        # 修正后的分类头
        self.conv_block = nn.Sequential(
            nn.Conv2d(dim, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 2, kernel_size=1)  # 输出通道数为2
        )
        self.global_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten()
        )
        self.center_extract = LambdaLayer(lambda x: x[:, :, self.center, self.center])

    def forward(self, x1, x2):
        # 双时相特征融合
        x = torch.cat([x1, x2], dim=1)  # [B, 2*in_channels, patch_size, patch_size]
        x = self.embed(x)               # [B, dim, patch_size, patch_size]
        x = self.blocks(x)              # [B, dim, patch_size, patch_size]

        x = self.conv_block(x)          # [B, 2, patch_size, patch_size]

        # 并行处理两种特征提取方式
        global_feat = self.global_pool(x)  # [B, 2]
        center_feat = self.center_extract(x)  # [B, 2]

        # 特征融合（加权平均）
        return 0.6 * global_feat + 0.4 * center_feat  # [B, 2]

# 评估函数
def evaluate(model, data_loader, criterion, device):
    model.eval()
    val_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for before_patch, after_patch, labels in data_loader:
            before_patch = before_patch.to(device)
            after_patch = after_patch.to(device)
            labels = labels.to(device)

            outputs = model(before_patch, after_patch)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * before_patch.size(0)

            _, predicted = torch.max(outputs.data, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    val_loss /= len(data_loader.dataset)
    accuracy = np.mean(np.array(all_preds) == np.array(all_labels)) * 100
    precision = precision_score(all_labels, all_preds, average='binary', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='binary', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='binary', zero_division=0)

    return val_loss, accuracy, precision, recall, f1

def main():
    # 设置随机种子以确保可重复性
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)
    np.random.seed(42)

    # 配置参数
    in_channels = 198  # 实际高光谱数据的波段数，需要根据具体数据集调整
    patch_size = 9
    batch_size = 512
    num_epochs = 50
    learning_rate = 0.0006
    weight_decay = 1e-4
    model_save_path = 'best_model.pth'

    # 检查CUDA可用性
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # 加载数据集
    try:
        train_dataset = HSIChangeDetectionDataset(
            '/content/drive/MyDrive/dataset zuixin/river_before.mat',
            '/content/drive/MyDrive/dataset zuixin/river_after.mat',
            '/content/drive/MyDrive/dataset zuixin/groundtruth.mat',
            patch_size=patch_size,
            mode='train'
        )

        val_dataset = HSIChangeDetectionDataset(
            '/content/drive/MyDrive/dataset zuixin/river_before.mat',
            '/content/drive/MyDrive/dataset zuixin/river_after.mat',
            '/content/drive/MyDrive/dataset zuixin/groundtruth.mat',
            patch_size=patch_size,
            mode='val'
        )

        test_dataset = HSIChangeDetectionDataset(
            '/content/drive/MyDrive/dataset zuixin/river_before.mat',
            '/content/drive/MyDrive/dataset zuixin/river_after.mat',
            '/content/drive/MyDrive/dataset zuixin/groundtruth.mat',
            patch_size=patch_size,
            mode='test'
        )

        # 检查数据集维度
        sample = train_dataset[0]
        in_channels = sample[0].shape[0]  # 动态获取输入通道数
        print(f"检测到输入通道数: {in_channels}")
        print(f"数据样本形状: before={sample[0].shape}, after={sample[1].shape}, label={sample[2]}")

        # 创建数据加载器
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

        print(f"训练样本数: {len(train_dataset)}")
        print(f"验证样本数: {len(val_dataset)}")
        print(f"测试样本数: {len(test_dataset)}")

    except Exception as e:
        print(f"加载数据集时出错: {e}")
        return

    # 初始化模型
    model = GLAFormer(in_channels=in_channels, dim=256, num_blocks=4, num_heads=8, patch_size=patch_size).to(device)

    # 类别不平衡处理
    # 根据数据集中变化和未变化像素的比例调整权重
    # 假设变化像素较少，给予更高的权重
    criterion = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 10.0]).to(device))

    # 优化器和学习率调度器
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

    # 训练循环
    best_val_loss = float('inf')
    best_f1 = 0.0
    patience = 10  # 早停耐心值
    counter = 0    # 早停计数器

    for epoch in range(num_epochs):
        # 训练阶段
        model.train()
        train_loss = 0.0
        for before_patch, after_patch, labels in train_loader:
            before_patch = before_patch.to(device)
            after_patch = after_patch.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(before_patch, after_patch)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * before_patch.size(0)

        train_loss /= len(train_loader.dataset)

        # 验证阶段
        val_loss, val_accuracy, val_precision, val_recall, val_f1 = evaluate(model, val_loader, criterion, device)

        # 学习率调整
        scheduler.step(val_loss)

        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'  Train Loss: {train_loss:.4f}')
        print(f'  Val Loss: {val_loss:.4f}, Acc: {val_accuracy:.2f}%, Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, F1: {val_f1:.4f}')

        # 保存最佳模型（基于验证损失或F1分数）
        if val_f1 > best_f1:
            best_f1 = val_f1
            counter = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'val_f1': val_f1,
            }, model_save_path)
            print(f'  模型已保存: val_f1 从 {best_f1-val_f1:.4f} 提升到 {val_f1:.4f}')
        else:
            counter += 1
            print(f'  F1未提升: {counter}/{patience}')

        # 早停
        if counter >= patience:
            print(f'早停: 验证F1已经{patience}个epoch没有提升')
            break

    # 测试阶段
    # 加载最佳模型
    checkpoint = torch.load(model_save_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"加载最佳模型（epoch {checkpoint['epoch']+1}，验证F1: {checkpoint['val_f1']:.4f}）")

    # 在测试集上评估
    test_loss, test_accuracy, test_precision, test_recall, test_f1 = evaluate(model, test_loader, criterion, device)

    print("\n最终测试结果:")
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_accuracy:.2f}%")
    print(f"Test Precision: {test_precision:.4f}")
    print(f"Test Recall: {test_recall:.4f}")
    print(f"Test F1 Score: {test_f1:.4f}")

if __name__ == "__main__":
    main()

Using device: cuda
检测到输入通道数: 198
数据样本形状: before=torch.Size([198, 9, 9]), after=torch.Size([198, 9, 9]), label=0
训练样本数: 74210
验证样本数: 15902
测试样本数: 15903
Epoch 1/50:
  Train Loss: 0.2447
  Val Loss: 0.1485, Acc: 91.96%, Precision: 0.5211, Recall: 0.9770, F1: 0.6797
  模型已保存: val_f1 从 0.0000 提升到 0.6797
Epoch 2/50:
  Train Loss: 0.1359
  Val Loss: 0.1479, Acc: 96.23%, Precision: 0.7242, Recall: 0.9186, F1: 0.8099
  模型已保存: val_f1 从 0.0000 提升到 0.8099
