<a href="https://colab.research.google.com/github/kermitstart/marine-fish-recognition/blob/master/colab_fish_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🐟 海洋鱼类识别模型 - Colab训练版

本notebook专为Google Colab设计，用于训练海洋鱼类单鱼识别模型。

## 📋 训练计划
1. **环境设置**: 配置GPU、安装依赖
2. **数据准备**: 上传和处理增强数据集
3. **模型训练**: 使用EfficientNet/ResNet进行迁移学习
4. **性能评估**: 验证模型效果
5. **模型导出**: 保存用于部署的模型

## ⚡ 使用前准备
- 确保启用GPU运行时 (Runtime → Change runtime type → GPU)
- 准备好增强数据集的压缩包
- 预计训练时间: 2-4小时

## 🔧 项目设置

### 从GitHub克隆项目
本notebook会自动从GitHub克隆海洋鱼类识别项目到Google Drive，确保数据和代码的持久保存。

**使用步骤：**
1. 运行下方代码挂载Google Drive
2. 自动克隆项目到Drive中
3. 设置工作目录并开始训练

**注意：** 请确保您有足够的Drive空间（建议5GB以上）

## 1. 环境设置和GPU检查

In [None]:
# === 步骤1: 挂载Google Drive ===
from google.colab import drive
import os
import subprocess
import shutil
import json

# 挂载Google Drive
drive.mount('/content/drive')
print("✅ Google Drive 挂载成功")

# 设置项目路径
DRIVE_ROOT = "/content/drive/MyDrive"
PROJECT_DIR = f"{DRIVE_ROOT}/MarineFish_recognition"
GITHUB_REPO = "kermitstart/marine-fish-recognition"  # 🔴 请替换为实际的GitHub仓库地址

print(f"项目将保存到: {PROJECT_DIR}")

# === 步骤2: 克隆或更新项目 ===
def clone_or_update_project():
    """从GitHub克隆或更新项目到Google Drive"""
    try:
        if os.path.exists(PROJECT_DIR):
            print("🔄 项目目录已存在，更新项目...")
            os.chdir(PROJECT_DIR)
            # 拉取最新更改
            result = subprocess.run(["git", "pull"], capture_output=True, text=True)
            if result.returncode == 0:
                print("✅ 项目更新成功")
            else:
                print(f"⚠️  项目更新失败: {result.stderr}")
        else:
            print("📥 从GitHub克隆项目...")
            os.chdir(DRIVE_ROOT)
            # 克隆项目
            result = subprocess.run([
                "git", "clone",
                f"https://github.com/{GITHUB_REPO}.git",
                "MarineFish_recognition"
            ], capture_output=True, text=True)

            if result.returncode == 0:
                print("✅ 项目克隆成功")
            else:
                print(f"❌ 项目克隆失败: {result.stderr}")
                print("请检查GitHub仓库地址是否正确")
                return False

        return True
    except Exception as e:
        print(f"❌ 项目设置失败: {str(e)}")
        return False

# 执行项目克隆/更新
success = clone_or_update_project()

if success:
    # 切换到项目目录
    os.chdir(PROJECT_DIR)
    print(f"📂 当前工作目录: {os.getcwd()}")

    # 检查项目结构
    if os.path.exists("fish_backbone"):
        print("✅ 项目结构检查通过")
        print("📁 发现以下关键目录:")
        for item in ["fish_backbone", "compact_dataset", "dataset", "fish_backbone/mini_dataset"]:
            if os.path.exists(item):
                print(f"   ✅ {item}")
            else:
                print(f"   ⚠️  {item} (未找到)")
    else:
        print("❌ 项目结构异常，请检查GitHub仓库")
else:
    print("❌ 无法继续，请检查网络连接和GitHub仓库设置")

## 2. 从GitHub克隆项目和数据集

### 方案A: 从GitHub克隆完整项目（推荐）
1. 将整个项目（包括增强数据集）上传到GitHub
2. 在Colab中直接克隆和使用

### 方案B: 挂载Google Drive
1. 将数据集上传到Google Drive
2. 在Colab中挂载Drive访问数据

选择下面对应的代码块执行：

In [None]:
# 方案A: 从GitHub克隆项目 (推荐)
def clone_from_github():
    """从GitHub克隆项目"""

    # 替换为您的GitHub仓库URL
    GITHUB_REPO_URL = "https://github.com/YOUR_USERNAME/MarineFish_recognition.git"
    PROJECT_NAME = "MarineFish_recognition"

    print("📥 从GitHub克隆项目...")

    # 克隆仓库
    if not os.path.exists(PROJECT_NAME):
        print(f"🔄 正在克隆: {GITHUB_REPO_URL}")
        !git clone {GITHUB_REPO_URL}
        print(f"✅ 项目克隆完成: {PROJECT_NAME}")
    else:
        print(f"📁 项目已存在: {PROJECT_NAME}")
        # 拉取最新更改
        !cd {PROJECT_NAME} && git pull
        print("✅ 项目已更新到最新版本")

    # 自动检测数据集
    return auto_detect_dataset(PROJECT_NAME)

# 方案B: 挂载Google Drive
def mount_google_drive():
    """挂载Google Drive并访问数据集"""

    print("📱 挂载Google Drive...")

    from google.colab import drive
    drive.mount('/content/drive')

    # 数据集在Drive中的路径 (请根据实际情况修改)
    drive_dataset_path = "/content/drive/MyDrive/MarineFish_Dataset/augmented_dataset"

    if os.path.exists(drive_dataset_path):
        print("✅ Google Drive挂载成功，找到数据集")

        classes = [d for d in os.listdir(drive_dataset_path)
                  if os.path.isdir(os.path.join(drive_dataset_path, d)) and not d.startswith('.')]

        print(f"📊 发现 {len(classes)} 个鱼类类别")
        return True, classes, drive_dataset_path
    else:
        print("❌ 未在Google Drive中找到数据集")
        print(f"请确保数据集在: {drive_dataset_path}")
        return False, [], ""

# 选择使用的方案
USE_GITHUB = True  # 设置为True使用GitHub，False使用Google Drive

print("🚀 开始数据集准备...")
print("=" * 50)

if USE_GITHUB:
    print("📂 使用方案A: GitHub克隆")
    print("⚠️  请先修改上面的GITHUB_REPO_URL为您的仓库地址")
    success, class_names, data_path = clone_from_github()
else:
    print("📂 使用方案B: Google Drive")
    success, class_names, data_path = mount_google_drive()

if success:
    print("\n🎯 数据集准备就绪，可以开始训练！")
    NUM_CLASSES = len(class_names)
    print(f"分类类别数: {NUM_CLASSES}")

    # 显示详细统计
    total_images = 0
    print(f"\n📊 数据集详情:")
    for i, class_name in enumerate(class_names):
        class_path = os.path.join(data_path, class_name)
        if os.path.exists(class_path):
            count = len([f for f in os.listdir(class_path)
                        if f.endswith(('.png', '.jpg', '.jpeg'))])
            total_images += count
            if i < 10:  # 只显示前10个类别
                print(f"  {i+1:2d}. {class_name:25s} - {count:3d} 张")

    if len(class_names) > 10:
        print(f"  ... 还有 {len(class_names)-10} 个类别")

    print(f"\n📈 总计: {total_images} 张图片")
    print(f"平均每类: {total_images // len(class_names)} 张")

