<a href="https://colab.research.google.com/github/caaszj/GLAFormer/blob/main/Untitled3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.io import loadmat
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import math
import os
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
from collections import Counter

# ===== 数据预处理和加载 =====
class HSIChangeDetectionDataset(Dataset):
    def __init__(self, before_path, after_path, gt_path, patch_size=9, mode='train',
                 augment=False, mixup_prob=0.5, mixup_alpha=0.2):
        self.before = self.load_mat(before_path)  # (H, W, C)
        self.after = self.load_mat(after_path)
        self.gt = self.load_gt(gt_path)  # (H, W)

        # 检查标签值
        unique_labels = np.unique(self.gt)
        print(f"标签中的唯一值: {unique_labels}")

        # 特征归一化
        self.before = self._normalize(self.before)
        self.after = self._normalize(self.after)

        # 数据增强设置
        self.augment = augment and mode == 'train'
        self.mixup_prob = mixup_prob
        self.mixup_alpha = mixup_alpha

        self.patch_size = patch_size
        self.half = patch_size // 2
        self.mode = mode

        # 生成有效位置索引
        self.coords = self.get_valid_coords()

        # 计算类别分布
        self.class_counts = self._get_class_distribution()
        print(f"类别分布 - 未变化: {self.class_counts[0]}, 变化: {self.class_counts[1]}")

        # 划分训练/验证/测试集，使用分层抽样保持类别分布
        np.random.seed(42)  # 设置随机种子以确保可重复性

        # 分离变化和未变化样本索引
        change_indices = [i for i, (r, c) in enumerate(self.coords) if self.gt[r, c] == 255]
        no_change_indices = [i for i, (r, c) in enumerate(self.coords) if self.gt[r, c] == 0]

        # 随机打乱两类索引
        np.random.shuffle(change_indices)
        np.random.shuffle(no_change_indices)

        # 划分比例
        train_ratio, val_ratio = 0.7, 0.15

        # 计算各集合的样本数量
        train_change_num = int(len(change_indices) * train_ratio)
        val_change_num = int(len(change_indices) * val_ratio)

        train_no_change_num = int(len(no_change_indices) * train_ratio)
        val_no_change_num = int(len(no_change_indices) * val_ratio)

        # 划分样本
        if mode == 'train':
            change_idx = change_indices[:train_change_num]
            no_change_idx = no_change_indices[:train_no_change_num]
            self.selected_indices = change_idx + no_change_idx
        elif mode == 'val':
            change_idx = change_indices[train_change_num:train_change_num+val_change_num]
            no_change_idx = no_change_indices[train_no_change_num:train_no_change_num+val_no_change_num]
            self.selected_indices = change_idx + no_change_idx
        else:  # test
            change_idx = change_indices[train_change_num+val_change_num:]
            no_change_idx = no_change_indices[train_no_change_num+val_no_change_num:]
            self.selected_indices = change_idx + no_change_idx

        # 计算当前集合的类别分布
        self.subset_class_counts = self._get_subset_distribution()
        print(f"{mode} 集合 - 未变化: {self.subset_class_counts[0]}, 变化: {self.subset_class_counts[1]}")

    def _normalize(self, data):
        """逐波段归一化，更稳定"""
        normalized_data = np.zeros_like(data, dtype=np.float32)
        for i in range(data.shape[2]):  # 遍历每个波段
            band = data[:,:,i]
            band_min, band_max = band.min(), band.max()
            if band_max > band_min:
                normalized_data[:,:,i] = (band - band_min) / (band_max - band_min)
        return normalized_data

    def _get_class_distribution(self):
        """计算变化/非变化类别的样本数量"""
        class_counts = [0, 0]
        for i, j in self.coords:
            label = 1 if self.gt[i, j] == 255 else 0
            class_counts[label] += 1
        return class_counts

    def _get_subset_distribution(self):
        """计算当前子集的类别分布"""
        class_counts = [0, 0]
        for idx in self.selected_indices:
            i, j = self.coords[idx]
            label = 1 if self.gt[i, j] == 255 else 0
            class_counts[label] += 1
        return class_counts

    def load_mat(self, path):
        mat = loadmat(path)
        # 获取.mat文件中的第一个键，假设它是数据键
        keys = [k for k in mat.keys() if not k.startswith('__')]
        if 'river_before' in mat:
            return mat['river_before'].astype(np.float32)
        elif 'river_after' in mat:
            return mat['river_after'].astype(np.float32)
        elif len(keys) > 0:
            return mat[keys[0]].astype(np.float32)
        else:
            raise ValueError(f"无法从{path}加载数据")

    def load_gt(self, path):
        mat = loadmat(path)
        keys = [k for k in mat.keys() if not k.startswith('__')]
        if 'lakelabel_v1' in mat:
            return mat['lakelabel_v1'].astype(np.int64)
        elif len(keys) > 0:
            return mat[keys[0]].astype(np.int64)
        else:
            raise ValueError(f"无法从{path}加载标签")

    def get_valid_coords(self):
        H, W = self.gt.shape
        coords = []
        for i in range(self.half, H-self.half):
            for j in range(self.half, W-self.half):
                if self.gt[i, j] == 0 or self.gt[i, j] == 255:  # 只处理有效像素
                    coords.append((i, j))
        return np.array(coords)

    def __len__(self):
        return len(self.selected_indices)

    def __getitem__(self, idx):
        """获取数据样本
        Returns:
            before_patch: 第一时相patch
            after_patch: 第二时相patch
            label: 标签
            mixup_params: 包含mixup参数的字典，如果没有使用mixup则为None
        """
        coord_idx = self.selected_indices[idx]
        i, j = self.coords[coord_idx]

        # 提取双时相patch
        before_patch = self.before[i-self.half:i+self.half+1, j-self.half:j+self.half+1, :].copy()
        after_patch = self.after[i-self.half:i+self.half+1, j-self.half:j+self.half+1, :].copy()

        # 数据增强
        if self.augment and np.random.random() < 0.5:
            before_patch, after_patch = self._augment_patches(before_patch, after_patch)

        # 转为CHW格式
        before_patch = torch.from_numpy(before_patch.copy()).permute(2,0,1).float()
        after_patch = torch.from_numpy(after_patch.copy()).permute(2,0,1).float()
        label = 1 if self.gt[i, j] == 255 else 0  # 确保标签为0或1

        # MixUp增强 (仅在训练时)
        mixup_params = None
        if self.augment and np.random.random() < self.mixup_prob:
            # 随机选择另一个样本
            other_idx = np.random.randint(0, len(self.selected_indices))
            other_coord_idx = self.selected_indices[other_idx]
            other_i, other_j = self.coords[other_coord_idx]

            other_before = self.before[other_i-self.half:other_i+self.half+1, other_j-self.half:other_j+self.half+1, :].copy()
            other_after = self.after[other_i-self.half:other_i+self.half+1, other_j-self.half:other_j+self.half+1, :].copy()

            other_before = torch.from_numpy(other_before.copy()).permute(2,0,1).float()
            other_after = torch.from_numpy(other_after.copy()).permute(2,0,1).float()
            other_label = 1 if self.gt[other_i, other_j] == 255 else 0

            # 生成混合系数
            lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)

            # 混合样本
            before_patch = lam * before_patch + (1 - lam) * other_before
            after_patch = lam * after_patch + (1 - lam) * other_after

            # 保存mixup参数
            mixup_params = {
                'lam': lam,
                'other_label': other_label
            }

        return before_patch, after_patch, label, mixup_params

    def _augment_patches(self, before_patch, after_patch):
        """数据增强：随机旋转和翻转"""
        k = np.random.randint(0, 4)  # 0-3的随机数，表示旋转次数

        # 随机旋转
        before_patch = np.rot90(before_patch, k=k, axes=(0, 1))
        after_patch = np.rot90(after_patch, k=k, axes=(0, 1))

        # 随机翻转
        if np.random.random() < 0.5:
            before_patch = np.flip(before_patch, axis=0)
            after_patch = np.flip(after_patch, axis=0)

        if np.random.random() < 0.5:
            before_patch = np.flip(before_patch, axis=1)
            after_patch = np.flip(after_patch, axis=1)

        return before_patch, after_patch

