In [1]:
"""
=========================== 新架构说明 ===========================
双路径细粒度识别网络：
1. 全局路径：提取整体特征
2. 局部路径：自动发现2个判别性区域，提取局部特征
3. 自适应融合：动态加权融合全局+局部特征
4. 多粒度监督：主损失 + 辅助损失
"""

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import warnings

warnings.filterwarnings('ignore')

# ===================== 全局配置 =====================
MODEL_PATH = "best_flower_model.pth"
METRICS_PATH = "training_metrics.npy"
CLASS_NAMES = ["daisy", "dandelion", "rose", "sunflower", "tulip"]
NUM_CLASSES = len(CLASS_NAMES)
BATCH_SIZE = 32
EPOCHS = 20
LEARNING_RATE = 1e-4
IMAGE_SIZE = (224, 224)
NUM_REGIONS = 2  # 提取2个局部区域


# ===================== 新模型组件 =====================
class RegionProposalModule(nn.Module):
    """区域提议模块：生成注意力图，定位判别性区域"""
    def __init__(self, in_channels, num_regions):
        super().__init__()
        self.num_regions = num_regions
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, num_regions, kernel_size=1),
            nn.Softmax(dim=1)  # 每个位置的概率分布
        )
    
    def forward(self, x):
        # x: [B, C, H, W]
        attention_maps = self.conv(x)  # [B, num_regions, H, W]
        return attention_maps


class LocalFeatureExtractor(nn.Module):
    """局部特征提取器：从注意力图指定的区域提取特征"""
    def __init__(self, backbone, feature_dim):
        super().__init__()
        # 使用backbone的前几层作为局部特征提取器
        self.backbone = nn.Sequential(*list(backbone.children())[:-2])
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(feature_dim, 256)  # 将局部特征压缩
    
    def forward(self, x, attention_map):
        # x: [B, C, H, W], attention_map: [B, 1, H, W]
        # 应用注意力图（广播）
        weighted_x = x * attention_map.unsqueeze(1)  # [B, C, H, W]
        
        # 提取局部特征
        features = self.backbone(weighted_x)
        pooled = self.global_pool(features).squeeze(-1).squeeze(-1)  # [B, feature_dim]
        compressed = self.fc(pooled)  # [B, 256]
        return compressed


class AdaptiveFusionModule(nn.Module):
    """自适应融合模块：动态加权融合全局和局部特征"""
    def __init__(self, global_dim, local_dim, num_regions):
        super().__init__()
        total_dim = global_dim + local_dim * num_regions
        self.weight_generator = nn.Sequential(
            nn.Linear(total_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 1 + num_regions),  # 全局权重 + 各局部权重
            nn.Softmax(dim=1)  # 权重和为1
        )
    
    def forward(self, global_feat, local_feats):
        # global_feat: [B, global_dim]
        # local_feats: list of [B, local_dim]
        batch_size = global_feat.size(0)
        
        # 拼接所有特征
        all_feats = [global_feat] + local_feats
        concat_feats = torch.cat(all_feats, dim=1)  # [B, total_dim]
        
        # 生成自适应权重
        weights = self.weight_generator(concat_feats)  # [B, 1+num_regions]
        
        # 按权重融合
        weighted_global = global_feat * weights[:, 0:1]
        weighted_local = sum(local_feats[i] * weights[:, i+1:i+2] 
                           for i in range(len(local_feats)))
        
        fused_feature = weighted_global + weighted_local
        return fused_feature, weights


