In [40]:
import os
import copy
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets 
import torchvision.transforms as transforms
import pandas as pd
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts,ExponentialLR
from sklearn.model_selection import train_test_split
from PIL import Image
import cv2
import albumentations
from albumentations.pytorch.transforms import ToTensorV2

In [41]:
"""
为每个标签创建一个唯一的数字编号（从0开始），
并将这个数字编号作为新的列添加到DataFrame中，最后保存为新的CSV文件。
"""
train = pd.read_csv("../dataset/train.csv")
labels = list(train['label'])  # 提取标签列并转换为列表
labels_unique = list(set(labels))   # set去除重复标签，list转换为唯一标签列表

# 创建数字标签列表
label_nums = []
for i in range(len(labels)):
    # 将文本标签转换为数字编号：
    # 当前标签在唯一标签列表中的索引作为其数字标签
    label_nums.append(labels_unique.index(labels[i]))
    
train['number'] = label_nums
train.to_csv("../dataset/train_num_label.csv", index=0)  # 记录对应关系

In [42]:
# 读取测试集
test_df = pd.read_csv("../dataset/test.csv")  
# 将数据集 train 分割成训练集和验证集
# stratify：按数字标签进行分层抽样，确保训练集和验证集中各类别的比例相同
train_data, eval_data = train_test_split(train, test_size=0.2, stratify=train['number'])  

In [43]:
class Leaf_Dataset(Dataset):
    '''
    树叶数据集的训练集 自定义Dataset
    '''
    def __init__(self, train_csv, transform=None, test=False):
        '''
        train_csv : 传入记录图像路径及其标号的csv文件
        transform : 对图像进行的变换
        test : 是否为测试集模式
        '''
        super().__init__()
        self.train_csv = train_csv
        self.image_path = list(self.train_csv['image'])  # 提取所有图像路径
        self.transform = transform  
        self.test = test
        # 如果不是测试集，加载标签信息
        if not self.test:
            self.label_nums = list(self.train_csv['number'])  # 提取数字标签列表

    def __getitem__(self, idx):
        '''
        获取单个样本
        idx : 样本索引
        return : image, label
        '''
        # 读取图像
        image = cv2.imread(os.path.join('../dataset', self.image_path[idx]))
        # 转换颜色空间 BGR→RGB (OpenCV默认BGR，PyTorch需要RGB)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        # image = Image.open(os.path.join('../dataset', self.image_path[idx]))
        if self.transform != None:
            image = self.transform(image=image)['image']
        # 返回结果：测试集只返回图像，训练集返回图像+标签
        if not self.test:
            label = self.label_nums[idx]
            return image, label
        else:
            return image

    def __len__(self):
        '''返回数据集大小（图像数量）'''
        return len(self.image_path)

In [44]:
# 数据增强
transforms_train = albumentations.Compose(
    [
        albumentations.Resize(320, 320),  # 调整图像尺寸到320x320像素
        albumentations.HorizontalFlip(p=0.5),  # 水平翻转（概率50%）
        albumentations.VerticalFlip(p=0.5),  # 垂直翻转（概率50%）
        albumentations.Rotate(limit=180, p=0.7),  # 随机旋转（-180度到+180度，概率70%）
        albumentations.RandomBrightnessContrast(),  # 随机调整亮度和对比度
        
        albumentations.Affine(
            translate_percent=(-0.25, 0.25),   # 平移范围 -25% 到 +25%
            scale=(0.9, 1.1),                  # 缩放范围 90% 到 110%
            rotate=0,                          # 无旋转
            p=0.5                              # 应用概率 50%
        ),
        
        # 图像归一化（使用ImageNet预训练模型的均值和标准差）标准化像素值分布，加速模型收敛
        # 均值: [0.485, 0.456, 0.406]
        # 标准差: [0.229, 0.224, 0.225]
        # max_pixel_value=255.0: 输入图像像素值范围0-255
        albumentations.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], max_pixel_value=255.0),

        # 将图像从numpy数组转换为PyTorch张量
        # p=1.0: 100%概率应用（总是应用）
        ToTensorV2(p=1.0)
    ]
)

# 不添加随机增强：保证评估结果的一致性
transforms_test = albumentations.Compose(
    [
        albumentations.Resize(320, 320),
        albumentations.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], max_pixel_value=255.0),
        ToTensorV2(p=1.0)
    ]
)