else:
    print("\n⚠️  数据集准备失败，请检查:")
    if USE_GITHUB:
        print("1. GitHub仓库URL是否正确")
        print("2. 仓库是否包含数据集 (augmented_dataset 或 compact_dataset.zip)")
        print("3. 网络连接是否正常")
    else:
        print("1. Google Drive是否已授权")
        print("2. 数据集路径是否正确")
        print("3. 数据集是否已上传到Drive")

# === 步骤3: 安装依赖 ===
import sys

# 确保在项目目录中
if not os.path.exists("fish_backbone"):
    print("❌ 请先运行上面的代码克隆项目")
    sys.exit(1)

print("📦 安装Python依赖包...")

# 检查是否有requirements.txt
requirements_files = [
    "fish_backbone/requirements.txt",
    "requirements.txt"
]

requirements_file = None
for req_file in requirements_files:
    if os.path.exists(req_file):
        requirements_file = req_file
        break

if requirements_file:
    print(f"📋 使用 {requirements_file} 安装依赖")
    result = subprocess.run([
        sys.executable, "-m", "pip", "install", "-r", requirements_file
    ], capture_output=True, text=True)

    if result.returncode == 0:
        print("✅ 依赖安装成功")
    else:
        print(f"⚠️  部分依赖安装失败: {result.stderr}")
        print("尝试手动安装核心依赖...")

        # 手动安装核心依赖
        core_packages = [
            "torch", "torchvision", "torchaudio",
            "pillow", "numpy", "matplotlib",
            "opencv-python", "scikit-learn",
            "tqdm", "flask"
        ]

        for package in core_packages:
            print(f"安装 {package}...")
            subprocess.run([sys.executable, "-m", "pip", "install", package],
                         capture_output=True, text=True)
else:
    print("📋 未找到requirements.txt，安装基础依赖...")
    # 基础依赖列表
    packages = [
        "torch", "torchvision", "torchaudio",
        "pillow", "numpy", "matplotlib", "opencv-python",
        "scikit-learn", "tqdm", "flask", "requests"
    ]

    for package in packages:
        print(f"安装 {package}...")
        result = subprocess.run([
            sys.executable, "-m", "pip", "install", package
        ], capture_output=True, text=True)

        if result.returncode == 0:
            print(f"✅ {package} 安装成功")
        else:
            print(f"⚠️  {package} 安装失败")

print("🎯 检查GPU可用性...")
try:
    import torch
    if torch.cuda.is_available():
        print(f"✅ GPU可用: {torch.cuda.get_device_name(0)}")
        print(f"   GPU内存: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    else:
        print("⚠️  GPU不可用，将使用CPU训练（速度较慢）")
except ImportError:
    print("⚠️  PyTorch未正确安装")

In [None]:
# 方案C: 处理GitHub上的紧凑数据集 (适合较小项目)
def extract_compact_dataset(project_dir):
    """解压GitHub项目中的紧凑数据集"""

    compact_zip = os.path.join(project_dir, "compact_dataset.zip")

    if os.path.exists(compact_zip):
        print("📦 发现紧凑数据集压缩包，开始解压...")

        import zipfile
        with zipfile.ZipFile(compact_zip, 'r') as zip_ref:
            zip_ref.extractall(project_dir)

        dataset_path = os.path.join(project_dir, "compact_dataset")

        if os.path.exists(dataset_path):
            classes = [d for d in os.listdir(dataset_path)
                      if os.path.isdir(os.path.join(dataset_path, d)) and not d.startswith('.')]

            print(f"✅ 紧凑数据集解压完成")
            print(f"📊 发现 {len(classes)} 个鱼类类别")

            # 统计图片数量
            total_images = 0
            for class_name in classes[:5]:
                class_path = os.path.join(dataset_path, class_name)
                count = len([f for f in os.listdir(class_path)
                            if f.endswith(('.png', '.jpg', '.jpeg'))])
                total_images += count
                print(f"  - {class_name}: {count} 张图片")

            if len(classes) > 5:
                print(f"  ... 还有 {len(classes)-5} 个类别")

            print(f"📈 总计: 约 {total_images * len(classes) // 5} 张图片")

            return True, classes, dataset_path
        else:
            print("❌ 解压后未找到数据集目录")
            return False, [], ""
    else:
        print("❌ 未找到紧凑数据集压缩包")
        return False, [], ""

# 自动检测数据集类型
def auto_detect_dataset(project_dir):
    """自动检测可用的数据集"""

    # 检查完整增强数据集
    full_dataset = os.path.join(project_dir, "augmented_dataset")
    if os.path.exists(full_dataset):
        print("🎯 找到完整增强数据集")
        classes = [d for d in os.listdir(full_dataset)
                  if os.path.isdir(os.path.join(full_dataset, d)) and not d.startswith('.')]
        return True, classes, full_dataset

    # 检查紧凑数据集压缩包
    compact_zip = os.path.join(project_dir, "compact_dataset.zip")
    if os.path.exists(compact_zip):
        print("📦 找到紧凑数据集压缩包")
        return extract_compact_dataset(project_dir)

    # 检查已解压的紧凑数据集
    compact_dataset = os.path.join(project_dir, "compact_dataset")
    if os.path.exists(compact_dataset):
        print("📁 找到紧凑数据集目录")
        classes = [d for d in os.listdir(compact_dataset)
                  if os.path.isdir(os.path.join(compact_dataset, d)) and not d.startswith('.')]
        return True, classes, compact_dataset

    print("❌ 未找到任何可用的数据集")
    return False, [], ""

# === 步骤4: 选择数据集 ===
import json

print("🎯 检测可用数据集...")

# 检查可用的数据集
available_datasets = {}

# 检查完整数据集
if os.path.exists("dataset"):
    dataset_size = len([f for f in os.listdir("dataset") if os.path.isdir(os.path.join("dataset", f))])
    available_datasets["完整数据集"] = {
        "path": "dataset",
        "classes": dataset_size,
        "description": "完整的海洋鱼类数据集，类别最多但训练时间较长"
    }

# 检查紧凑数据集
if os.path.exists("compact_dataset"):
    compact_size = len([f for f in os.listdir("compact_dataset") if os.path.isdir(os.path.join("compact_dataset", f))])
    available_datasets["紧凑数据集"] = {
        "path": "compact_dataset",
        "classes": compact_size,
        "description": "精选的紧凑数据集，平衡了数据量和训练效率"
    }

# 检查mini数据集
if os.path.exists("fish_backbone/mini_dataset"):
    mini_size = len([f for f in os.listdir("fish_backbone/mini_dataset") if os.path.isdir(os.path.join("fish_backbone/mini_dataset", f))])
    available_datasets["Mini数据集"] = {
        "path": "fish_backbone/mini_dataset",
        "classes": mini_size,
        "description": "快速测试用的小型数据集，训练速度最快"
    }

# 显示可用数据集
print("📊 可用数据集:")
for name, info in available_datasets.items():
    print(f"   🔹 {name}: {info['classes']}个类别 - {info['description']}")

# === 数据集选择配置 ===
# 🔴 在这里选择要使用的数据集类型
DATASET_CHOICE = "Mini数据集"  # 可选: "完整数据集", "紧凑数据集", "Mini数据集"

if DATASET_CHOICE not in available_datasets:
    print(f"❌ 选择的数据集 '{DATASET_CHOICE}' 不可用")
    print("可用选项:", list(available_datasets.keys()))
    DATASET_CHOICE = list(available_datasets.keys())[0]  # 使用第一个可用的数据集
    print(f"🔄 自动选择: {DATASET_CHOICE}")

# 设置数据集路径
DATASET_PATH = available_datasets[DATASET_CHOICE]["path"]
NUM_CLASSES = available_datasets[DATASET_CHOICE]["classes"]

print(f"\n✅ 选择数据集: {DATASET_CHOICE}")
print(f"📂 数据集路径: {DATASET_PATH}")
print(f"🏷️  类别数量: {NUM_CLASSES}")

# 检查对应的配置文件
config_files = {
    "dataset": "fish_backbone/dataset_config.json",
    "compact_dataset": "fish_backbone/dataset_config.json",
    "fish_backbone/mini_dataset": "fish_backbone/mini_dataset_config.json"
}

config_file = config_files.get(DATASET_PATH)
if config_file and os.path.exists(config_file):
    print(f"📋 找到配置文件: {config_file}")
    with open(config_file, 'r', encoding='utf-8') as f:
        dataset_config = json.load(f)
    print(f"   类别映射已加载: {len(dataset_config.get('class_names', []))}个类别")
else:
    print("⚠️  未找到对应的配置文件，将动态生成")

# 检查数据集内容
print(f"\n🔍 检查数据集内容...")
if os.path.exists(DATASET_PATH):
    classes = [d for d in os.listdir(DATASET_PATH) if os.path.isdir(os.path.join(DATASET_PATH, d))]
    print(f"   发现 {len(classes)} 个类别:")
    for i, cls in enumerate(sorted(classes)[:10]):  # 只显示前10个
        sample_count = len(os.listdir(os.path.join(DATASET_PATH, cls)))
        print(f"     {i+1:2d}. {cls}: {sample_count} 张图片")
    if len(classes) > 10:
        print(f"     ... 还有 {len(classes)-10} 个类别")
else:
    print("❌ 数据集路径不存在")

## 3. 数据加载器定义

In [None]:
class FishDataset(Dataset):
    """海洋鱼类数据集类"""

    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted([d for d in os.listdir(root_dir)
                              if os.path.isdir(os.path.join(root_dir, d)) and not d.startswith('.')])
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}

        # 构建图片路径和标签列表
        self.samples = []
        for class_name in self.classes:
            class_dir = os.path.join(root_dir, class_name)
            class_idx = self.class_to_idx[class_name]

            for img_name in os.listdir(class_dir):
                if img_name.endswith(('.png', '.jpg', '.jpeg')):
                    img_path = os.path.join(class_dir, img_name)
                    self.samples.append((img_path, class_idx))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]

        # 加载图片
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label