class FineGrainedFlowerModel(nn.Module):
    """细粒度花卉分类模型"""
    def __init__(self, num_classes, num_regions=2):
        super().__init__()
        self.num_regions = num_regions
        
        # 骨干网络（ResNet50）
        backbone = models.resnet50(pretrained=True)
        
        # 全局特征提取器（使用全网络）
        self.global_extractor = nn.Sequential(*list(backbone.children())[:-1])
        
        # 中间层特征提取器（用于区域提议）
        self.mid_feature_extractor = nn.Sequential(*list(backbone.children())[:7])
        
        # 区域提议模块
        self.region_proposal = RegionProposalModule(1024, num_regions)
        
        # 局部特征提取器
        self.local_extractors = nn.ModuleList([
            LocalFeatureExtractor(backbone, 2048) for _ in range(num_regions)
        ])
        
        # 自适应融合模块
        self.fusion_module = AdaptiveFusionModule(
            global_dim=2048, 
            local_dim=256, 
            num_regions=num_regions
        )
        
        # 分类器
        self.global_classifier = nn.Linear(2048, num_classes)  # 辅助分类器
        self.local_classifiers = nn.ModuleList([
            nn.Linear(256, num_classes) for _ in range(num_regions)  # 局部辅助分类器
        ])
        self.final_classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
        
        # 冻结骨干网络的前几层
        for param in list(self.mid_feature_extractor.parameters())[:100]:
            param.requires_grad = False
    
    def forward(self, x):
        batch_size = x.size(0)
        
        # ===== 全局路径 =====
        global_feat = self.global_extractor(x)  # [B, 2048, 1, 1]
        global_feat_flat = global_feat.view(batch_size, -1)  # [B, 2048]
        
        # ===== 局部路径 =====
        # 提取中层特征用于区域提议
        mid_features = self.mid_feature_extractor(x)  # [B, 1024, 14, 14]
        
        # 生成区域注意力图
        attention_maps = self.region_proposal(mid_features)  # [B, num_regions, 14, 14]
        
        # 提取局部特征
        local_feats = []
        local_features_raw = self.global_extractor(x)  # 重用全局特征提取
        local_features_raw = local_features_raw.view(batch_size, -1)  # [B, 2048]
        
        for i in range(self.num_regions):
            # 使用注意力图加权原始输入
            att_map = attention_maps[:, i:i+1]  # [B, 1, 14, 14]
            att_map_upsampled = F.interpolate(
                att_map, size=x.shape[2:], mode='bilinear', align_corners=False
            )
            
            # 加权输入图像
            weighted_input = x * att_map_upsampled
            
            # 提取该区域的局部特征
            local_feat = self.local_extractors[i](weighted_input, att_map_upsampled)
            local_feats.append(local_feat)
        
        # ===== 自适应融合 =====
        fused_feature, fusion_weights = self.fusion_module(global_feat_flat, local_feats)
        
        # ===== 分类输出 =====
        global_logits = self.global_classifier(global_feat_flat)
        local_logits = [classifier(feat) for classifier, feat in zip(self.local_classifiers, local_feats)]
        final_logits = self.final_classifier(fused_feature)
        
        # ===== 注意力图上采样（用于可视化） =====
        attention_maps_vis = F.interpolate(
            attention_maps, size=x.shape[2:], mode='bilinear', align_corners=False
        )
        
        return {
            'final_logits': final_logits,
            'global_logits': global_logits,
            'local_logits': local_logits,
            'attention_maps': attention_maps_vis,
            'fusion_weights': fusion_weights,
            'fused_feature': fused_feature
        }


# ===================== 修改后的模型构建函数 =====================
def build_model(device):
    """构建双路径细粒度模型"""
    model = FineGrainedFlowerModel(
        num_classes=NUM_CLASSES,
        num_regions=NUM_REGIONS
    )
    
    # 移至指定设备
    model = model.to(device)
    
    # 打印模型信息
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"模型已加载至: {next(model.parameters()).device}")
    print(f"总参数量: {total_params:,}")
    print(f"可训练参数量: {trainable_params:,}")
    print(f"使用区域数量: {NUM_REGIONS}")
    
    return model