# ===== 损失函数 =====
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha  # 权重因子
        self.gamma = gamma  # 聚焦因子
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha)
        pt = torch.exp(-ce_loss)  # 预测的概率
        focal_loss = (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# MixUp损失处理
def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

# ===== 注意力模块 =====
class GLAM(nn.Module):
    """改进的全局-局部注意力模块，支持多尺度特征融合"""
    def __init__(self, dim, num_heads=8, window_size=3, attn_drop=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.attn_drop = attn_drop

        # 确保维度可以被头数整除
        self.split_dim = dim // 2
        self.num_heads_local = max(1, num_heads // 2)  # 至少1个头
        # 确保head_dim_local是整数
        self.head_dim_local = self.split_dim // self.num_heads_local
        self.split_dim = self.head_dim_local * self.num_heads_local  # 重新计算以确保可整除

        self.window_size = window_size

        # Local分支
        self.local_qkv = nn.Conv2d(self.split_dim, self.split_dim*3, kernel_size=1)
        self.local_dropout = nn.Dropout(attn_drop)

        # Global分支 - 多尺度特征融合
        self.global_q = nn.Conv2d(self.split_dim, self.split_dim, kernel_size=1)
        self.global_kv1 = nn.Conv2d(self.split_dim, self.split_dim*2, kernel_size=1)  # 原始尺度
        self.global_kv2 = nn.Conv2d(self.split_dim, self.split_dim*2, kernel_size=1)  # 下采样2倍
        self.global_kv3 = nn.Conv2d(self.split_dim, self.split_dim*2, kernel_size=1)  # 下采样4倍
        self.global_dropout = nn.Dropout(attn_drop)

        # 融合不同尺度的注意力结果
        self.scale_fusion = nn.Conv2d(self.split_dim*3, self.split_dim, kernel_size=1)

        # 确保投影层维度正确
        self.proj = nn.Conv2d(self.split_dim*2, dim, kernel_size=1)
        self.proj_drop = nn.Dropout(0.1)

    def forward(self, x):
        B, C, H, W = x.shape
        # 分割通道维度
        x_split = torch.split(x, [self.split_dim, C - self.split_dim], dim=1)
        x_local = x_split[0]
        x_global = x_split[0] if len(x_split) == 1 else x_split[1]

        # 保证x_global的维度与split_dim一致
        if x_global.shape[1] != self.split_dim:
            x_global = F.adaptive_avg_pool2d(x_global, (H, W))
            x_global = F.interpolate(x_global, size=(H, W), mode='bilinear', align_corners=False)
            if x_global.shape[1] > self.split_dim:
                x_global = x_global[:, :self.split_dim, :, :]
            elif x_global.shape[1] < self.split_dim:
                padding = self.split_dim - x_global.shape[1]
                x_global = torch.cat([x_global, torch.zeros(B, padding, H, W, device=x.device)], dim=1)

        # Local Attention
        qkv = self.local_qkv(x_local)  # [B, 3*split_dim, H, W]
        qkv = qkv.reshape(B, 3, self.num_heads_local, self.head_dim_local, H, W)
        q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]  # [B, num_heads_local, head_dim_local, H, W]

        # 重塑张量以计算注意力
        q = q.reshape(B, self.num_heads_local, self.head_dim_local, H*W)
        k = k.reshape(B, self.num_heads_local, self.head_dim_local, H*W)
        v = v.reshape(B, self.num_heads_local, self.head_dim_local, H*W)

        # 计算注意力分数 - 修正维度顺序
        attn = torch.matmul(q.transpose(-2, -1), k) / math.sqrt(self.head_dim_local)  # [B, num_heads_local, H*W, H*W]
        attn = F.softmax(attn, dim=-1)
        attn = self.local_dropout(attn)

        # 应用注意力
        local_out = torch.matmul(v, attn.transpose(-2, -1))  # [B, num_heads_local, head_dim_local, H*W]
        local_out = local_out.reshape(B, self.split_dim, H, W)

        # Global Attention - 多尺度特征
        q_g = self.global_q(x_global)  # [B, split_dim, H, W]

        # 原始尺度
        kv1 = self.global_kv1(x_global)  # [B, 2*split_dim, H, W]

        # 下采样2倍
        x_pool2 = F.avg_pool2d(x_global, 2, stride=2)
        kv2 = self.global_kv2(x_pool2)  # [B, 2*split_dim, H/2, W/2]

        # 下采样4倍
        x_pool4 = F.avg_pool2d(x_global, 4, stride=4)
        kv3 = self.global_kv3(x_pool4)  # [B, 2*split_dim, H/4, W/4]

        # 处理不同尺度
        global_out1 = self._process_scale(q_g, kv1, H, W, "scale1")
        global_out2 = self._process_scale(q_g, kv2, H//2, W//2, "scale2")
        global_out3 = self._process_scale(q_g, kv3, H//4, W//4, "scale3")

        # 融合不同尺度的结果
        multi_scale_out = torch.cat([global_out1, global_out2, global_out3], dim=1)
        global_out = self.scale_fusion(multi_scale_out)  # [B, split_dim, H, W]

        # 连接并投影
        out = torch.cat([local_out, global_out], dim=1)
        out = self.proj(out)
        out = self.proj_drop(out)
        return out

    def _process_scale(self, q, kv, h, w, scale_name):
        B = q.shape[0]

        # 分离kv
        kv = kv.reshape(B, 2, self.num_heads_local, self.head_dim_local, h, w)
        k, v = kv[:, 0], kv[:, 1]  # [B, num_heads_local, head_dim_local, h, w]

        # 重塑查询
        q_g = q.reshape(B, self.num_heads_local, self.head_dim_local, q.shape[2], q.shape[3])
        q_g = q_g.reshape(B, self.num_heads_local, self.head_dim_local, -1)  # [B, num_heads_local, head_dim_local, H*W]

        # 重塑键值
        k_g = k.reshape(B, self.num_heads_local, self.head_dim_local, -1)  # [B, num_heads_local, head_dim_local, h*w]
        v_g = v.reshape(B, self.num_heads_local, self.head_dim_local, -1)  # [B, num_heads_local, head_dim_local, h*w]

        # 计算注意力分数
        attn_g = torch.matmul(q_g.transpose(-2, -1), k_g) / math.sqrt(self.head_dim_local)  # [B, num_heads_local, H*W, h*w]
        attn_g = F.softmax(attn_g, dim=-1)
        attn_g = self.global_dropout(attn_g)

        # 应用注意力
        out = torch.matmul(v_g, attn_g.transpose(-2, -1))  # [B, num_heads_local, head_dim_local, H*W]
        out = out.reshape(B, self.split_dim, q.shape[2], q.shape[3])  # [B, split_dim, H, W]

        return out

# 跨时相注意力模块
class CrossTemporalAttention(nn.Module):
    def __init__(self, dim, num_heads=8, attn_drop=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        # 双向变换投影
        self.t1_to_t2_proj = nn.Linear(dim, dim)
        self.t2_to_t1_proj = nn.Linear(dim, dim)

        # QKV投影
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)

        self.attn_drop = nn.Dropout(attn_drop)
        self.out_proj = nn.Linear(dim, dim)

    def forward(self, x1, x2):
        """
        输入:
            x1: 第一时相特征 [B, C, H, W]
            x2: 第二时相特征 [B, C, H, W]
        """
        B, C, H, W = x1.shape

        # 重塑为序列形式
        x1_flat = x1.flatten(2).transpose(1, 2)  # [B, H*W, C]
        x2_flat = x2.flatten(2).transpose(1, 2)  # [B, H*W, C]

        # 计算从t1到t2的注意力
        t1_to_t2 = self._compute_cross_attention(
            self.t1_to_t2_proj(x1_flat),  # 查询
            x2_flat,  # 键
            x2_flat   # 值
        )

        # 计算从t2到t1的注意力
        t2_to_t1 = self._compute_cross_attention(
            self.t2_to_t1_proj(x2_flat),  # 查询
            x1_flat,  # 键
            x1_flat   # 值
        )

        # 重塑回空间形式
        t1_to_t2 = t1_to_t2.transpose(1, 2).reshape(B, C, H, W)
        t2_to_t1 = t2_to_t1.transpose(1, 2).reshape(B, C, H, W)

        # 返回增强后的双时相特征
        return t1_to_t2, t2_to_t1

    def _compute_cross_attention(self, q_input, k_input, v_input):
        B, L, C = q_input.shape

        # 投影QKV
        q = self.q_proj(q_input).reshape(B, L, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  # [B, num_heads, L, head_dim]
        k = self.k_proj(k_input).reshape(B, L, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  # [B, num_heads, L, head_dim]
        v = self.v_proj(v_input).reshape(B, L, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  # [B, num_heads, L, head_dim]

        # 计算注意力分数
        attn = (q @ k.transpose(-2, -1)) * self.scale  # [B, num_heads, L, L]
        attn = F.softmax(attn, dim=-1)
        attn = self.attn_drop(attn)

        # 应用注意力
        out = (attn @ v).transpose(1, 2).reshape(B, L, C)  # [B, L, C]
        out = self.out_proj(out)

        return out

# 改进的通道门控融合网络
class CGFN(nn.Module):
    def __init__(self, dim, expansion=4, dropout=0.1):
        super().__init__()
        hidden_dim = dim * expansion

        # 增加Squeeze-and-Excitation模块
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(dim, dim // 16, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim // 16, dim, 1),
            nn.Sigmoid()
        )

        # 增加1x1卷积降维
        self.conv1 = nn.Conv2d(dim, hidden_dim, 1)

        # 多分支深度可分离卷积
        self.dwconv3 = nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1, groups=hidden_dim)
        self.dwconv5 = nn.Conv2d(hidden_dim, hidden_dim, 5, padding=2, groups=hidden_dim)
        self.dwconv7 = nn.Conv2d(hidden_dim, hidden_dim, 7, padding=3, groups=hidden_dim)

        # 批归一化
        self.norm3 = nn.BatchNorm2d(hidden_dim)
        self.norm5 = nn.BatchNorm2d(hidden_dim)
        self.norm7 = nn.BatchNorm2d(hidden_dim)

        # 门控机制改进
        self.gate1 = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim, 1),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Conv2d(hidden_dim, hidden_dim, 1),
            nn.Sigmoid()
        )
        self.gate2 = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim, 1),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Conv2d(hidden_dim, hidden_dim, 1),
            nn.Sigmoid()
        )
        self.gate3 = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim, 1),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Conv2d(hidden_dim, hidden_dim, 1),
            nn.Sigmoid()
        )

        # 添加跳跃连接
        self.skip_conn = nn.Conv2d(dim, dim, 1)

        # 最终投影层
        self.conv2 = nn.Conv2d(hidden_dim*3, dim, 1)
        self.dropout = nn.Dropout(dropout)
        self.final_norm = nn.BatchNorm2d(dim)

    def forward(self, x):
        # Squeeze-and-Excitation注意力
        se_weight = self.se(x)
        x = x * se_weight

        # 保存残差连接
        residual = self.skip_conn(x)

        # 1x1卷积升维
        x = self.conv1(x)

        # 多分支深度可分离卷积
        x1 = self.dwconv3(x)
        x1 = self.norm3(x1)
        x2 = self.dwconv5(x)
        x2 = self.norm5(x2)
        x3 = self.dwconv7(x)
        x3 = self.norm7(x3)

        # 应用门控机制
        g1 = self.gate1(x1)
        g2 = self.gate2(x2)
        g3 = self.gate3(x3)

        # 门控融合
        x1 = x1 * g1
        x2 = x2 * g2
        x3 = x3 * g3

        # 特征融合
        x = torch.cat([x1, x2, x3], dim=1)
        x = self.conv2(x)
        x = self.dropout(x)
        x = self.final_norm(x)

        # 添加残差连接
        return x + residual