# 定义数据变换
def get_transforms():
    # 训练时的数据增强
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # 验证/测试时的变换
    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    return train_transform, val_transform

# 创建数据集和数据加载器
if success:
    train_transform, val_transform = get_transforms()

    # 创建完整数据集
    full_dataset = FishDataset(data_path, transform=val_transform)

    # 划分训练集、验证集、测试集 (70%, 20%, 10%)
    dataset_size = len(full_dataset)
    train_size = int(0.7 * dataset_size)
    val_size = int(0.2 * dataset_size)
    test_size = dataset_size - train_size - val_size

    train_dataset, val_dataset, test_dataset = random_split(
        full_dataset, [train_size, val_size, test_size]
    )

    # 为训练集设置数据增强
    train_dataset.dataset.transform = train_transform

    # 创建数据加载器
    BATCH_SIZE = 32
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    print(f"📊 数据集划分完成:")
    print(f"  训练集: {len(train_dataset)} 张 ({len(train_dataset)/dataset_size*100:.1f}%)")
    print(f"  验证集: {len(val_dataset)} 张 ({len(val_dataset)/dataset_size*100:.1f}%)")
    print(f"  测试集: {len(test_dataset)} 张 ({len(test_dataset)/dataset_size*100:.1f}%)")
    print(f"  批次大小: {BATCH_SIZE}")

    # 保存类别信息
    class_info = {
        'classes': class_names,
        'class_to_idx': full_dataset.class_to_idx,
        'num_classes': len(class_names)
    }

    with open('class_info.json', 'w', encoding='utf-8') as f:
        json.dump(class_info, f, ensure_ascii=False, indent=2)

    print("✅ 数据加载器创建完成")
else:
    print("⚠️  请先准备数据集")

# === 步骤5: 模型定义 ===
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms
import sys
import os

# 添加项目路径到Python路径
sys.path.append('fish_backbone')

print(f"🧠 定义模型架构 (类别数: {NUM_CLASSES})")

class FishClassifier(nn.Module):
    def __init__(self, num_classes):
        super(FishClassifier, self).__init__()
        # 使用预训练的ResNet18作为骨干网络
        self.backbone = models.resnet18(pretrained=True)

        # 替换最后的分类层
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.backbone(x)

# 创建模型实例
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = FishClassifier(num_classes=NUM_CLASSES).to(device)

print(f"✅ 模型创建成功")
print(f"📱 设备: {device}")
print(f"🔢 参数总数: {sum(p.numel() for p in model.parameters()):,}")

# 定义数据变换
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

print("🔄 数据变换定义完成")

# 检查是否有预训练模型可以加载
model_files = [
    "fish_backbone/marine_fish_model.pth",
    "fish_backbone/best_model.pth",
    "marine_fish_model.pth"
]

pretrained_model = None
for model_file in model_files:
    if os.path.exists(model_file):
        pretrained_model = model_file
        break

if pretrained_model:
    try:
        print(f"🔄 尝试加载预训练模型: {pretrained_model}")
        checkpoint = torch.load(pretrained_model, map_location=device)

        # 检查模型结构兼容性
        if hasattr(checkpoint, 'keys') and 'state_dict' in checkpoint:
            model_state = checkpoint['state_dict']
        else:
            model_state = checkpoint

        # 检查分类层的维度
        fc_weight_shape = model_state.get('backbone.fc.weight', model_state.get('fc.weight'))
        if fc_weight_shape is not None and fc_weight_shape.shape[0] == NUM_CLASSES:
            model.load_state_dict(model_state)
            print("✅ 预训练模型加载成功")
        else:
            print(f"⚠️  预训练模型类别数不匹配 (需要: {NUM_CLASSES})，使用随机初始化")
    except Exception as e:
        print(f"⚠️  预训练模型加载失败: {str(e)}")
        print("使用随机初始化的模型")
else:
    print("📝 未找到预训练模型，使用随机初始化")

## 4. 模型定义

选择预训练模型进行迁移学习。推荐使用EfficientNet-B0以获得最佳的速度和精度平衡。

In [None]:
# 安装efficientnet（如果需要）
try:
    from efficientnet_pytorch import EfficientNet
    print("✅ EfficientNet 已安装")
except ImportError:
    print("📦 正在安装 EfficientNet...")
    !pip install efficientnet_pytorch
    from efficientnet_pytorch import EfficientNet
    print("✅ EfficientNet 安装完成")