# ===================== 修改后的训练逻辑 =====================
def train_model(device):
    """重新训练双路径模型"""
    # 加载数据
    train_loader, test_loader, train_dataset, test_dataset = load_data()
    
    # 构建双路径模型
    model = build_model(device)
    
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()), 
        lr=LEARNING_RATE
    )
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
    
    # 训练指标记录
    train_losses = []
    train_accs = []
    test_losses = []
    test_accs = []
    best_test_acc = 0.0
    
    # 辅助损失权重
    lambda_global = 0.3  # 全局分支损失权重
    lambda_local = 0.2   # 每个局部分支损失权重
    
    print(f"\n========== 开始训练（双路径架构）==========")
    print(f"辅助损失权重: 全局={lambda_global}, 每个局部={lambda_local}")
    
    for epoch in range(EPOCHS):
        # ===== 训练阶段 =====
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{EPOCHS} [Train]")
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            
            # 前向传播
            outputs = model(images)
            final_logits = outputs['final_logits']
            global_logits = outputs['global_logits']
            local_logits = outputs['local_logits']
            
            # 计算多粒度损失
            loss_final = criterion(final_logits, labels)
            loss_global = criterion(global_logits, labels)
            loss_local = sum(criterion(logits, labels) for logits in local_logits) / len(local_logits)
            
            # 总损失（加权和）
            total_loss = loss_final + lambda_global * loss_global + lambda_local * loss_local
            
            # 反向传播
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            
            # 统计
            running_loss += total_loss.item() * images.size(0)
            _, predicted = torch.max(final_logits, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            pbar.set_postfix({
                "loss": running_loss / total,
                "acc": correct / total,
                "fusion_w": outputs['fusion_weights'][0].cpu().detach().numpy().round(2)
            })
        
        train_loss = running_loss / len(train_loader.dataset)
        train_acc = correct / total
        
        # ===== 验证阶段 =====
        model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            pbar = tqdm(test_loader, desc="Validation")
            for images, labels in pbar:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                final_logits = outputs['final_logits']
                global_logits = outputs['global_logits']
                local_logits = outputs['local_logits']
                
                # 计算验证损失（仅用于监控，不反向传播）
                loss_final = criterion(final_logits, labels)
                loss_global = criterion(global_logits, labels)
                loss_local = sum(criterion(logits, labels) for logits in local_logits) / len(local_logits)
                total_loss = loss_final + lambda_global * loss_global + lambda_local * loss_local
                
                running_loss += total_loss.item() * images.size(0)
                _, predicted = torch.max(final_logits, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                
                pbar.set_postfix({
                    "loss": running_loss / total,
                    "acc": correct / total
                })
        
        test_loss = running_loss / len(test_loader.dataset)
        test_acc = correct / total
        
        # 学习率衰减
        scheduler.step()
        
        # 记录指标
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        test_losses.append(test_loss)
        test_accs.append(test_acc)
        
        # 保存最优模型
        if test_acc > best_test_acc:
            best_test_acc = test_acc
            torch.save({
                'model_state_dict': model.state_dict(),
                'best_test_acc': best_test_acc,
                'epoch': epoch
            }, MODEL_PATH)
            print(f"保存最优模型，测试准确率: {best_test_acc:.4f}")
        
        # 打印epoch结果
        print(f"Epoch {epoch + 1}")
        print(f"  训练损失: {train_loss:.4f}, 训练准确率: {train_acc:.4f}")
        print(f"  测试损失: {test_loss:.4f}, 测试准确率: {test_acc:.4f}")
        print(f"  融合权重示例: {outputs['fusion_weights'][0].cpu().numpy().round(3)}\n")
    
    # 保存训练指标
    save_training_metrics(train_losses, train_accs, test_losses, test_accs)
    print(f"训练完成！最优测试准确率: {best_test_acc:.4f}")
    
    # 可视化结果
    visualize_results(train_losses, train_accs, test_losses, test_accs, train_dataset, test_dataset)
    
    # 可视化注意力图（新增）
    visualize_attention_maps(model, test_loader, device)
    
    # 绘制混淆矩阵
    plot_confusion_matrix(model, test_loader, device)
    
    # 最终评估
    evaluate_model(device, test_loader)


def visualize_attention_maps(model, test_loader, device, num_samples=5):
    """可视化区域提议模块的注意力图"""
    model.eval()
    
    plt.rcParams["font.sans-serif"] = ["SimHei"]
    plt.rcParams["axes.unicode_minus"] = False
    
    with torch.no_grad():
        images, labels = next(iter(test_loader))
        images = images[:num_samples].to(device)
        labels = labels[:num_samples]
        
        outputs = model(images)
        attention_maps = outputs['attention_maps'].cpu()  # [B, num_regions, H, W]
        fusion_weights = outputs['fusion_weights'].cpu()  # [B, 1+num_regions]
        
        fig, axes = plt.subplots(num_samples, NUM_REGIONS + 2, figsize=(15, 3*num_samples))
        
        for i in range(num_samples):
            # 原始图像
            img = images[i].cpu()
            img = img * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
            img = img + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
            img = img.clamp(0, 1).permute(1, 2, 0).numpy()
            
            axes[i, 0].imshow(img)
            axes[i, 0].set_title(f"原始图像\n{CLASS_NAMES[labels[i]]}")
            axes[i, 0].axis('off')
            
            # 各个区域的注意力图
            for j in range(NUM_REGIONS):
                att_map = attention_maps[i, j].numpy()
                axes[i, j+1].imshow(att_map, cmap='hot')
                weight = fusion_weights[i, j+1].item()
                axes[i, j+1].set_title(f"区域{j+1}\n权重: {weight:.3f}")
                axes[i, j+1].axis('off')
            
            # 融合后的注意力图
            combined_att = attention_maps[i].sum(dim=0).numpy()
            axes[i, -1].imshow(combined_att, cmap='hot')
            axes[i, -1].set_title("融合注意力")
            axes[i, -1].axis('off')
        
        plt.suptitle("区域注意力可视化", fontsize=16)
        plt.tight_layout()
        plt.savefig("attention_visualization.png", dpi=300)
        plt.show()


# ===================== 修改后的评估函数 =====================
def evaluate_model(device, test_loader=None):
    """测试双路径模型"""
    if test_loader is None:
        _, test_loader, _, _ = load_data()
    
    if not os.path.exists(MODEL_PATH):
        raise FileNotFoundError(f"未找到模型文件: {MODEL_PATH}，请先执行训练模式！")
    
    # 构建模型
    model = build_model(device)
    
    # 加载模型权重
    checkpoint = torch.load(MODEL_PATH)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    # 测试评估
    criterion = nn.CrossEntropyLoss()
    running_loss = 0.0
    correct = 0
    total = 0
    class_correct = [0] * NUM_CLASSES
    class_total = [0] * NUM_CLASSES
    
    print(f"\n========== 开始测试（双路径架构）==========")
    
    # 收集融合权重用于分析
    all_fusion_weights = []
    
    with torch.no_grad():
        pbar = tqdm(test_loader, desc="Testing")
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            final_logits = outputs['final_logits']
            
            # 收集融合权重
            all_fusion_weights.append(outputs['fusion_weights'].cpu())
            
            # 计算损失
            loss = criterion(final_logits, labels)
            
            # 统计指标
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(final_logits, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            for label, pred in zip(labels, predicted):
                if label == pred:
                    class_correct[label] += 1
                class_total[label] += 1
            
            pbar.set_postfix({"loss": running_loss / total, "acc": correct / total})
    
    # 计算最终指标
    final_loss = running_loss / len(test_loader.dataset)
    final_acc = correct / total
    
    # 分析融合权重
    all_fusion_weights = torch.cat(all_fusion_weights, dim=0)
    avg_weights = all_fusion_weights.mean(dim=0).numpy()
    
    print(f"\n测试完成！")
    print(f"整体测试损失: {final_loss:.4f}, 整体测试准确率: {final_acc:.4f}")
    print(f"平均融合权重: 全局={avg_weights[0]:.3f}, " + 
          ", ".join([f"区域{i+1}={avg_weights[i+1]:.3f}" for i in range(NUM_REGIONS)]))
    
    print("\n各类别准确率:")
    for i in range(NUM_CLASSES):
        acc = class_correct[i] / class_total[i] if class_total[i] > 0 else 0
        print(f"{CLASS_NAMES[i]}: {acc:.4f} (正确数: {class_correct[i]}, 总数: {class_total[i]})")
    
    # 可视化结果
    try:
        train_losses, train_accs, test_losses, test_accs = load_training_metrics()
        _, _, train_dataset, test_dataset = load_data()
        visualize_results(train_losses, train_accs, test_losses, test_accs, train_dataset, test_dataset)
        
        # 可视化注意力图
        visualize_attention_maps(model, test_loader, device)
        
        # 绘制混淆矩阵
        plot_confusion_matrix(model, test_loader, device)
    except FileNotFoundError as e:
        print(f"\n警告：{e}，仅展示测试结果，跳过可视化！")


# ===================== 其他辅助函数保持不变 =====================
# （check_gpu_availability, load_data, save_training_metrics, 
# load_training_metrics, visualize_results, plot_confusion_matrix）

# ... [保持原有的辅助函数不变，只修改上面提到的部分] ...

# ===================== 主程序入口保持不变 =====================
if __name__ == '__main__':
    device = check_gpu_availability()
    
    print("\n========== 细粒度花卉分类模型（双路径架构）==========")
    print(f"架构特点:")
    print(f"1. 全局路径 + {NUM_REGIONS}个局部路径")
    print(f"2. 自适应融合模块")
    print(f"3. 多粒度监督训练")
    
    print("\n请选择运行模式：")
    print("1. 重新训练模型（双路径架构）")
    print("2. 仅测试（使用已训练的双路径模型）")
    
    while True:
        try:
            choice = int(input("输入选择（1/2）："))
            if choice in [1, 2]:
                break
            else:
                print("请输入1或2！")
        except ValueError:
            print("请输入有效的数字（1/2）！")
    
    if choice == 1:
        train_model(device)
    elif choice == 2:
        evaluate_model(device)

NameError: name 'check_gpu_availability' is not defined