# 多尺度金字塔融合模块
class MultiScalePyramidFusion(nn.Module):
    def __init__(self, dim):
        super().__init__()

        # 多尺度特征提取
        self.scale1_conv = nn.Conv2d(dim, dim//4, kernel_size=3, padding=1)
        self.scale2_conv = nn.Conv2d(dim, dim//4, kernel_size=3, padding=1)
        self.scale3_conv = nn.Conv2d(dim, dim//4, kernel_size=3, padding=1)
        self.scale4_conv = nn.Conv2d(dim, dim//4, kernel_size=3, padding=1)

        # 注意力融合
        self.attn_fusion = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=1),
            nn.BatchNorm2d(dim),
            nn.Sigmoid()
        )

        # 最终融合
        self.fusion_conv = nn.Conv2d(dim, dim, kernel_size=1)
        self.norm = nn.BatchNorm2d(dim)
        self.act = nn.GELU()

    def forward(self, x):
        # 多尺度下采样
        x_scale1 = self.scale1_conv(x)  # 原始尺度

        # 下采样到1/2
        x_down2 = F.avg_pool2d(x, kernel_size=2, stride=2)
        x_scale2 = self.scale2_conv(x_down2)
        x_scale2 = F.interpolate(x_scale2, size=x.shape[2:], mode='bilinear', align_corners=False)

        # 下采样到1/4
        x_down4 = F.avg_pool2d(x, kernel_size=4, stride=4)
        x_scale3 = self.scale3_conv(x_down4)
        x_scale3 = F.interpolate(x_scale3, size=x.shape[2:], mode='bilinear', align_corners=False)

        # 下采样到1/8
        x_down8 = F.avg_pool2d(x, kernel_size=8, stride=8)
        x_scale4 = self.scale4_conv(x_down8)
        x_scale4 = F.interpolate(x_scale4, size=x.shape[2:], mode='bilinear', align_corners=False)

        # 特征拼接
        multi_scale_feat = torch.cat([x_scale1, x_scale2, x_scale3, x_scale4], dim=1)

        # 注意力融合
        attn = self.attn_fusion(multi_scale_feat)
        fused = multi_scale_feat * attn

        # 最终融合
        out = self.fusion_conv(fused)
        out = self.norm(out)
        out = self.act(out)

        return out + x  # 残差连接

# GLAFormer Block - 改进版
class GLAFormerBlock(nn.Module):
    def __init__(self, dim, num_heads, window_size=3, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.glam = GLAM(dim, num_heads, window_size, attn_drop=dropout)
        self.norm2 = nn.LayerNorm(dim)
        self.cgfn = CGFN(dim, dropout=dropout)
        self.dropout = nn.Dropout(dropout)

        # 添加随机特征噪声
        self.feature_noise = 0.05
        self.apply_noise = True

    def forward(self, x):
        # 注意这里使用LayerNorm，需要先调整维度(B,C,H,W)->(B,H,W,C)
        B, C, H, W = x.shape
        x_perm = x.permute(0, 2, 3, 1)  # [B, H, W, C]
        x_norm = self.norm1(x_perm).permute(0, 3, 1, 2)  # 归一化后转回[B, C, H, W]

        # GLAM注意力
        glam_out = self.glam(x_norm)
        x = x + self.dropout(glam_out)

        # 在训练时添加随机特征噪声，提高鲁棒性
        if self.training and self.apply_noise:
            noise = torch.randn_like(x) * self.feature_noise
            x = x + noise

        # CGFN模块
        x_perm = x.permute(0, 2, 3, 1)
        x_norm = self.norm2(x_perm).permute(0, 3, 1, 2)
        x = x + self.dropout(self.cgfn(x_norm))

        return x

# 时空双流注意力融合模块
class DualStreamFusion(nn.Module):
    def __init__(self, dim, num_heads=8):
        super().__init__()

        # 单时相特征增强
        self.t1_enhance = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim),
            nn.BatchNorm2d(dim),
            nn.GELU(),
            nn.Conv2d(dim, dim, kernel_size=1),
            nn.BatchNorm2d(dim)
        )

        self.t2_enhance = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim),
            nn.BatchNorm2d(dim),
            nn.GELU(),
            nn.Conv2d(dim, dim, kernel_size=1),
            nn.BatchNorm2d(dim)
        )

        # 跨时相注意力
        self.cross_attn = CrossTemporalAttention(dim, num_heads)

        # 特征融合
        self.fusion = nn.Sequential(
            nn.Conv2d(dim*2, dim, kernel_size=1),
            nn.BatchNorm2d(dim),
            nn.GELU()
        )

        # 添加多尺度金字塔融合
        self.pyramid = MultiScalePyramidFusion(dim)

    def forward(self, x1, x2):
        # 单时相特征增强
        x1_enh = self.t1_enhance(x1)
        x2_enh = self.t2_enhance(x2)

        # 跨时相注意力
        x1_cross, x2_cross = self.cross_attn(x1_enh, x2_enh)

        # 合并增强特征
        x1_final = x1_enh + x1_cross
        x2_final = x2_enh + x2_cross

        # 特征融合
        fused = self.fusion(torch.cat([x1_final, x2_final], dim=1))

        # 多尺度金字塔融合
        fused = self.pyramid(fused)

        return fused