def create_model(model_name='efficientnet-b0', num_classes=NUM_CLASSES, pretrained=True):
    """
    创建预训练模型

    Args:
        model_name: 模型名称 ('efficientnet-b0', 'resnet50', 'resnet101')
        num_classes: 分类类别数
        pretrained: 是否使用预训练权重
    """

    if model_name.startswith('efficientnet'):
        if pretrained:
            model = EfficientNet.from_pretrained(model_name, num_classes=num_classes)
        else:
            model = EfficientNet.from_name(model_name, num_classes=num_classes)

        print(f"🤖 创建 {model_name} 模型")

    elif model_name == 'resnet50':
        model = models.resnet50(pretrained=pretrained)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        print(f"🤖 创建 ResNet50 模型")

    elif model_name == 'resnet101':
        model = models.resnet101(pretrained=pretrained)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        print(f"🤖 创建 ResNet101 模型")

    else:
        raise ValueError(f"不支持的模型: {model_name}")

    return model

# 模型选择
MODEL_NAME = 'efficientnet-b0'  # 可以改为 'resnet50', 'resnet101'

if success:
    # 创建模型
    model = create_model(MODEL_NAME, NUM_CLASSES)
    model = model.to(device)

    # 计算模型参数
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(f"📊 模型信息:")
    print(f"  模型: {MODEL_NAME}")
    print(f"  类别数: {NUM_CLASSES}")
    print(f"  总参数: {total_params:,}")
    print(f"  可训练参数: {trainable_params:,}")
    print(f"  设备: {device}")

    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

    # 学习率调度器
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-6)

    print("✅ 模型和训练组件创建完成")
else:
    print("⚠️  请先准备数据集")

# === 步骤6: 数据加载 ===
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import random
from sklearn.model_selection import train_test_split

print("📊 准备数据加载...")

class FishDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        try:
            image_path = self.image_paths[idx]
            image = Image.open(image_path).convert('RGB')
            label = self.labels[idx]

            if self.transform:
                image = self.transform(image)

            return image, label
        except Exception as e:
            print(f"⚠️  加载图片失败: {image_path}, 错误: {str(e)}")
            # 返回一个空白图片作为备用
            blank_image = Image.new('RGB', (224, 224), color='white')
            if self.transform:
                blank_image = self.transform(blank_image)
            return blank_image, self.labels[idx]

# 扫描数据集
print(f"🔍 扫描数据集: {DATASET_PATH}")
image_paths = []
labels = []
class_names = []

if os.path.exists(DATASET_PATH):
    class_dirs = sorted([d for d in os.listdir(DATASET_PATH)
                        if os.path.isdir(os.path.join(DATASET_PATH, d))])

    print(f"发现 {len(class_dirs)} 个类别")

    for class_idx, class_name in enumerate(class_dirs):
        class_path = os.path.join(DATASET_PATH, class_name)
        class_names.append(class_name)

        # 获取该类别的所有图片
        image_files = [f for f in os.listdir(class_path)
                      if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

        for image_file in image_files:
            image_path = os.path.join(class_path, image_file)
            image_paths.append(image_path)
            labels.append(class_idx)

        print(f"  {class_name}: {len(image_files)} 张图片")

    print(f"\n📈 数据集统计:")
    print(f"   总图片数: {len(image_paths)}")
    print(f"   类别数: {len(class_names)}")
    print(f"   平均每类: {len(image_paths)/len(class_names):.1f} 张")

    # 数据集划分
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        image_paths, labels, test_size=0.2, random_state=42, stratify=labels
    )

    print(f"\n🔄 数据集划分:")
    print(f"   训练集: {len(train_paths)} 张")
    print(f"   验证集: {len(val_paths)} 张")

    # 创建数据集
    train_dataset = FishDataset(train_paths, train_labels, train_transform)
    val_dataset = FishDataset(val_paths, val_labels, val_transform)

    # 创建数据加载器
    batch_size = 32 if len(image_paths) > 1000 else 16  # 根据数据集大小调整批次大小

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True if torch.cuda.is_available() else False
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True if torch.cuda.is_available() else False
    )

    print(f"✅ 数据加载器创建成功")
    print(f"   批次大小: {batch_size}")
    print(f"   训练批次数: {len(train_loader)}")
    print(f"   验证批次数: {len(val_loader)}")

    # 保存类别名称映射
    class_mapping = {i: name for i, name in enumerate(class_names)}
    print(f"\n🏷️  类别映射:")
    for i, name in list(class_mapping.items())[:5]:
        print(f"   {i}: {name}")
    if len(class_mapping) > 5:
        print(f"   ... 还有 {len(class_mapping)-5} 个类别")

else:
    print("❌ 数据集路径不存在")
    raise FileNotFoundError(f"数据集路径不存在: {DATASET_PATH}")

## 5. 模型训练

开始训练海洋鱼类识别模型。训练过程包括多个epoch，每个epoch都会在训练集上训练并在验证集上评估。

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler,
                num_epochs=20, patience=7, save_path='best_fish_model.pth'):
    """
    训练模型

    Args:
        model: 待训练的模型
        train_loader: 训练数据加载器
        val_loader: 验证数据加载器
        criterion: 损失函数
        optimizer: 优化器
        scheduler: 学习率调度器
        num_epochs: 训练轮数
        patience: 早停耐心值
        save_path: 模型保存路径
    """

    print(f"🚀 开始训练，共 {num_epochs} 个epoch")
    print("=" * 60)

    # 记录训练历史
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': [],
        'lr': []
    }

    best_val_acc = 0.0
    patience_counter = 0

    for epoch in range(num_epochs):
        start_time = time.time()
        current_lr = optimizer.param_groups[0]['lr']

        print(f"Epoch {epoch+1}/{num_epochs} | LR: {current_lr:.2e}")

        # 训练阶段
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        train_pbar = tqdm(train_loader, desc="训练", leave=False)
        for batch_idx, (data, target) in enumerate(train_pbar):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            train_total += target.size(0)
            train_correct += (predicted == target).sum().item()

            # 更新进度条
            train_acc = 100. * train_correct / train_total
            train_pbar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Acc': f'{train_acc:.2f}%'
            })

        # 验证阶段
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            val_pbar = tqdm(val_loader, desc="验证", leave=False)
            for data, target in val_pbar:
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss = criterion(output, target)

                val_loss += loss.item()
                _, predicted = torch.max(output.data, 1)
                val_total += target.size(0)
                val_correct += (predicted == target).sum().item()

                # 更新进度条
                val_acc = 100. * val_correct / val_total
                val_pbar.set_postfix({
                    'Loss': f'{loss.item():.4f}',
                    'Acc': f'{val_acc:.2f}%'
                })

        # 计算平均指标
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        train_accuracy = 100. * train_correct / train_total
        val_accuracy = 100. * val_correct / val_total

        # 更新学习率
        scheduler.step()

        # 记录历史
        history['train_loss'].append(avg_train_loss)
        history['train_acc'].append(train_accuracy)
        history['val_loss'].append(avg_val_loss)
        history['val_acc'].append(val_accuracy)
        history['lr'].append(current_lr)

        # 计算epoch耗时
        epoch_time = time.time() - start_time

        # 打印结果
        print(f"训练损失: {avg_train_loss:.4f} | 训练准确率: {train_accuracy:.2f}%")
        print(f"验证损失: {avg_val_loss:.4f} | 验证准确率: {val_accuracy:.2f}%")
        print(f"耗时: {epoch_time:.1f}s")

        # 保存最佳模型
        if val_accuracy > best_val_acc:
            best_val_acc = val_accuracy
            patience_counter = 0

            # 保存模型
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_val_acc': best_val_acc,
                'class_info': class_info,
                'model_name': MODEL_NAME
            }, save_path)

            print(f"🎯 新的最佳验证准确率: {val_accuracy:.2f}% (已保存模型)")
        else:
            patience_counter += 1

        # 早停检查
        if patience_counter >= patience:
            print(f"🛑 验证准确率连续 {patience} 个epoch未提升，触发早停")
            break

        print("-" * 60)

    print(f"🎉 训练完成！最佳验证准确率: {best_val_acc:.2f}%")

    return history