In [45]:
def train_model(train_loader, valid_loader, test_df, device=torch.device('cuda:0')):
    # 1. 模型初始化
    # 使用预训练的ResNet50模型
    net = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1)
    # 获取全连接层的输入特征维度
    in_features = net.fc.in_features
    # 替换全连接层以适应176类树叶分类任务
    net.fc = nn.Linear(in_features, 176)
    net = net.to(device)

    # 2. 训练参数设置
    epoch = 30
    best_epoch = 0
    best_score = 0.0
    best_model_state = None  # 保存最佳模型状态
    early_stopping_round = 3  # 早停轮数（验证集性能不再提升时停止）
    losses = []  # 记录每轮训练损失

    # 3. 优化器和损失函数
    optimizer = optim.Adam(net.parameters(), lr=0.0001, weight_decay=1e-5)  # Adam优化器
    loss = nn.CrossEntropyLoss(reduction='mean')  # 交叉熵损失函数
    scheduler = ExponentialLR(optimizer, gamma=0.9)  # 学习率指数衰减调度器

    # 4. 训练循环
    for i in range(epoch):
        acc = 0  # 累计训练准确数
        loss_sum = 0  # 累计训练损失

        # 训练阶段
        net.train()  # 设置模型为训练模式
        for x, y in tqdm(train_loader):  # 遍历训练数据批次
            # 准备输入数据
            x = torch.as_tensor(x, dtype=torch.float).to(device)
            y = y.to(device)

            # 前向传播
            y_hat = net(x)
            # 计算损失
            loss_temp = loss(y_hat, y)
            loss_sum += loss_temp

            # 反向传播和优化
            optimizer.zero_grad()  # 清空梯度
            loss_temp.backward()  # 反向传播计算梯度
            optimizer.step()  # 更新参数

            # 计算准确数
            acc += torch.sum(y_hat.argmax(dim=1).type(y.dtype) == y)
            
        # 更新学习率
        scheduler.step()
        # 记录平均损失
        losses.append(loss_sum.cpu().detach().numpy() / len(train_loader))

        # 打印训练结果
        print("epoch: ", i, 
              "loss=", loss_sum.item(), 
              "train acc: ", (acc/(len(train_loader)*train_loader.batch_size)).item(), end='')

        # 验证阶段
        test_acc = 0  # 累计验证准确数
        net.eval()  # 设置模型为评估模式
        for x, y in tqdm(valid_loader):  # 遍历验证数据批次
            # 准备输入数据
            x = torch.as_tensor(x, dtype=torch.float).to(device)
            y = y.to(device)

            # 前向传播（无梯度计算）
            with torch.no_grad():
                y_hat = net(x)

            # 计算准确数
            test_acc += torch.sum(y_hat.argmax(dim=1).type(y.dtype) == y)

        # 计算并打印验证准确率
        print("valid acc: ", (test_acc / (len(valid_loader)*valid_loader.batch_size)).item())

        # 5. 模型保存与早停
        # 如果当前模型在验证集上表现更好，保存为最佳模型
        if test_acc > best_score:
            best_model_state = copy.deepcopy(net.state_dict())  # 深度复制模型状态
            best_score = test_acc
            best_epoch = i
            print('best epoch save!')

        # 早停机制：如果连续early_stopping_round轮没有提升，停止训练
        if i - best_epoch >= early_stopping_round:
            print(f'Early stopping at epoch {i}')
            break

    # 6. 加载最佳模型
    net.load_state_dict(best_model_state)

    # 7. 在测试集上进行预测
    testset = Leaf_Dataset(test_df, transform=transforms_test, test=True)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, drop_last=False)

    # 进行预测
    predictions = []
    with torch.no_grad():
        for x in tqdm(test_loader):  # 遍历测试数据批次
            x = x.to(device)
            x = torch.as_tensor(x, dtype=torch.float)

            # 前向传播
            y_hat = net(x)

            # 获取预测结果
            predict = torch.argmax(y_hat, dim=1).reshape(-1)
            predict = list(predict.cpu().detach().numpy())
            predictions.extend(predict)
            
    return predictions

In [46]:
# 导入分层K折交叉验证工具
from sklearn.model_selection import StratifiedKFold

# 1. 初始化分层K折交叉验证器
# n_splits=5: 将数据分为5折
# shuffle=True: 打乱数据顺序
# random_state=2025: 随机种子（保证每次划分结果相同）
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=2025)

# 2. 创建空 DataFrame 存储各折的预测结果
prediction_df = pd.DataFrame()