# 完整模型 - 改进版GLAFormer
class GLAFormer(nn.Module):
    def __init__(self, in_channels=198, dim=256, num_blocks=4, num_heads=8, patch_size=9,
                 dropout=0.1, use_dual_stream=True):
        super().__init__()
        self.patch_size = patch_size
        self.dim = dim
        self.use_dual_stream = use_dual_stream

        # 动态计算中心位置，避免硬编码
        self.register_buffer('center', torch.tensor(patch_size // 2, dtype=torch.long))

        # 改进的输入嵌入层
        self.embed1 = nn.Sequential(
            nn.Conv2d(in_channels, dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )

        self.embed2 = nn.Sequential(
            nn.Conv2d(in_channels, dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )

        # 双时相特征融合模块
        if use_dual_stream:
            self.dual_stream = DualStreamFusion(dim, num_heads)
        else:
            self.concat_conv = nn.Conv2d(dim*2, dim, kernel_size=1)

        # 主干网络
        self.blocks = nn.Sequential(
            *[GLAFormerBlock(dim, num_heads, dropout=dropout) for _ in range(num_blocks)]
        )

        # 多层特征融合
        self.multi_level_fusion = nn.ModuleList([
            nn.Conv2d(dim, dim, kernel_size=1)
            for _ in range(num_blocks)
        ])

        # 分类头
        self.conv_block = nn.Sequential(
            nn.Conv2d(dim, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 2, kernel_size=1)
        )

        # 全局特征和中心特征提取
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))

        # 特征增强
        self.feat_enhance = nn.Sequential(
            nn.Linear(2, 32),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(32, 2)
        )

        # 自适应特征融合权重
        self.fusion_weights = nn.Parameter(torch.FloatTensor([0.6, 0.4]))

    def forward(self, x1, x2):
        # 检查输入
        if x1.dim() != 4 or x2.dim() != 4:
            raise ValueError(f"输入维度不正确: x1={x1.shape}, x2={x2.shape}")

        # 特征嵌入
        feat1 = self.embed1(x1)  # [B, dim, patch_size, patch_size]
        feat2 = self.embed2(x2)  # [B, dim, patch_size, patch_size]

        # 双时相特征融合
        if self.use_dual_stream:
            x = self.dual_stream(feat1, feat2)  # 高级双流融合
        else:
            x = self.concat_conv(torch.cat([feat1, feat2], dim=1))  # 简单连接

        # 多层特征融合
        multi_level_features = []
        for i, block in enumerate(self.blocks):
            x = block(x)
            multi_level_features.append(self.multi_level_fusion[i](x))

        # 融合多层特征
        if len(multi_level_features) > 1:
            # 加权融合多层特征
            fusion_weights = F.softmax(torch.randn(len(multi_level_features)), dim=0)
            for i, feat in enumerate(multi_level_features[1:], 1):
                multi_level_features[0] += fusion_weights[i] * feat
            x = multi_level_features[0]

        # 分类预测
        x = self.conv_block(x)  # [B, 2, patch_size, patch_size]

        # 提取全局特征
        global_feat = self.global_pool(x).flatten(1)  # [B, 2]

        # 提取中心特征 - 使用动态计算的中心位置
        center = self.center.item()
        center_feat = x[:, :, center, center]  # [B, 2]

        # 特征增强
        global_feat = self.feat_enhance(global_feat)
        center_feat = self.feat_enhance(center_feat)

        # 归一化融合权重
        fusion_weights = F.softmax(self.fusion_weights, dim=0)

        # 自适应特征融合
        output = fusion_weights[0] * global_feat + fusion_weights[1] * center_feat  # [B, 2]

        return output

# ===== 评估和训练函数 =====

# 计算混淆矩阵和详细指标
def compute_metrics(y_true, y_pred, threshold=0.5):
    """计算混淆矩阵和详细性能指标，支持不同决策阈值"""
    if isinstance(y_pred, torch.Tensor) and y_pred.shape[1] == 2:
        # 应用自定义阈值处理概率输出
        probs = F.softmax(y_pred, dim=1)[:,1]  # 获取正类概率
        y_pred_binary = (probs > threshold).long().cpu().numpy()
    else:
        # 如果已经是标签
        y_pred_binary = y_pred

    y_true = np.array(y_true)

    cm = confusion_matrix(y_true, y_pred_binary)
    try:
        tn, fp, fn, tp = cm.ravel()
    except ValueError:
        # 处理可能的单类情况
        if len(cm) == 1:
            if y_true[0] == 1:  # 所有都是正类
                tp = cm[0, 0]
                tn, fp, fn = 0, 0, 0
            else:  # 所有都是负类
                tn = cm[0, 0]
                tp, fp, fn = 0, 0, 0
        else:
            raise

    # 计算各种指标
    accuracy = (tp + tn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else 0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

    # 计算Kappa系数
    pe = ((tn + fp) * (tn + fn) + (fn + tp) * (fp + tp)) / ((tp + tn + fp + fn) ** 2) if (tp + tn + fp + fn) > 0 else 0
    kappa = (accuracy - pe) / (1 - pe) if (1 - pe) > 0 else 0

    # 计算IoU
    iou = tp / (tp + fp + fn) if (tp + fp + fn) > 0 else 0

    return {
        'confusion_matrix': cm,
        'accuracy': accuracy * 100,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'kappa': kappa,
        'iou': iou
    }

# 自适应寻找最佳阈值
def find_optimal_threshold(model, data_loader, device, num_thresholds=10):
    """搜索最佳决策阈值"""
    model.eval()
    all_probs = []
    all_labels = []

    with torch.no_grad():
        for before_patch, after_patch, labels, mixup_params in data_loader:
            before_patch = before_patch.to(device)
            after_patch = after_patch.to(device)

            outputs = model(before_patch, after_patch)
            probs = F.softmax(outputs, dim=1)[:, 1]  # 获取正类概率

            all_probs.extend(probs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # 测试不同阈值
    thresholds = np.linspace(0.1, 0.9, num_thresholds)
    best_f1 = 0
    best_threshold = 0.5

    for threshold in thresholds:
        y_pred = (np.array(all_probs) > threshold).astype(int)
        f1 = f1_score(all_labels, y_pred)

        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold

    print(f"找到最佳阈值: {best_threshold:.3f}，F1分数: {best_f1:.4f}")
    return best_threshold

# 评估函数
def evaluate(model, data_loader, criterion, device, threshold=0.5, mixup_eval=False):
    model.eval()
    val_loss = 0.0
    all_outputs = []
    all_labels = []

    with torch.no_grad():
        for batch in data_loader:
            if mixup_eval and len(batch) == 4:  # MixUp 数据
                before_patch, after_patch, labels, mixup_params = batch
                before_patch = before_patch.to(device)
                after_patch = after_patch.to(device)
                labels = labels.to(device)

                outputs = model(before_patch, after_patch)
                loss = criterion(outputs, labels)

            else:  # 标准数据
                before_patch, after_patch, labels, _ = batch
                before_patch = before_patch.to(device)
                after_patch = after_patch.to(device)
                labels = labels.to(device)

                outputs = model(before_patch, after_patch)
                loss = criterion(outputs, labels)

            val_loss += loss.item() * before_patch.size(0)
            all_outputs.append(outputs)
            all_labels.extend(labels.cpu().numpy())

    val_loss /= len(data_loader.dataset)
    all_outputs = torch.cat(all_outputs, dim=0)

    # 计算指标
    metrics = compute_metrics(all_labels, all_outputs, threshold)

    return val_loss, metrics, all_outputs, all_labels

# 学习率预热和余弦退火调度器
class WarmupCosineScheduler:
    def __init__(self, optimizer, warmup_epochs, total_epochs, min_lr=1e-6, warmup_method='linear'):
        self.optimizer = optimizer
        self.warmup_epochs = warmup_epochs
        self.total_epochs = total_epochs
        self.min_lr = min_lr
        self.warmup_method = warmup_method
        self.base_lrs = [group['lr'] for group in optimizer.param_groups]

    def step(self, epoch):
        if epoch < self.warmup_epochs:
            # 预热阶段
            if self.warmup_method == 'linear':
                # 线性预热
                lr_scale = epoch / self.warmup_epochs
            elif self.warmup_method == 'exp':
                # 指数预热
                lr_scale = (epoch / self.warmup_epochs) ** 2
            else:
                # 默认线性预热
                lr_scale = epoch / self.warmup_epochs
        else:
            # 余弦退火
            progress = (epoch - self.warmup_epochs) / (self.total_epochs - self.warmup_epochs)
            lr_scale = max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

        for i, group in enumerate(self.optimizer.param_groups):
            group['lr'] = self.base_lrs[i] * lr_scale + self.min_lr

def custom_collate_fn(batch):
    """自定义的collate函数，用于处理包含None的batch数据"""
    before_patches, after_patches, labels, mixup_params = zip(*batch)

    # 将patches和labels转换为张量
    before_patches = torch.stack(before_patches)
    after_patches = torch.stack(after_patches)
    labels = torch.tensor(labels)

    # 处理mixup参数
    if all(p is None for p in mixup_params):
        mixup_params = None
    else:
        valid_params = [p for p in mixup_params if p is not None]
        if valid_params:
            # 合并所有mixup参数
            mixup_params = {
                'lam': torch.tensor([p['lam'] for p in valid_params]),
                'other_label': torch.tensor([p['other_label'] for p in valid_params])
            }
        else:
            mixup_params = None

    return before_patches, after_patches, labels, mixup_params

def main():
    # 设置随机种子以确保可重复性
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)
    np.random.seed(42)

    # 配置参数
    in_channels = 198  # 高光谱数据波段数，根据数据集调整
    patch_size = 9
    batch_size = 256
    num_epochs = 100
    learning_rate = 0.0005
    weight_decay = 5e-5
    warmup_epochs = 5  # 学习率预热epochs
    model_save_path = 'best_glaformer_model.pth'
    dropout = 0.2
    use_dual_stream = True  # 是否使用双时相流模块
    use_focal_loss = True   # 是否使用Focal Loss
    use_mixup = True        # 是否使用MixUp
    use_sampler = True      # 是否使用加权采样

    # 检查CUDA可用性
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")

    # 加载数据集
    try:
        train_dataset = HSIChangeDetectionDataset(
            '/content/drive/MyDrive/dataset zuixin/river_before.mat',
            '/content/drive/MyDrive/dataset zuixin/river_after.mat',
            '/content/drive/MyDrive/dataset zuixin/groundtruth.mat',
            patch_size=patch_size,
            mode='train',
            augment=True
        )

        val_dataset = HSIChangeDetectionDataset(
            '/content/drive/MyDrive/dataset zuixin/river_before.mat',
            '/content/drive/MyDrive/dataset zuixin/river_after.mat',
            '/content/drive/MyDrive/dataset zuixin/groundtruth.mat',
            patch_size=patch_size,
            mode='val'
        )

        test_dataset = HSIChangeDetectionDataset(
            '/content/drive/MyDrive/dataset zuixin/river_before.mat',
            '/content/drive/MyDrive/dataset zuixin/river_after.mat',
            '/content/drive/MyDrive/dataset zuixin/groundtruth.mat',
            patch_size=patch_size,
            mode='test'
        )

        # 检查数据集维度
        sample = train_dataset[0]
        if isinstance(sample, tuple) and len(sample) >= 3:
            in_channels = sample[0].shape[0]  # 动态获取输入通道数
            print(f"检测到输入通道数: {in_channels}")
            print(f"数据样本形状: before={sample[0].shape}, after={sample[1].shape}, label={sample[2]}")

        # 创建加权采样器 - 处理类别不平衡
        if use_sampler:
            # 获取训练集的所有标签
            train_labels = [train_dataset[i][2] for i in range(len(train_dataset))]

            # 计算样本权重：少数类样本获得更高权重
            class_counts = Counter(train_labels)
            num_samples = len(train_labels)

            class_weights = {class_id: num_samples / (len(class_counts) * count)
                             for class_id, count in class_counts.items()}

            # 为每个样本分配权重
            sample_weights = [class_weights[label] for label in train_labels]
            sampler = WeightedRandomSampler(
                weights=sample_weights,
                num_samples=len(train_dataset),
                replacement=True
            )
            train_loader = DataLoader(train_dataset, batch_size=batch_size,
                                    sampler=sampler, num_workers=2, pin_memory=True, collate_fn=custom_collate_fn)
            print("使用加权采样器处理类别不平衡")
        else:
            train_loader = DataLoader(train_dataset, batch_size=batch_size,
                                    shuffle=True, num_workers=2, pin_memory=True, collate_fn=custom_collate_fn)

        val_loader = DataLoader(val_dataset, batch_size=batch_size,
                               shuffle=False, num_workers=2, pin_memory=True, collate_fn=custom_collate_fn)
        test_loader = DataLoader(test_dataset, batch_size=batch_size,
                                shuffle=False, num_workers=2, pin_memory=True, collate_fn=custom_collate_fn)

        print(f"训练样本数: {len(train_dataset)}")
        print(f"验证样本数: {len(val_dataset)}")
        print(f"测试样本数: {len(test_dataset)}")

    except Exception as e:
        print(f"加载数据集时出错: {e}")
        return

    # 初始化模型
    model = GLAFormer(
        in_channels=in_channels,
        dim=256,
        num_blocks=4,
        num_heads=8,
        patch_size=patch_size,
        dropout=dropout,
        use_dual_stream=use_dual_stream
    ).to(device)

    print(f"模型初始化完成，参数量: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

    # 动态设置类别权重 - 根据训练数据分布计算
    neg_weight = 1.0
    pos_weight = train_dataset.subset_class_counts[0] / max(1, train_dataset.subset_class_counts[1])  # 正样本权重
    class_weights = torch.tensor([neg_weight, pos_weight]).to(device)
    print(f"类别权重: {class_weights.cpu().numpy()}")

    # 选择损失函数
    if use_focal_loss:
        criterion = FocalLoss(alpha=class_weights, gamma=2.0)
        print("使用Focal Loss")
    else:
        criterion = nn.CrossEntropyLoss(weight=class_weights)
        print("使用加权交叉熵损失")

    # 优化器 - 使用AdamW而非Adam
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    # 学习率调度器 - 使用预热和余弦退火
    lr_scheduler = WarmupCosineScheduler(
        optimizer,
        warmup_epochs=warmup_epochs,
        total_epochs=num_epochs,
        warmup_method='exp'  # 使用指数预热而非线性预热
    )

    # 训练循环
    best_val_f1 = 0.0
    patience = 15  # 早停耐心值 - 增加以允许模型充分探索
    counter = 0    # 早停计数器
    best_threshold = 0.5  # 初始决策阈值

    for epoch in range(num_epochs):
        # 更新学习率
        lr_scheduler.step(epoch)
        current_lr = optimizer.param_groups[0]['lr']

        # 训练阶段
        model.train()
        train_loss = 0.0
        batch_count = 0

        for batch in train_loader:
            # 处理MixUp数据
            if use_mixup and len(batch) == 4:
                before_patch, after_patch, labels, mixup_params = batch
                before_patch = before_patch.to(device)
                after_patch = after_patch.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()
                outputs = model(before_patch, after_patch)
                loss = criterion(outputs, labels)

            else:  # 处理标准数据
                before_patch, after_patch, labels, _ = batch
                before_patch = before_patch.to(device)
                after_patch = after_patch.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()
                outputs = model(before_patch, after_patch)
                loss = criterion(outputs, labels)

            loss.backward()

            # 梯度裁剪，防止梯度爆炸
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()

            train_loss += loss.item() * before_patch.size(0)
            batch_count += 1

            # 打印进度
            if batch_count % 20 == 0:
                print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_count}/{len(train_loader)}, Loss: {loss.item():.4f}')

        train_loss /= len(train_loader.dataset)

        # 验证阶段
        val_loss, val_metrics, val_outputs, val_labels = evaluate(
            model, val_loader, criterion, device, threshold=best_threshold, mixup_eval=use_mixup
        )

        # 每5个epoch寻找最佳阈值
        if (epoch + 1) % 5 == 0:
            best_threshold = find_optimal_threshold(model, val_loader, device)
            # 使用新阈值重新评估
            _, val_metrics, _, _ = evaluate(
                model, val_loader, criterion, device, threshold=best_threshold, mixup_eval=use_mixup
            )

        print(f'Epoch {epoch+1}/{num_epochs}, LR: {current_lr:.6f}, 阈值: {best_threshold:.3f}')
        print(f'  Train Loss: {train_loss:.4f}')
        print(f'  Val Loss: {val_loss:.4f}, Acc: {val_metrics["accuracy"]:.2f}%, F1: {val_metrics["f1"]:.4f}, Kappa: {val_metrics["kappa"]:.4f}')
        print(f'  Val Confusion Matrix:\n{val_metrics["confusion_matrix"]}')

        # 保存最佳模型（基于F1分数）
        if val_metrics["f1"] > best_val_f1:
            best_val_f1 = val_metrics["f1"]
            counter = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'val_f1': val_metrics["f1"],
                'val_metrics': val_metrics,
                'threshold': best_threshold,
            }, model_save_path)
            print(f'  模型已保存: val_f1 从 {best_val_f1-val_metrics["f1"]:.4f} 提升到 {val_metrics["f1"]:.4f}')
        else:
            counter += 1
            print(f'  F1未提升: {counter}/{patience}')

        # 早停
        if counter >= patience:
            print(f'早停: 验证F1已经{patience}个epoch没有提升')
            break

    # 测试阶段
    # 加载最佳模型
    checkpoint = torch.load(model_save_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    best_threshold = checkpoint['threshold']
    print(f"加载最佳模型（epoch {checkpoint['epoch']+1}，验证F1: {checkpoint['val_f1']:.4f}，阈值: {best_threshold:.3f}）")

    # 在测试集上评估
    test_loss, test_metrics, _, _ = evaluate(
        model, test_loader, criterion, device, threshold=best_threshold, mixup_eval=False
    )

    print("\n最终测试结果:")
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_metrics['accuracy']:.2f}%")
    print(f"Test Precision: {test_metrics['precision']:.4f}")
    print(f"Test Recall: {test_metrics['recall']:.4f}")
    print(f"Test F1 Score: {test_metrics['f1']:.4f}")
    print(f"Test Kappa: {test_metrics['kappa']:.4f}")
    print(f"Test IoU: {test_metrics['iou']:.4f}")
    print(f"Test Confusion Matrix:\n{test_metrics['confusion_matrix']}")

if __name__ == "__main__":
    main()

使用设备: cuda
标签中的唯一值: [  0 255]
类别分布 - 未变化: 96694, 变化: 9321
train 集合 - 未变化: 67685, 变化: 6524
标签中的唯一值: [  0 255]
类别分布 - 未变化: 96694, 变化: 9321
val 集合 - 未变化: 14504, 变化: 1398
标签中的唯一值: [  0 255]
类别分布 - 未变化: 96694, 变化: 9321
test 集合 - 未变化: 14505, 变化: 1399
检测到输入通道数: 198
数据样本形状: before=torch.Size([198, 9, 9]), after=torch.Size([198, 9, 9]), label=1
使用加权采样器处理类别不平衡
训练样本数: 74209
验证样本数: 15902
测试样本数: 15904
模型初始化完成，参数量: 34,125,222
类别权重: [ 1.      10.37477]
使用Focal Loss
Epoch 1/100, Batch 20/290, Loss: 3.9334
Epoch 1/100, Batch 40/290, Loss: 3.7414
Epoch 1/100, Batch 60/290, Loss: 2.8498
Epoch 1/100, Batch 80/290, Loss: 3.5809
Epoch 1/100, Batch 100/290, Loss: 3.2674
Epoch 1/100, Batch 120/290, Loss: 3.1210
Epoch 1/100, Batch 140/290, Loss: 3.1948
Epoch 1/100, Batch 160/290, Loss: 2.8224
Epoch 1/100, Batch 180/290, Loss: 2.7188
Epoch 1/100, Batch 200/290, Loss: 2.6221
Epoch 1/100, Batch 220/290, Loss: 2.8675
Epoch 1/100, Batch 240/290, Loss: 2.4869
Epoch 1/100, Batch 260/290, Loss: 2.3072
Epoch 1/100, Bat