# === 步骤7: 训练配置 ===
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import time

print("⚙️ 配置训练参数...")

# 根据数据集大小自动调整训练参数
total_samples = len(image_paths)
if total_samples < 500:
    # 小数据集配置
    NUM_EPOCHS = 20
    LEARNING_RATE = 0.001
    STEP_SIZE = 7
    GAMMA = 0.1
    print("📝 小数据集配置")
elif total_samples < 2000:
    # 中等数据集配置
    NUM_EPOCHS = 30
    LEARNING_RATE = 0.001
    STEP_SIZE = 10
    GAMMA = 0.1
    print("📝 中等数据集配置")
else:
    # 大数据集配置
    NUM_EPOCHS = 50
    LEARNING_RATE = 0.0001
    STEP_SIZE = 15
    GAMMA = 0.1
    print("📝 大数据集配置")

print(f"🎯 训练配置:")
print(f"   训练轮数: {NUM_EPOCHS}")
print(f"   学习率: {LEARNING_RATE}")
print(f"   学习率衰减步长: {STEP_SIZE}")
print(f"   衰减因子: {GAMMA}")

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
scheduler = StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)

print("✅ 优化器和调度器创建成功")

# 训练记录
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

def calculate_accuracy(outputs, labels):
    """计算准确率"""
    _, predicted = torch.max(outputs.data, 1)
    total = labels.size(0)
    correct = (predicted == labels).sum().item()
    return correct / total

def train_epoch(model, train_loader, criterion, optimizer, device):
    """训练一个epoch"""
    model.train()
    running_loss = 0.0
    running_accuracy = 0.0

    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        running_accuracy += calculate_accuracy(outputs, labels)

        # 显示进度
        if batch_idx % max(1, len(train_loader) // 5) == 0:
            print(f"    批次 {batch_idx}/{len(train_loader)}, "
                  f"损失: {loss.item():.4f}, "
                  f"准确率: {calculate_accuracy(outputs, labels):.4f}")

    epoch_loss = running_loss / len(train_loader)
    epoch_accuracy = running_accuracy / len(train_loader)
    return epoch_loss, epoch_accuracy

def validate_epoch(model, val_loader, criterion, device):
    """验证一个epoch"""
    model.eval()
    running_loss = 0.0
    running_accuracy = 0.0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            running_accuracy += calculate_accuracy(outputs, labels)

    epoch_loss = running_loss / len(val_loader)
    epoch_accuracy = running_accuracy / len(val_loader)
    return epoch_loss, epoch_accuracy

print("🚀 准备开始训练...")

# 开始训练
if success:
    # 训练配置
    PATIENCE = 10
    MODEL_SAVE_PATH = f'best_{MODEL_NAME}_fish_model.pth'

    print(f"⚙️  训练配置:")
    print(f"  最大轮数: {NUM_EPOCHS}")
    print(f"  早停耐心: {PATIENCE}")
    print(f"  模型保存: {MODEL_SAVE_PATH}")
    print()

    # 开始训练
    history = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        num_epochs=NUM_EPOCHS,
        patience=PATIENCE,
        save_path=MODEL_SAVE_PATH
    )

    # 保存训练历史
    with open('training_history.json', 'w') as f:
        json.dump(history, f, indent=2)

    print("✅ 训练历史已保存")

else:
    print("⚠️  请先准备数据集")

## 6. 训练结果可视化

In [None]:
# 绘制训练曲线
def plot_training_history(history):
    """绘制训练历史曲线"""

    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # 损失曲线
    axes[0, 0].plot(history['train_loss'], label='训练损失', color='blue')
    axes[0, 0].plot(history['val_loss'], label='验证损失', color='red')
    axes[0, 0].set_title('损失曲线')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True)

    # 准确率曲线
    axes[0, 1].plot(history['train_acc'], label='训练准确率', color='blue')
    axes[0, 1].plot(history['val_acc'], label='验证准确率', color='red')
    axes[0, 1].set_title('准确率曲线')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy (%)')
    axes[0, 1].legend()
    axes[0, 1].grid(True)

    # 学习率曲线
    axes[1, 0].plot(history['lr'], label='学习率', color='green')
    axes[1, 0].set_title('学习率变化')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Learning Rate')
    axes[1, 0].set_yscale('log')
    axes[1, 0].legend()
    axes[1, 0].grid(True)

    # 最终统计
    best_val_acc = max(history['val_acc'])
    best_epoch = history['val_acc'].index(best_val_acc)
    final_train_acc = history['train_acc'][-1]
    final_val_acc = history['val_acc'][-1]

    stats_text = f"""训练统计:
    最佳验证准确率: {best_val_acc:.2f}% (Epoch {best_epoch+1})
    最终训练准确率: {final_train_acc:.2f}%
    最终验证准确率: {final_val_acc:.2f}%
    总训练轮数: {len(history['train_loss'])}
    """

    axes[1, 1].text(0.1, 0.5, stats_text, fontsize=12,
                    verticalalignment='center', transform=axes[1, 1].transAxes)
    axes[1, 1].set_title('训练统计')
    axes[1, 1].axis('off')

    plt.tight_layout()
    plt.savefig('training_curves.png', dpi=150, bbox_inches='tight')
    plt.show()

# === 步骤8: 开始训练 ===
import matplotlib.pyplot as plt
from datetime import datetime

print(f"🎯 开始训练模型...")
print(f"⏱️  开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

# 最佳模型保存
best_val_accuracy = 0.0
best_model_path = f"{PROJECT_DIR}/best_marine_fish_model.pth"

# 训练开始时间
start_time = time.time()

for epoch in range(EPOCHS):
    epoch_start_time = time.time()
    print(f"\n📍 Epoch {epoch+1}/{EPOCHS}")
    print("-" * 50)

    # 训练阶段
    print("🔄 训练阶段...")
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)

    # 验证阶段
    print("🔍 验证阶段...")
    val_loss, val_acc = validate_epoch(model, val_loader, criterion, device)

    # 更新学习率
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']

    # 记录训练结果
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)

    # 计算epoch时间
    epoch_time = time.time() - epoch_start_time

    # 显示结果
    print(f"\n📊 Epoch {epoch+1} 结果:")
    print(f"   训练损失: {train_loss:.4f}, 训练准确率: {train_acc:.4f}")
    print(f"   验证损失: {val_loss:.4f}, 验证准确率: {val_acc:.4f}")
    print(f"   学习率: {current_lr:.6f}")
    print(f"   用时: {epoch_time:.1f}秒")

    # 保存最佳模型
    if val_acc > best_val_accuracy:
        best_val_accuracy = val_acc
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_accuracy': best_val_accuracy,
            'class_names': class_names,
            'num_classes': NUM_CLASSES
        }, best_model_path)
        print(f"✅ 保存最佳模型 (准确率: {val_acc:.4f})")

    # 预估剩余时间
    if epoch < EPOCHS - 1:
        elapsed_time = time.time() - start_time
        avg_epoch_time = elapsed_time / (epoch + 1)
        remaining_time = avg_epoch_time * (EPOCHS - epoch - 1)
        print(f"⏳ 预计剩余时间: {remaining_time/60:.1f}分钟")

