<a href="https://colab.research.google.com/github/caaszj/TF/blob/main/Untitled0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.io import loadmat
from torch.utils.data import Dataset, DataLoader
import math

# 数据预处理和加载
class HSIChangeDetectionDataset(Dataset):
    def __init__(self, before_path, after_path, gt_path, patch_size=9, mode='train'):
        self.before = self.load_mat(before_path)  # (H, W, C)
        self.after = self.load_mat(after_path)
        self.gt = self.load_gt(gt_path)  # (H, W)

        # 归一化
        self.before = (self.before - self.before.min()) / (self.before.max() - self.before.min())
        self.after = (self.after - self.after.min()) / (self.after.max() - self.after.min())

        self.patch_size = patch_size
        self.half = patch_size // 2
        self.mode = mode

        # 生成有效位置索引
        self.coords = self.get_valid_coords()

        # 划分训练/验证/测试集
        np.random.seed(42)  # 设置随机种子以确保可重复性
        idx = np.random.permutation(len(self.coords))  # 随机打乱索引
        train_num = int(len(idx) * 0.7)  # 训练集占70%
        val_num = int(len(idx) * 0.15)  # 验证集占15%

        if mode == 'train':
            self.coords = self.coords[idx[:train_num]]  # 取前70%作为训练集
        elif mode == 'val':
            self.coords = self.coords[idx[train_num:train_num + val_num]]  # 取接下来的15%作为验证集
        else:
            self.coords = self.coords[idx[train_num + val_num:]]  # 剩下的15%作为测试集

    def load_mat(self, path):
        mat = loadmat(path)
        if 'river_before' in path:
            return mat['river_before'].astype(np.float32)
        elif 'river_after' in path:
            return mat['river_after'].astype(np.float32)

    def load_gt(self, path):
        mat = loadmat(path)
        return mat['lakelabel_v1'].astype(np.int64)

    def get_valid_coords(self):
        H, W = self.gt.shape
        coords = []
        for i in range(self.half, H-self.half):
            for j in range(self.half, W-self.half):
                if self.gt[i, j] in [0, 255]:  # 只处理有效像素
                    coords.append((i, j))
        return np.array(coords)

    def __len__(self):
        return len(self.coords)

    def __getitem__(self, idx):
        i, j = self.coords[idx]
        # 提取双时相patch
        before_patch = self.before[i-self.half:i+self.half+1, j-self.half:j+self.half+1, :]
        after_patch = self.after[i-self.half:i+self.half+1, j-self.half:j+self.half+1, :]
        # 转为CHW格式
        before_patch = torch.from_numpy(before_patch).permute(2,0,1).float()
        after_patch = torch.from_numpy(after_patch).permute(2,0,1).float()
        label = self.gt[i, j] // 255  # 将255映射为1

        return before_patch, after_patch, label

# GLAM模块实现

class GLAM(nn.Module):
    def __init__(self, dim, num_heads=8, window_size=3):
        super().__init__()
        self.num_heads = num_heads
        self.split_dim = dim // 2  # 明确分割维度
        self.num_heads_local = num_heads // 2
        self.head_dim_local = self.split_dim // self.num_heads_local  # 正确计算头维度
        self.window_size = window_size

        # Local分支
        self.local_qkv = nn.Conv2d(self.split_dim, self.split_dim*3, kernel_size=1)

        # Global分支
        self.global_q = nn.Conv2d(self.split_dim, self.split_dim, kernel_size=1)
        self.global_kv = nn.Conv2d(self.split_dim, self.split_dim*2, kernel_size=1)

        self.proj = nn.Conv2d(dim, dim, kernel_size=1)

    def forward(self, x):
        B, C, H, W = x.shape
        x_local, x_global = torch.split(x, self.split_dim, dim=1)

        # Local Attention
        qkv = self.local_qkv(x_local)  # [B, 3*split_dim, H, W]
        q, k, v = torch.chunk(qkv, 3, dim=1)  # 各为split_dim
        q = q.view(B, self.num_heads_local, self.head_dim_local, H*W)
        k = k.view(B, self.num_heads_local, self.head_dim_local, H*W)
        v = v.view(B, self.num_heads_local, self.head_dim_local, H*W)

        attn = (q.transpose(-2, -1) @ k) / math.sqrt(self.head_dim_local)
        local_out = (v @ attn.softmax(dim=-1)).view(B, self.split_dim, H, W)

        # Global Attention
        x_pool = F.avg_pool2d(x_global, self.window_size)
        pool_H, pool_W = x_pool.shape[2], x_pool.shape[3]
        kv = self.global_kv(x_pool)
        k_g, v_g = torch.chunk(kv, 2, dim=1)

        q_g = self.global_q(x_global).view(B, self.num_heads_local, self.head_dim_local, H*W)
        k_g = k_g.view(B, self.num_heads_local, self.head_dim_local, pool_H*pool_W)
        v_g = v_g.view(B, self.num_heads_local, self.head_dim_local, pool_H*pool_W)

        attn_g = (q_g.transpose(-2, -1) @ k_g) / math.sqrt(self.head_dim_local)
        global_out = (v_g @ attn_g.softmax(dim=-1).transpose(-2, -1)).view(B, self.split_dim, H, W)

        return self.proj(torch.cat([local_out, global_out], dim=1))