# 3. 开始K折交叉验证循环
# enumerate(skf.split(train, train['number'])):
#   - train: 整个训练集DataFrame
#   - train['number']: 数字标签（用于分层抽样）
for fold_n, (train_idx, val_idx) in enumerate(skf.split(train, train['number'])):
    print(f'fold {fold_n} training...')  # 打印当前折信息

    # 4. 划分训练集和验证集
    # 根据索引获取当前折的训练数据和验证数据
    train_data = train.iloc[train_idx]
    eval_data = train.iloc[val_idx]

    # 5. 创建数据集对象
    # 训练集应用训练时的数据增强
    trainset = Leaf_Dataset(train_data, transform=transforms_train)
    # 验证集应用测试时的变换（无增强）
    evalset = Leaf_Dataset(eval_data, transform=transforms_test)

    # 6. 创建数据加载器
    # 训练集 DataLoader：打乱顺序，用于训练
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, drop_last=False)
    eval_loader = torch.utils.data.DataLoader(evalset, batch_size=32, shuffle=False, drop_last=False)

    # 7. 训练模型并在测试集上预测
    predictions = train_model(train_loader, eval_loader, test_df)

    # 8. 存储当前折的预测结果
    prediction_df[f'fold_{fold_n}'] = predictions

fold 0 training...


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  0 loss= 1290.000244140625 train acc:  0.3598856031894684

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.5986412763595581
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  1 loss= 546.6947631835938 train acc:  0.6814405918121338

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.7994564771652222
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  2 loss= 333.58599853515625 train acc:  0.7997685074806213

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8519021272659302
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  3 loss= 234.7442169189453 train acc:  0.8562772274017334

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8861412405967712
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  4 loss= 175.50152587890625 train acc:  0.8884803652763367

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9163042902946472
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  5 loss= 132.03565979003906 train acc:  0.9202069640159607

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9252716898918152
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  6 loss= 110.65531158447266 train acc:  0.930419385433197

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9320651888847351
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  7 loss= 87.03434753417969 train acc:  0.945942223072052

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9372282028198242
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  8 loss= 76.56438446044922 train acc:  0.9494144916534424

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9540760517120361
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  9 loss= 66.35782623291016 train acc:  0.9571078419685364

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9499999284744263


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  10 loss= 55.73731994628906 train acc:  0.9639841914176941

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9529891014099121


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  11 loss= 50.99494171142578 train acc:  0.9665713310241699

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9570651650428772
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  12 loss= 46.55691146850586 train acc:  0.9678648710250854

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9532608389854431


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  13 loss= 40.029998779296875 train acc:  0.9728349447250366

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9603260159492493
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  14 loss= 35.550941467285156 train acc:  0.9758986830711365

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9586955904960632


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  15 loss= 34.58921432495117 train acc:  0.9732434153556824

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9663043022155762
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  16 loss= 28.296247482299805 train acc:  0.9805963635444641

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9581521153450012


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  17 loss= 29.294584274291992 train acc:  0.9786219596862793

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9619565010070801


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  18 loss= 27.172334671020508 train acc:  0.9811410307884216

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9641304016113281
Early stopping at epoch 18


  0%|          | 0/138 [00:00<?, ?it/s]

fold 1 training...


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  0 loss= 1319.0223388671875 train acc:  0.3513752520084381

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.5760869383811951
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  1 loss= 554.51416015625 train acc:  0.6790577173233032

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.792934775352478
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  2 loss= 335.9294738769531 train acc:  0.8016067147254944

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.824999988079071
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  3 loss= 233.9556121826172 train acc:  0.860838770866394

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8847825527191162
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  4 loss= 170.17950439453125 train acc:  0.8952205777168274

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9051629900932312
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  5 loss= 134.01654052734375 train acc:  0.9155772924423218

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9138586521148682
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  6 loss= 106.93794250488281 train acc:  0.9318491220474243

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9203804135322571
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  7 loss= 88.06417846679688 train acc:  0.9426062107086182

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9402173757553101
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  8 loss= 77.35823822021484 train acc:  0.9502995610237122

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9429347515106201
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  9 loss= 66.6692123413086 train acc:  0.9564950466156006

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9527173638343811
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  10 loss= 55.723228454589844 train acc:  0.9626225233078003

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9535325765609741
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  11 loss= 49.71709060668945 train acc:  0.9656181931495667

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9551630020141602
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  12 loss= 43.09606170654297 train acc:  0.9699074029922485

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9641304016113281
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  13 loss= 40.79607009887695 train acc:  0.9722222089767456

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9567934274673462


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  14 loss= 35.98540115356445 train acc:  0.9752178192138672

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9603260159492493


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  15 loss= 33.287574768066406 train acc:  0.9746050834655762

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9619565010070801
Early stopping at epoch 15


  0%|          | 0/138 [00:00<?, ?it/s]