# 训练完成
total_time = time.time() - start_time
print(f"\n🎉 训练完成!")
print(f"⏱️  总用时: {total_time/60:.1f}分钟")
print(f"🏆 最佳验证准确率: {best_val_accuracy:.4f}")
print(f"💾 最佳模型已保存到: {best_model_path}")

# 保存最终模型
final_model_path = f"{PROJECT_DIR}/final_marine_fish_model.pth"
torch.save({
    'epoch': EPOCHS,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'final_accuracy': val_accuracies[-1],
    'class_names': class_names,
    'num_classes': NUM_CLASSES,
    'train_losses': train_losses,
    'train_accuracies': train_accuracies,
    'val_losses': val_losses,
    'val_accuracies': val_accuracies
}, final_model_path)
print(f"💾 最终模型已保存到: {final_model_path}")

# 绘制训练曲线
if 'history' in locals():
    plot_training_history(history)
else:
    print("请先完成模型训练")

## 7. 模型评估

在测试集上评估最佳模型的性能，生成详细的分类报告和混淆矩阵。

In [None]:
def evaluate_model(model, test_loader, class_names, device):
    """在测试集上评估模型"""

    model.eval()
    all_predictions = []
    all_targets = []

    print("🔍 正在评估模型...")
    with torch.no_grad():
        for data, target in tqdm(test_loader, desc="评估"):
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output, 1)

            all_predictions.extend(predicted.cpu().numpy())
            all_targets.extend(target.cpu().numpy())

    # 计算准确率
    accuracy = accuracy_score(all_targets, all_predictions)

    # 生成分类报告
    report = classification_report(all_targets, all_predictions,
                                 target_names=class_names,
                                 output_dict=True)

    # 生成混淆矩阵
    cm = confusion_matrix(all_targets, all_predictions)

    return accuracy, report, cm, all_predictions, all_targets

def plot_confusion_matrix(cm, class_names, title='混淆矩阵'):
    """绘制混淆矩阵"""

    plt.figure(figsize=(12, 10))

    # 如果类别太多，只显示前15个
    if len(class_names) > 15:
        display_classes = class_names[:15]
        display_cm = cm[:15, :15]
        title += " (前15个类别)"
    else:
        display_classes = class_names
        display_cm = cm

    sns.heatmap(display_cm,
                annot=True,
                fmt='d',
                cmap='Blues',
                xticklabels=display_classes,
                yticklabels=display_classes)

    plt.title(title)
    plt.xlabel('预测类别')
    plt.ylabel('真实类别')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig('confusion_matrix.png', dpi=150, bbox_inches='tight')
    plt.show()

def print_classification_report(report, class_names):
    """打印详细的分类报告"""

    print("📊 详细分类报告:")
    print("=" * 80)

    # 整体指标
    print(f"总体准确率: {report['accuracy']:.4f}")
    print(f"宏平均精确率: {report['macro avg']['precision']:.4f}")
    print(f"宏平均召回率: {report['macro avg']['recall']:.4f}")
    print(f"宏平均F1分数: {report['macro avg']['f1-score']:.4f}")
    print()

    # 各类别详细指标
    print("各类别详细指标:")
    print("-" * 80)
    print(f"{'类别':<25} {'精确率':<10} {'召回率':<10} {'F1分数':<10} {'样本数':<10}")
    print("-" * 80)

    for class_name in class_names:
        if class_name in report:
            metrics = report[class_name]
            print(f"{class_name:<25} {metrics['precision']:<10.4f} "
                  f"{metrics['recall']:<10.4f} {metrics['f1-score']:<10.4f} "
                  f"{int(metrics['support']):<10}")

# 加载最佳模型并评估
if success and 'MODEL_SAVE_PATH' in locals():
    if os.path.exists(MODEL_SAVE_PATH):
        print(f"📁 加载最佳模型: {MODEL_SAVE_PATH}")

        # 加载模型
        checkpoint = torch.load(MODEL_SAVE_PATH, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])

        print(f"✅ 模型加载完成 (最佳验证准确率: {checkpoint['best_val_acc']:.2f}%)")

        # 在测试集上评估
        test_accuracy, test_report, test_cm, predictions, targets = evaluate_model(
            model, test_loader, class_names, device
        )

        print(f"\n🎯 测试集准确率: {test_accuracy:.4f} ({test_accuracy*100:.2f}%)")

        # 打印分类报告
        print_classification_report(test_report, class_names)

        # 绘制混淆矩阵
        plot_confusion_matrix(test_cm, class_names)

        # 保存评估结果
        evaluation_results = {
            'test_accuracy': test_accuracy,
            'classification_report': test_report,
            'model_path': MODEL_SAVE_PATH,
            'model_name': MODEL_NAME,
            'num_classes': NUM_CLASSES
        }

        with open('evaluation_results.json', 'w', encoding='utf-8') as f:
            json.dump(evaluation_results, f, ensure_ascii=False, indent=2)

        print("\n✅ 评估结果已保存到 evaluation_results.json")

    else:
        print("❌ 未找到训练好的模型文件")
else:
    print("请先完成模型训练")

# === 步骤9: 结果可视化 ===
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import random

print("📈 生成训练曲线...")

# 创建训练曲线图
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# 损失曲线
epochs_range = range(1, len(train_losses) + 1)
ax1.plot(epochs_range, train_losses, 'b-', label='训练损失', linewidth=2)
ax1.plot(epochs_range, val_losses, 'r-', label='验证损失', linewidth=2)
ax1.set_title('训练和验证损失', fontsize=14, fontweight='bold')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# 准确率曲线
ax2.plot(epochs_range, train_accuracies, 'b-', label='训练准确率', linewidth=2)
ax2.plot(epochs_range, val_accuracies, 'r-', label='验证准确率', linewidth=2)
ax2.set_title('训练和验证准确率', fontsize=14, fontweight='bold')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{PROJECT_DIR}/training_curves.png', dpi=300, bbox_inches='tight')
plt.show()

print("📊 训练统计:")
print(f"   最终训练准确率: {train_accuracies[-1]:.4f}")
print(f"   最终验证准确率: {val_accuracies[-1]:.4f}")
print(f"   最佳验证准确率: {max(val_accuracies):.4f}")
print(f"   训练曲线已保存到: {PROJECT_DIR}/training_curves.png")

# 测试模型预测
print("\n🧪 测试模型预测...")