# CGFN模块实现
class CGFN(nn.Module):
    def __init__(self, dim, expansion=4):
        super().__init__()
        hidden_dim = dim * expansion

        self.conv1 = nn.Conv2d(dim, hidden_dim, 1)
        self.dwconv3 = nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1, groups=hidden_dim)
        self.dwconv5 = nn.Conv2d(hidden_dim, hidden_dim, 5, padding=2, groups=hidden_dim)
        self.conv2 = nn.Conv2d(hidden_dim*2, dim, 1)

        self.gate1 = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim, 1),
            nn.GELU(),
            nn.Conv2d(hidden_dim, hidden_dim, 1),
            nn.Sigmoid()
        )
        self.gate2 = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim, 1),
            nn.GELU(),
            nn.Conv2d(hidden_dim, hidden_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv1(x)
        x1 = self.dwconv3(x)
        x2 = self.dwconv5(x)

        g1 = self.gate1(x1)
        g2 = self.gate2(x2)

        x1 = x1 * g1 + x2
        x2 = x2 * g2 + x1
        x = torch.cat([x1, x2], dim=1)
        return self.conv2(x)

# GLAFormer Block
class GLAFormerBlock(nn.Module):
    def __init__(self, dim, num_heads, window_size=3):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.glam = GLAM(dim, num_heads, window_size)
        self.norm2 = nn.LayerNorm(dim)
        self.cgfn = CGFN(dim)

    def forward(self, x):
        B, C, H, W = x.shape
        # GLAM
        x = x + self.glam(self.norm1(x.permute(0,2,3,1)).permute(0,3,1,2))
        # CGFN
        x = x + self.cgfn(self.norm2(x.permute(0,2,3,1)).permute(0,3,1,2))
        return x

# 完整模型

class GLAFormer(nn.Module):
    def __init__(self, in_channels=198, dim=256, num_blocks=4, num_heads=8, patch_size=9):
        super().__init__()
        self.patch_size = patch_size
        self.center = patch_size // 2  # 预计算中心位置索引

        # 输入嵌入层
        self.embed = nn.Sequential(
            nn.Conv2d(in_channels*2, dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(dim),
            nn.ReLU(inplace=True)
        )

        # 主干网络
        self.blocks = nn.Sequential(
            *[GLAFormerBlock(dim, num_heads) for _ in range(num_blocks)]
        )

        # 修正后的分类头
        self.conv_block = nn.Sequential(
            nn.Conv2d(dim, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 2, kernel_size=1)  # 输出通道数为2
        )
        self.global_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten()
        )
        self.center_extract = LambdaLayer(lambda x: x[:, :, self.center, self.center])

    def forward(self, x1, x2):
        # 双时相特征融合
        x = torch.cat([x1, x2], dim=1)  # [B, 396, 9, 9]
        x = self.embed(x)               # [B, 256, 9, 9]
        x = self.blocks(x)              # [B, 256, 9, 9]

        x = self.conv_block(x)          # [B, 2, 9, 9]

        # 并行处理两种特征提取方式
        global_feat = self.global_pool(x)  # [B, 2]
        center_feat = self.center_extract(x)  # [B, 2]

        # 特征融合（加权平均）
        return 0.6 * global_feat + 0.4 * center_feat  # [B, 2]

# 辅助模块
class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super().__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)

# 训练配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GLAFormer().to(device)
criterion = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 10.0]).to(device))  # 处理类别不平衡
optimizer = torch.optim.Adam(model.parameters(), lr=0.0006)

# 数据加载

train_dataset = HSIChangeDetectionDataset(
    '/content/drive/MyDrive/dataset zuixin/river_before.mat',
    '/content/drive/MyDrive/dataset zuixin/river_after.mat',
    '/content/drive/MyDrive/dataset zuixin/groundtruth.mat',
    mode='train'
)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

val_dataset = HSIChangeDetectionDataset(
    '/content/drive/MyDrive/dataset zuixin/river_before.mat',
    '/content/drive/MyDrive/dataset zuixin/river_after.mat',
    '/content/drive/MyDrive/dataset zuixin/groundtruth.mat',
    mode='val'
)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=True)

test_dataset = HSIChangeDetectionDataset(
    '/content/drive/MyDrive/dataset zuixin/river_before.mat',
    '/content/drive/MyDrive/dataset zuixin/river_after.mat',
    '/content/drive/MyDrive/dataset zuixin/groundtruth.mat',
    mode='test'
)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=True)




# 训练循环
num_epochs = 50
best_val_loss = float('inf')

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    for before_patch, after_patch, labels in train_loader:
        before_patch = before_patch.to(device)
        after_patch = after_patch.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(before_patch, after_patch)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * before_patch.size(0)

    train_loss /= len(train_loader.dataset)

    # 验证阶段
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for before_patch, after_patch, labels in val_loader:
            before_patch = before_patch.to(device)
            after_patch = after_patch.to(device)
            labels = labels.to(device)

            outputs = model(before_patch, after_patch)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * before_patch.size(0)

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss /= len(val_loader.dataset)
    val_accuracy = 100 * correct / total

    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%')

    # 保存最佳模型
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pth')

# 测试阶段


model.load_state_dict(torch.load('best_model.pth'))
model.eval()
test_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
    for before_patch, after_patch, labels in test_loader:
        before_patch = before_patch.to(device)
        after_patch = after_patch.to(device)
        labels = labels.to(device)

        outputs = model(before_patch, after_patch)
        loss = criterion(outputs, labels)
        test_loss += loss.item() * before_patch.size(0)

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_loss /= len(test_loader.dataset)
test_accuracy = 100 * correct / total
print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_accuracy:.2f}%')

Epoch 1/50, Train Loss: 0.2152, Val Loss: 0.2109, Val Acc: 85.08%
Epoch 2/50, Train Loss: 0.1364, Val Loss: 0.1745, Val Acc: 90.22%