fold 2 training...


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  0 loss= 1306.3377685546875 train acc:  0.35525599122047424

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.585326075553894
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  1 loss= 544.5120239257812 train acc:  0.6921296119689941

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.7486412525177002
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  2 loss= 333.6719055175781 train acc:  0.7987472414970398

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8524456024169922
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  3 loss= 227.04466247558594 train acc:  0.8627451062202454

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8858695030212402
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  4 loss= 170.54348754882812 train acc:  0.8963099122047424

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9054347276687622
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  5 loss= 134.63429260253906 train acc:  0.9159177541732788

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9203804135322571
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  6 loss= 107.89347839355469 train acc:  0.9313044548034668

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9331521391868591
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  7 loss= 90.518310546875 train acc:  0.9416530132293701

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9339673519134521
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  8 loss= 77.7706527709961 train acc:  0.9505037665367126

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9513586759567261
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  9 loss= 62.35662078857422 train acc:  0.9583333134651184

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9475542902946472


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  10 loss= 56.0413703918457 train acc:  0.962282121181488

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9595108032226562
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  11 loss= 50.66645431518555 train acc:  0.9666393995285034

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.960869550704956
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  12 loss= 45.02617645263672 train acc:  0.9692265391349792

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9627717137336731
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  13 loss= 42.7098503112793 train acc:  0.9703839421272278

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9649456143379211
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  14 loss= 37.902793884277344 train acc:  0.9745370149612427

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9652173519134521
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  15 loss= 33.72583770751953 train acc:  0.976443350315094

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9668477773666382
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  16 loss= 31.016061782836914 train acc:  0.9771241545677185

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9657608270645142


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  17 loss= 29.352880477905273 train acc:  0.9781454205513

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9728260636329651
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  18 loss= 26.30277442932129 train acc:  0.9805282950401306

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9684782028198242


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  19 loss= 24.485157012939453 train acc:  0.9814814329147339

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9671195149421692


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  20 loss= 22.473047256469727 train acc:  0.983523964881897

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9714673757553101
Early stopping at epoch 20


  0%|          | 0/138 [00:00<?, ?it/s]

fold 3 training...


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  0 loss= 1315.7860107421875 train acc:  0.35253268480300903

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.5554347634315491
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  1 loss= 569.63671875 train acc:  0.6655773520469666

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.7788043022155762
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  2 loss= 342.25067138671875 train acc:  0.7958877682685852

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8388586640357971
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  3 loss= 238.70567321777344 train acc:  0.8536900877952576

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8733695149421692
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  4 loss= 176.0605926513672 train acc:  0.8916802406311035

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9010869264602661
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  5 loss= 137.25892639160156 train acc:  0.9146241545677185

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.919021725654602
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  6 loss= 114.60398864746094 train acc:  0.9269471168518066

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9301630258560181
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  7 loss= 88.9895248413086 train acc:  0.9428104162216187

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9364129900932312
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  8 loss= 78.27359008789062 train acc:  0.9504356980323792

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9467390775680542
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  9 loss= 65.63002014160156 train acc:  0.9576525092124939

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9559782147407532
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  10 loss= 56.473575592041016 train acc:  0.9627587199211121

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9491847157478333


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  11 loss= 51.79374694824219 train acc:  0.9656862616539001

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9576086401939392
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  12 loss= 46.55669021606445 train acc:  0.9682053327560425

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9595108032226562
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  13 loss= 41.27650833129883 train acc:  0.9708605408668518

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9584238529205322


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  14 loss= 36.256412506103516 train acc:  0.9734476804733276

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9649456143379211
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  15 loss= 32.74489212036133 train acc:  0.976443350315094

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.960869550704956


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  16 loss= 31.300424575805664 train acc:  0.977056086063385

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9657608270645142
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  17 loss= 28.481853485107422 train acc:  0.9789624214172363

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9665760397911072
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  18 loss= 26.2423038482666 train acc:  0.9806644916534424

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9673912525177002
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  19 loss= 24.98631477355957 train acc:  0.9808006286621094

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9682064652442932
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  20 loss= 23.585763931274414 train acc:  0.9825026988983154

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9665760397911072


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  21 loss= 22.074758529663086 train acc:  0.9835920333862305

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9690216779708862
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  22 loss= 21.32380485534668 train acc:  0.9840686321258545

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9692934155464172
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  23 loss= 20.10523223876953 train acc:  0.9846813678741455

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9692934155464172


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  24 loss= 19.015336990356445 train acc:  0.9849537014961243

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9684782028198242


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  25 loss= 19.58958625793457 train acc:  0.9844090342521667

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9698368906974792
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  26 loss= 17.517620086669922 train acc:  0.9863153100013733

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9714673757553101
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  27 loss= 17.25578498840332 train acc:  0.9867238402366638

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9714673757553101


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  28 loss= 17.300249099731445 train acc:  0.9867238402366638

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9711955785751343


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  29 loss= 15.784106254577637 train acc:  0.9871323108673096

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9698368906974792
Early stopping at epoch 29


  0%|          | 0/138 [00:00<?, ?it/s]