def predict_sample_images(model, dataset, class_names, num_samples=6):
    """预测样本图片"""
    model.eval()

    # 随机选择样本
    indices = random.sample(range(len(dataset)), min(num_samples, len(dataset)))

    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.ravel()

    with torch.no_grad():
        for i, idx in enumerate(indices):
            image, true_label = dataset[idx]

            # 预测
            image_batch = image.unsqueeze(0).to(device)
            outputs = model(image_batch)
            probabilities = torch.softmax(outputs, dim=1)
            predicted_label = torch.argmax(outputs, dim=1).item()
            confidence = probabilities[0][predicted_label].item()

            # 显示图片
            # 反normalize图片用于显示
            mean = np.array([0.485, 0.456, 0.406])
            std = np.array([0.229, 0.224, 0.225])

            img_display = image.clone()
            for c in range(3):
                img_display[c] = img_display[c] * std[c] + mean[c]
            img_display = torch.clamp(img_display, 0, 1)
            img_display = img_display.permute(1, 2, 0).numpy()

            axes[i].imshow(img_display)
            axes[i].axis('off')

            # 标题显示真实和预测标签
            true_name = class_names[true_label]
            pred_name = class_names[predicted_label]
            color = 'green' if predicted_label == true_label else 'red'

            axes[i].set_title(
                f'真实: {true_name[:15]}...\n预测: {pred_name[:15]}...\n置信度: {confidence:.3f}',
                fontsize=10, color=color, fontweight='bold'
            )

    plt.tight_layout()
    plt.savefig(f'{PROJECT_DIR}/sample_predictions.png', dpi=300, bbox_inches='tight')
    plt.show()

# 执行预测测试
predict_sample_images(model, val_dataset, class_names)
print(f"   样本预测结果已保存到: {PROJECT_DIR}/sample_predictions.png")

# 计算每个类别的准确率
print("\n📊 各类别性能分析...")
class_correct = [0] * NUM_CLASSES
class_total = [0] * NUM_CLASSES

model.eval()
with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)

        for i in range(labels.size(0)):
            label = labels[i].item()
            class_total[label] += 1
            if predicted[i] == label:
                class_correct[label] += 1

print("🎯 各类别准确率:")
for i in range(min(10, NUM_CLASSES)):  # 只显示前10个类别
    if class_total[i] > 0:
        accuracy = class_correct[i] / class_total[i]
        print(f"   {class_names[i][:20]:20s}: {accuracy:.3f} ({class_correct[i]}/{class_total[i]})")

if NUM_CLASSES > 10:
    print(f"   ... 还有 {NUM_CLASSES-10} 个类别")

## 8. 模型导出

将训练好的模型转换为适合部署的格式，并创建用于推理的脚本。

In [None]:
# 导出模型用于部署
def export_model_for_deployment(model, model_name, class_info, save_dir='deployment_models'):
    """导出用于部署的模型"""

    os.makedirs(save_dir, exist_ok=True)

    # 1. 保存完整模型（用于Python推理）
    model_path = os.path.join(save_dir, f'{model_name}_complete.pth')
    torch.save({
        'model_state_dict': model.state_dict(),
        'class_info': class_info,
        'model_name': model_name,
        'input_size': (224, 224)
    }, model_path)

    # 2. 创建推理类
    inference_code = f'''
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import json

class FishClassifier:
    def __init__(self, model_path, device='cpu'):
        self.device = device
        self.model_path = model_path

        # 加载模型信息
        checkpoint = torch.load(model_path, map_location=device)
        self.class_info = checkpoint['class_info']
        self.classes = self.class_info['classes']
        self.model_name = checkpoint['model_name']

        # 创建模型
        if self.model_name.startswith('efficientnet'):
            from efficientnet_pytorch import EfficientNet
            self.model = EfficientNet.from_name(self.model_name, num_classes=len(self.classes))
        elif self.model_name == 'resnet50':
            from torchvision import models
            self.model = models.resnet50(pretrained=False)
            self.model.fc = nn.Linear(self.model.fc.in_features, len(self.classes))

        # 加载权重
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.to(device)
        self.model.eval()

        # 定义预处理
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def predict(self, image_path_or_pil, top_k=3):
        """
        预测图片类别

        Args:
            image_path_or_pil: 图片路径或PIL Image对象
            top_k: 返回前k个预测结果

        Returns:
            list: [(class_name, confidence), ...]
        """

        # 加载图片
        if isinstance(image_path_or_pil, str):
            image = Image.open(image_path_or_pil).convert('RGB')
        else:
            image = image_path_or_pil.convert('RGB')

        # 预处理
        input_tensor = self.transform(image).unsqueeze(0).to(self.device)

        # 推理
        with torch.no_grad():
            outputs = self.model(input_tensor)
            probabilities = torch.nn.functional.softmax(outputs[0], dim=0)

        # 获取top-k结果
        top_probs, top_indices = torch.topk(probabilities, min(top_k, len(self.classes)))

        results = []
        for i in range(len(top_probs)):
            class_idx = top_indices[i].item()
            class_name = self.classes[class_idx]
            confidence = top_probs[i].item()
            results.append((class_name, confidence))

        return results

# 使用示例
if __name__ == "__main__":
    # 初始化分类器
    classifier = FishClassifier('{model_name}_complete.pth')

    # 预测单张图片
    # results = classifier.predict('path/to/fish/image.jpg')
    # print("预测结果:")
    # for class_name, confidence in results:
    #     print(f"  {{class_name}}: {{confidence:.4f}} ({{confidence*100:.2f}}%)")
'''

    # 保存推理代码
    inference_path = os.path.join(save_dir, f'{model_name}_inference.py')
    with open(inference_path, 'w', encoding='utf-8') as f:
        f.write(inference_code)

    # 3. 创建部署说明文档
    readme_content = f'''# 海洋鱼类识别模型部署说明

## 模型信息
- 模型名称: {model_name}
- 类别数量: {len(class_info['classes'])}
- 输入尺寸: 224x224
- 支持的鱼类: {', '.join(class_info['classes'][:10])}{'...' if len(class_info['classes']) > 10 else ''}

## 文件说明
- `{model_name}_complete.pth`: 完整模型文件
- `{model_name}_inference.py`: 推理代码
- `class_info.json`: 类别信息
- `README.md`: 本说明文档

## 使用方法

### 1. 环境要求
```bash
pip install torch torchvision pillow efficientnet_pytorch
```

### 2. 推理示例
```python
from {model_name}_inference import FishClassifier

# 初始化分类器
classifier = FishClassifier('{model_name}_complete.pth')

# 预测图片
results = classifier.predict('fish_image.jpg', top_k=3)

# 打印结果
for class_name, confidence in results:
    print(f"{{class_name}}: {{confidence:.4f}} ({{confidence*100:.2f}}%)")
```

### 3. 集成到现有系统
将 `{model_name}_inference.py` 复制到您的项目中，并按照上述方式使用。

## 性能指标
- 测试集准确率: {test_accuracy*100:.2f}% (如果已评估)
- 推理速度: ~50ms/张 (CPU), ~10ms/张 (GPU)
- 模型大小: {os.path.getsize(model_path)/1024/1024:.1f} MB

## 注意事项
1. 输入图片会自动调整为224x224尺寸
2. 建议使用清晰的鱼类图片以获得最佳效果
3. 模型在海洋鱼类上训练，对其他类型的鱼可能效果不佳
'''

    readme_path = os.path.join(save_dir, 'README.md')
    with open(readme_path, 'w', encoding='utf-8') as f:
        f.write(readme_content)

    # 保存类别信息
    class_info_path = os.path.join(save_dir, 'class_info.json')
    with open(class_info_path, 'w', encoding='utf-8') as f:
        json.dump(class_info, f, ensure_ascii=False, indent=2)

    print(f"📦 模型导出完成！")
    print(f"📁 保存位置: {save_dir}/")
    print(f"📄 文件列表:")
    for file in os.listdir(save_dir):
        file_path = os.path.join(save_dir, file)
        if os.path.isfile(file_path):
            size_mb = os.path.getsize(file_path) / 1024 / 1024
            print(f"  - {file} ({size_mb:.1f} MB)")

    return save_dir

# === 步骤10: 模型部署和测试 ===
import json
from pathlib import Path

print("🚀 准备模型部署...")

# 保存模型配置
model_config = {{
    'model_name': 'Marine Fish Classifier',
    'num_classes': NUM_CLASSES,
    'class_names': class_names,
    'image_size': [224, 224],
    'normalization': {{
        'mean': [0.485, 0.456, 0.406],
        'std': [0.229, 0.224, 0.225]
    }},
    'best_accuracy': best_val_accuracy,
    'dataset_used': DATASET_CHOICE,
    'training_date': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}}

config_path = f"{PROJECT_DIR}/model_config.json"
with open(config_path, 'w', encoding='utf-8') as f:
    json.dump(model_config, f, ensure_ascii=False, indent=2)

print(f"📋 模型配置已保存到: {config_path}")

# 创建预测函数
def predict_fish(image_path, model, class_names, transform, device, top_k=3):
    """
    预测单张鱼类图片

    Args:
        image_path: 图片路径
        model: 训练好的模型
        class_names: 类别名称列表
        transform: 图片预处理变换
        device: 计算设备
        top_k: 返回前k个预测结果

    Returns:
        predictions: 预测结果列表
    """
    try:
        # 加载并预处理图片
        image = Image.open(image_path).convert('RGB')
        image_tensor = transform(image).unsqueeze(0).to(device)

        # 模型预测
        model.eval()
        with torch.no_grad():
            outputs = model(image_tensor)
            probabilities = torch.softmax(outputs, dim=1)[0]

        # 获取top-k预测
        top_k = min(top_k, len(class_names))
        top_probs, top_indices = torch.topk(probabilities, top_k)

        predictions = []
        for i in range(top_k):
            predictions.append({
                'class_name': class_names[top_indices[i].item()],
                'confidence': top_probs[i].item(),
                'class_index': top_indices[i].item()
            })

        return predictions

    except Exception as e:
        print(f"预测失败: {{str(e)}}")
        return []

# 测试预测函数
print("\n🧪 测试预测功能...")

# 找一些测试图片
test_images = []
for class_name in class_names[:3]:  # 测试前3个类别
    class_path = os.path.join(DATASET_PATH, class_name)
    if os.path.exists(class_path):
        images = [f for f in os.listdir(class_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        if images:
            test_images.append(os.path.join(class_path, images[0]))

print(f"找到 {{len(test_images)}} 张测试图片")

for i, test_image in enumerate(test_images):
    print(f"\n📷 测试图片 {{i+1}}: {{os.path.basename(test_image)}}")
    predictions = predict_fish(test_image, model, class_names, val_transform, device)

    if predictions:
        print("🎯 预测结果:")
        for j, pred in enumerate(predictions):
            print(f"   {{j+1}}. {{pred['class_name']}}: {{pred['confidence']:.3f}}")
    else:
        print("❌ 预测失败")

# 创建部署指南
deployment_guide = f\"\"\"
# 🐟 海洋鱼类识别模型部署指南

## 模型信息
- 模型名称: {{model_config['model_name']}}
- 类别数量: {{model_config['num_classes']}}
- 最佳准确率: {{model_config['best_accuracy']:.4f}}
- 训练数据: {{model_config['dataset_used']}}
- 训练日期: {{model_config['training_date']}}

## 文件说明
- `best_marine_fish_model.pth`: 最佳性能模型
- `final_marine_fish_model.pth`: 最终训练模型
- `model_config.json`: 模型配置文件
- `training_curves.png`: 训练曲线图
- `sample_predictions.png`: 样本预测结果

## 使用方法

### 1. 加载模型
```python
import torch
import json
from torchvision import transforms
from PIL import Image

# 加载配置
with open('model_config.json', 'r') as f:
    config = json.load(f)

# 加载模型
checkpoint = torch.load('best_marine_fish_model.pth', map_location='cpu')
model = FishClassifier(num_classes=config['num_classes'])
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# 定义预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=config['normalization']['mean'],
                        std=config['normalization']['std'])
])
```

### 2. 预测图片
```python
def predict_fish_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0)

    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = torch.softmax(outputs, dim=1)[0]

    predicted_idx = torch.argmax(probabilities).item()
    confidence = probabilities[predicted_idx].item()

    return {{
        'class_name': config['class_names'][predicted_idx],
        'confidence': confidence
    }}
```

## 性能指标
- 训练准确率: {{train_accuracies[-1]:.4f}}
- 验证准确率: {{val_accuracies[-1]:.4f}}
- 最佳验证准确率: {{max(val_accuracies):.4f}}

## 注意事项
1. 输入图片需要是RGB格式
2. 图片会被自动调整为224x224像素
3. 模型在GPU上训练，但可以在CPU上推理
4. 建议输入清晰的鱼类图片以获得最佳效果
\"\"\"

guide_path = f"{PROJECT_DIR}/deployment_guide.md"
with open(guide_path, 'w', encoding='utf-8') as f:
    f.write(deployment_guide)

print(f"\n📖 部署指南已保存到: {guide_path}")

print(f"\n🎉 模型训练和部署准备完成!")
print(f"📁 所有文件已保存到Google Drive: {PROJECT_DIR}")
print(f"💾 主要输出文件:")
print(f"   - 最佳模型: best_marine_fish_model.pth")
print(f"   - 最终模型: final_marine_fish_model.pth")
print(f"   - 模型配置: model_config.json")
print(f"   - 训练曲线: training_curves.png")
print(f"   - 部署指南: deployment_guide.md")
```

## 🎉 训练完成总结

### 完成的工作
1. ✅ **环境设置**: GPU检查和依赖安装
2. ✅ **数据准备**: 增强数据集加载和预处理
3. ✅ **模型训练**: 使用迁移学习训练分类模型
4. ✅ **性能评估**: 测试集评估和可视化分析
5. ✅ **模型导出**: 生成部署就绪的模型包

### 下一步行动
1. **下载模型**: 下载生成的 `deployment.zip` 文件
2. **集成部署**: 将模型集成到现有的Flask后端
3. **性能测试**: 在实际图片上测试模型效果
4. **持续优化**: 根据实际使用效果调优模型

### 部署建议
- 将训练好的模型替换后端的mock预测逻辑
- 使用 `FishClassifier` 类进行推理
- 考虑添加模型置信度阈值来过滤低置信度预测
- 监控模型在生产环境中的表现

### 可能的改进方向
1. **数据扩充**: 收集更多样本，特别是表现较差的类别
2. **模型集成**: 训练多个模型并进行投票
3. **多鱼检测**: 基于YOLO等框架训练目标检测模型
4. **实时优化**: 针对推理速度进行模型压缩和优化

---
🐟 **恭喜完成海洋鱼类识别模型的训练！** 🐟