fold 4 training...


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  0 loss= 1312.33544921875 train acc:  0.35532405972480774

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.6008151769638062
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  1 loss= 543.9176025390625 train acc:  0.6836192607879639

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.7698369026184082
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  2 loss= 335.3684387207031 train acc:  0.8030364513397217

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8494564890861511
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  3 loss= 229.05911254882812 train acc:  0.8596813678741455

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8714673519134521
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  4 loss= 168.6392364501953 train acc:  0.8986247181892395

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9122282266616821
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  5 loss= 129.6028594970703 train acc:  0.9180963635444641

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9171195030212402
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  6 loss= 107.14385986328125 train acc:  0.9333469271659851

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9388586282730103
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  7 loss= 88.74978637695312 train acc:  0.9416530132293701

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9374999403953552


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  8 loss= 74.25508880615234 train acc:  0.9526143670082092

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9508152008056641
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  9 loss= 64.20774841308594 train acc:  0.9576525092124939

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9546195268630981
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  10 loss= 54.347137451171875 train acc:  0.9641203284263611

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9584238529205322
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  11 loss= 48.442893981933594 train acc:  0.9666393995285034

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9581521153450012


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  12 loss= 43.3475227355957 train acc:  0.9707244038581848

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9641304016113281
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  13 loss= 39.512752532958984 train acc:  0.9728349447250366

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9641304016113281


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  14 loss= 36.40730667114258 train acc:  0.9758986830711365

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9633151888847351


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  15 loss= 32.5826301574707 train acc:  0.9780092239379883

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9644021391868591
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  16 loss= 31.56509780883789 train acc:  0.9779411554336548

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9660325646400452
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  17 loss= 29.011211395263672 train acc:  0.9790304899215698

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9682064652442932
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  18 loss= 25.478715896606445 train acc:  0.9818218946456909

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9668477773666382


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  19 loss= 24.460115432739258 train acc:  0.9832516312599182

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9654890894889832


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  20 loss= 24.2436466217041 train acc:  0.9826388359069824

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9671195149421692
Early stopping at epoch 20


  0%|          | 0/138 [00:00<?, ?it/s]

In [47]:
print(prediction_df)

      fold_0  fold_1  fold_2  fold_3  fold_4
0         37      37      37      37      37
1        159     159     159     159     159
2        133     133     133     133     133
3        117     117     117     117     117
4        133     133     133     133     133
...      ...     ...     ...     ...     ...
8795     168     168     168     168     116
8796      59      59      59      59      59
8797     168     168     168     168     168
8798     121     121     121     121     121
8799     121     121     121     121     121

[8800 rows x 5 columns]


In [48]:
# 1. 获取各交叉验证折的最终预测结果（众数投票）
# prediction_df包含各折的预测结果（每列代表一折的预测）
# mode(axis=1)[0]：对每行取众数, 当存在多个众数时，取第一个出现的众数
# astype(int)：将结果转换为整数类型（类别索引）
all_predictions = list(prediction_df.mode(axis=1)[0].astype(int))

# 2. 将数字标签转换回原始文本标签
predict_label = []
for i in range(len(all_predictions)):
    # 使用之前创建的 labels_unique 列表（包含所有唯一类别名称）
    # 将数字索引转换为对应的类别名称
    predict_label.append(labels_unique[all_predictions[i]])

# 3. 准备提交文件
# 读取测试集原始CSV（包含image列）
submission = pd.read_csv('../dataset/test.csv')

# 4. 添加预测结果列
# 将预测的文本标签添加到 DataFrame 的新列'label'中
submission['label'] = predict_label

# 5. 保存结果文件
# 不保存索引列（index=False）
submission.to_csv('../dataset/result.csv', index=False)

In [49]:
submission

Unnamed: 0,image,label
0,images/18353.jpg,asimina_triloba
1,images/18354.jpg,betula_nigra
2,images/18355.jpg,platanus_acerifolia
3,images/18356.jpg,pinus_bungeana
4,images/18357.jpg,platanus_acerifolia
...,...,...
8795,images/27148.jpg,pinus_thunbergii
8796,images/27149.jpg,crataegus_crus-galli
8797,images/27150.jpg,pinus_thunbergii
8798,images/27151.jpg,juniperus_virginiana
