In [54]:
import os
import cv2
import copy
import torch
import torchvision.datasets
import albumentations

import pandas as pd
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms

from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ExponentialLR, CosineAnnealingLR
from sklearn.model_selection import train_test_split, StratifiedKFold
from PIL import Image
from albumentations.pytorch.transforms import ToTensorV2

In [55]:
"""
为每个标签创建一个唯一的数字编号，从 0 开始
并将这个数字编号作为新的列添加到 DataFrame 中, 最后保存为 train_num_label2.csv
"""
train = pd.read_csv('../dataset/train.csv')
# 对标签列进行数字编码
# pd.factorize() 的编码规则是按标签首次出现的顺序分配编号
# 元组的第一个元素 [0] 就是所需的数值编码序列, 第二个元素[1]是所有出现过的文本标签
train['number'], labels_unique = pd.factorize(train['label'])
train.to_csv('../dataset/train_num_label2.csv', index=False)

In [56]:
# 读取测试集
test = pd.read_csv('../dataset/test.csv')
# 将数据集 train 分割为训练集和验证集
# stratify：按数字标签分层抽样, 确保训练集和验证集中各类别的比例相同
# train_data, eval_data = train_test_split(train, test_size=0.2, stratify=train['number'])

In [57]:
class Leaf_Dataset(Dataset):
    def __init__(self, train_csv, transform=None, test_bool=False):
        '''
        train_csv : 记录图像路径及其标号的csv文件
        transform : 图像变换
        test_bool : 是否为测试集模式
        '''
        super().__init__()
        self.train_csv = train_csv
        self.image_path = list(self.train_csv['image'])  # 提取所有图像路径
        self.transform = transform
        self.test_bool = test_bool
        # 如果不是测试集, 加载标签信息
        if not self.test_bool:
            self.label_nums = list(self.train_csv['number'])  

    def __getitem__(self, idx):
        '''
        获取单个样本
        idx : 样本索引
        return : image, label
        '''
        # 读取图像
        image = cv2.imread(os.path.join('../dataset', self.image_path[idx]))
        # 转换颜色空间 BGR -> RGB（ OpenCV默认BGR, PyTorch需要RGB ）
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform != None:
            image = self.transform(image=image)['image']  # albumentations 同时处理多种类型的数据，以字典形式返回
        # 测试集只返回图像, 训练集返回图像+标签
        if not self.test_bool:
            label = self.label_nums[idx]
            return image, label
        else:
            return image

    def __len__(self):
        ''' 返回数据集大小 '''
        return len(self.image_path)

In [58]:
# 数据增强
transforms_train = albumentations.Compose(
    [
        albumentations.Resize(320, 320),            # 调整图像尺寸到 320x320 像素
        albumentations.HorizontalFlip(p=0.5),       # 概率50%水平翻转
        albumentations.VerticalFlip(p=0.5),         # 概率50%垂直翻转
        albumentations.Rotate(limit=180, p=0.7),    # 随机翻转( ±180°, 概率70%)
        albumentations.RandomBrightnessContrast(),  # 随机调整亮度和对比度
        albumentations.Affine(
            translate_percent = (-0.25, 0.25),  # 平移范围 -25% 到 +25%
            scale = (0.9, 1.1),                 # 缩放范围 90% 到 110%
            rotate = 0,                         # 无旋转
            p = 0.5),                           # 应用概率50%

        # 图像归一化（使用ImageNet预训练模型的均值和标准差）标准化像素值分布，加速模型收敛
        # 均值: [0.485, 0.456, 0.406]
        # 标准差: [0.229, 0.224, 0.225]
        # max_pixel_value=255.0: 输入图像像素值范围0-255
        albumentations.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], max_pixel_value=255.0),

        # 将图像从 numpy 数组转换为 PyTorch 张量
        ToTensorV2(p=1.0)
    ]
)

# 不添加随机增强, 保证评估结果的一致性
transforms_test = albumentations.Compose(
    [
        albumentations.Resize(320, 320),
        albumentations.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], max_pixel_value=255.0),
        ToTensorV2(p=1.0)
    ]
)

In [59]:
def train_model(train_loader, valid_loader, test, device=torch.device('cuda:0')):
    # 模型初始化
    net = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1)
    in_features = net.fc.in_features      # 获取全连接层的输入特征维度
    net.fc = nn.Linear(in_features, 176)  # 替换全连接层以适应176类树叶分类任务
    net = net.to(device)

    # 训练参数设置
    epoch = 30
    best_epoch = 0
    best_score = 0.0
    best_model_state = None   # 保存最佳模型状态
    early_stopping_round = 3  # 早停轮数
    losses = []               # 记录每轮训练损失

    # 优化器和损失函数
    optimizer = optim.Adam(net.parameters(), lr=0.0001, weight_decay=1e-5)
    loss = nn.CrossEntropyLoss(reduction='mean')  # 返回所有样本损失的平均值
    scheduler = CosineAnnealingLR(optimizer, T_max=epoch)

    # 训练循环
    for i in range(epoch):
        acc = 0       # 累计训练准确数
        loss_sum = 0  # 累计训练损失 

        # 训练阶段
        net.train()
        for x, y in tqdm(train_loader):
            # 准备输入数据
            x = torch.as_tensor(x, dtype=torch.float).to(device)
            y = y.to(device)

            # 前向传播
            y_hat = net(x)
            # 计算损失
            loss_temp = loss(y_hat, y)
            loss_sum += loss_temp

            # 反向传播和优化
            optimizer.zero_grad()  # 清空梯度
            loss_temp.backward()   # 反向传播计算梯度
            optimizer.step()       # 更新模型参数

            # 计算准确数
            acc += torch.sum(y_hat.argmax(dim=1).type(y.dtype) == y)

        # 更新学习率
        scheduler.step()
        # 记录平均损失
        losses.append(loss_sum.cpu().detach().numpy() / len(train_loader))

        # 打印训练结果
        print('epoch: ', i,
             'loss = ', loss_sum.item(),
             'train acc: ', (acc / (len(train_loader) * train_loader.batch_size)).item(), end='')

        # 验证阶段
        valid_acc = 0  # 累计验证准确数
        net.eval()
        for x, y in tqdm(valid_loader):
            # 准备输入数据
            x = torch.as_tensor(x, dtype=torch.float).to(device)
            y = y.to(device)

            # 前向传播
            with torch.no_grad():
                y_hat = net(x)

            # 计算准确数
            valid_acc += torch.sum(y_hat.argmax(dim=1).type(y.dtype) == y)

        # 计算并打印验证准确率
        print('valid acc: ', (valid_acc / (len(valid_loader) * valid_loader.batch_size)).item())

        # 模型保存与早停
        if valid_acc > best_score:
            best_model_state = copy.deepcopy(net.state_dict())
            best_score = valid_acc
            best_epoch = i
            print('best epoch save!')
            
        if i - best_epoch >= early_stopping_round:
            print(f'Early stopping at epoch {i}')
            break
        

    # 加载最佳模型
    net.load_state_dict(best_model_state)

    # 加载测试数据
    testset = Leaf_Dataset(test, transform=transforms_test, test_bool=True)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, drop_last=False)

    # 执行预测
    predictions = []
    with torch.no_grad():
        for x in tqdm(test_loader):
            x = torch.as_tensor(x, dtype=torch.float).to(device)

            # 前向传播
            y_hat = net(x)

            # 获取预测结果
            predict = torch.argmax(y_hat, dim=1).reshape(-1)
            predict = list(predict.cpu().detach().numpy())
            predictions.extend(predict)

    return predictions

In [60]:
# # 可视化 skf.split(train, train['number']) 
# from sklearn.model_selection import StratifiedKFold
# import pandas as pd
# import numpy as np

# # 示例数据
# train = pd.DataFrame({
#     'image': [f'img_{i}' for i in range(10)],
#     'number': [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]  # 二分类标签
# })

# # 初始化分层K折（假设2折）
# skf = StratifiedKFold(n_splits=2, shuffle=True, random_state=42)

# for fold_n, (train_idx, val_idx) in enumerate(skf.split(train, train['number'])):
#     print(f"\nFold {fold_n}:")
#     print("Train indices:", train_idx)
#     print("Val indices:  ", val_idx)
#     print("Train data:\n", train.iloc[train_idx])
#     print("Val data:\n", train.iloc[val_idx])

In [61]:
# 初始化分层K折交叉验证器
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=2025)

# 创建空 DataFrame 存储各折的预测结果
prediction_KFold = pd.DataFrame()

# 开始K折交叉验证
for fold_n, (train_idx, val_idx) in enumerate(skf.split(train, train['number'])):
    print(f'fold {fold_n} training...')

    # 划分训练集和验证集
    train_data = train.iloc[train_idx]
    eval_data = train.iloc[val_idx]

    # 创建数据集对象
    trainset = Leaf_Dataset(train_data, transform=transforms_train)
    evalset = Leaf_Dataset(eval_data, transform=transforms_test)

    # 创建数据加载器
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, drop_last=False)
    eval_loader = torch.utils.data.DataLoader(evalset, batch_size=32, shuffle=False, drop_last=False)

    # 训练模型并在测试集上预测
    predictions = train_model(train_loader, eval_loader, test)

    # 存储当前折的预测结果
    prediction_KFold[f'fold_{fold_n}'] = predictions

fold 0 training...


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  0 loss =  1316.62060546875 train acc:  0.35246458649635315

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.5885869264602661
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  1 loss =  563.4143676757812 train acc:  0.6655092239379883

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.7902173399925232
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  2 loss =  347.95941162109375 train acc:  0.790168821811676

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8361412882804871
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  3 loss =  247.69253540039062 train acc:  0.8464732766151428

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8592391014099121
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  4 loss =  196.0199432373047 train acc:  0.8728893995285034

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8926630020141602
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  5 loss =  160.21218872070312 train acc:  0.897467315196991

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9192934632301331
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  6 loss =  127.32304382324219 train acc:  0.9189133644104004

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9277173280715942
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  7 loss =  118.04581451416016 train acc:  0.922726035118103

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9252716898918152


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  8 loss =  99.58284759521484 train acc:  0.9307597875595093

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9342390894889832
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  9 loss =  86.33109283447266 train acc:  0.9413806796073914

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9418478012084961
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  10 loss =  77.77699279785156 train acc:  0.9443082809448242

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9404891133308411


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  11 loss =  69.10450744628906 train acc:  0.9515250325202942

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9510869383811951
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  12 loss =  61.59609603881836 train acc:  0.9576525092124939

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9592390656471252
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  13 loss =  48.73257827758789 train acc:  0.9656862616539001

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9546195268630981


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  14 loss =  50.38199234008789 train acc:  0.9634395241737366

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9521738886833191


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  15 loss =  44.36927795410156 train acc:  0.9682053327560425

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9586955904960632
Early stopping at epoch 15


  0%|          | 0/138 [00:00<?, ?it/s]

fold 1 training...


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  0 loss =  1316.48388671875 train acc:  0.3528049886226654

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.5766304135322571
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  1 loss =  558.2635498046875 train acc:  0.6760620474815369

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.783423900604248
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  2 loss =  345.5888366699219 train acc:  0.7903050184249878

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.835326075553894
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  3 loss =  233.8436737060547 train acc:  0.8577069640159607

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8608695268630981
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  4 loss =  187.35643005371094 train acc:  0.8824210166931152

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9048912525177002
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  5 loss =  154.6842041015625 train acc:  0.9003267884254456

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9078803658485413
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  6 loss =  129.3643341064453 train acc:  0.9138071537017822

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9252716898918152
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  7 loss =  110.4753189086914 train acc:  0.9270833134651184

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9233695268630981


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  8 loss =  96.3366470336914 train acc:  0.9349809288978577

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9432064890861511
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  9 loss =  81.1338119506836 train acc:  0.94471675157547

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9423912763595581


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  10 loss =  77.76331329345703 train acc:  0.9460103511810303

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9432064890861511


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  11 loss =  66.31366729736328 train acc:  0.9525462985038757

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9377716779708862
Early stopping at epoch 11


  0%|          | 0/138 [00:00<?, ?it/s]

fold 2 training...


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  0 loss =  1317.433837890625 train acc:  0.35103484988212585

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.6119564771652222
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  1 loss =  561.385009765625 train acc:  0.6709558367729187

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.7698369026184082
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  2 loss =  343.05706787109375 train acc:  0.7976579070091248

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8222825527191162
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  3 loss =  246.9237823486328 train acc:  0.845452070236206

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8548912405967712
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  4 loss =  188.2266845703125 train acc:  0.8805146813392639

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8785325884819031
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  5 loss =  156.35809326171875 train acc:  0.9000544548034668

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8994565010070801
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  6 loss =  129.232177734375 train acc:  0.914692223072052

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9187499284744263
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  7 loss =  114.54090118408203 train acc:  0.9242919087409973

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9141303896903992


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  8 loss =  96.05045318603516 train acc:  0.9370914697647095

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9163042902946472


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  9 loss =  87.94209289550781 train acc:  0.9379765391349792

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9345108270645142
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  10 loss =  79.63701629638672 train acc:  0.9434231519699097

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9307065010070801


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  11 loss =  70.18844604492188 train acc:  0.9524781703948975

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9532608389854431
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  12 loss =  61.09904479980469 train acc:  0.9582652449607849

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9385868906974792


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  13 loss =  55.26669692993164 train acc:  0.9598311185836792

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9535325765609741
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  14 loss =  46.21421813964844 train acc:  0.9681372046470642

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9597825407981873
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  15 loss =  42.60511016845703 train acc:  0.9679329991340637

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9567934274673462


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  16 loss =  38.590572357177734 train acc:  0.9728349447250366

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9578803777694702


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  17 loss =  32.75934982299805 train acc:  0.9741966128349304

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9551630020141602
Early stopping at epoch 17


  0%|          | 0/138 [00:00<?, ?it/s]

fold 3 training...


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  0 loss =  1309.6649169921875 train acc:  0.35191991925239563

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.5766304135322571
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  1 loss =  567.3193359375 train acc:  0.6642837524414062

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.7641304135322571
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  2 loss =  356.0361022949219 train acc:  0.7822712063789368

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8222825527191162
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  3 loss =  251.7965087890625 train acc:  0.8428649306297302

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.876902163028717
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  4 loss =  195.51132202148438 train acc:  0.873910665512085

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8861412405967712
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  5 loss =  161.3388671875 train acc:  0.8942673802375793

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8942934274673462
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  6 loss =  136.24508666992188 train acc:  0.9089732766151428

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9138586521148682
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  7 loss =  112.09037017822266 train acc:  0.9264025092124939

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9149456024169922
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  8 loss =  107.30073547363281 train acc:  0.9282407164573669

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9331521391868591
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  9 loss =  81.44024658203125 train acc:  0.9458741545677185

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9464673399925232
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  10 loss =  76.18907165527344 train acc:  0.9463507533073425

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9345108270645142


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  11 loss =  68.70964050292969 train acc:  0.9528186321258545

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9355977773666382


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  12 loss =  61.65922164916992 train acc:  0.9570397138595581

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9516304135322571
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  13 loss =  54.80375289916992 train acc:  0.9605119824409485

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9491847157478333


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  14 loss =  48.00424575805664 train acc:  0.9654139280319214

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9519021511077881
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  15 loss =  42.0115966796875 train acc:  0.9689542055130005

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9586955904960632
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  16 loss =  37.72837829589844 train acc:  0.9720860123634338

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9584238529205322


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  17 loss =  34.203704833984375 train acc:  0.9739242792129517

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9586955904960632


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  18 loss =  28.203380584716797 train acc:  0.9787581562995911

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9660325646400452
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  19 loss =  27.25482177734375 train acc:  0.9789624214172363

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9709238409996033
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  20 loss =  25.240076065063477 train acc:  0.9804602265357971

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9652173519134521


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  21 loss =  20.420642852783203 train acc:  0.983047366142273

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9725543260574341
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  22 loss =  18.391077041625977 train acc:  0.9850898385047913

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9755434393882751
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  23 loss =  18.38484764099121 train acc:  0.9843409061431885

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9749999642372131


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  24 loss =  17.306766510009766 train acc:  0.9849537014961243

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9711955785751343


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  25 loss =  15.016822814941406 train acc:  0.9878812432289124

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9701086282730103
Early stopping at epoch 25


  0%|          | 0/138 [00:00<?, ?it/s]

fold 4 training...


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  0 loss =  1313.68408203125 train acc:  0.35103484988212585

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.6198369264602661
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  1 loss =  559.6613159179688 train acc:  0.6701388955116272

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.7874999642372131
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  2 loss =  342.8634948730469 train acc:  0.7949346303939819

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8638586401939392
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  3 loss =  240.38873291015625 train acc:  0.8534858226776123

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8380434513092041


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  4 loss =  188.3905792236328 train acc:  0.8810593485832214

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.866847813129425
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  5 loss =  153.45269775390625 train acc:  0.9033223986625671

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9097825884819031
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  6 loss =  126.71725463867188 train acc:  0.9156454205513

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9274455904960632
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  7 loss =  116.07733917236328 train acc:  0.9242238402366638

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9277173280715942
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  8 loss =  94.63565826416016 train acc:  0.9360702633857727

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9290760159492493
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  9 loss =  87.50334930419922 train acc:  0.938725471496582

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9364129900932312
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  10 loss =  76.30065155029297 train acc:  0.946282684803009

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9478260278701782
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  11 loss =  69.58679962158203 train acc:  0.9500952959060669

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9499999284744263
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  12 loss =  60.982933044433594 train acc:  0.9572439789772034

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9559782147407532
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  13 loss =  49.87460708618164 train acc:  0.9654139280319214

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9540760517120361


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  14 loss =  49.70308303833008 train acc:  0.9639841914176941

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9638586640357971
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  15 loss =  43.146202087402344 train acc:  0.9684776663780212

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9624999761581421


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  16 loss =  38.76523971557617 train acc:  0.9718136787414551

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9660325646400452
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  17 loss =  33.4554557800293 train acc:  0.9748093485832214

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9665760397911072
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  18 loss =  30.15144920349121 train acc:  0.9767837524414062

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9701086282730103
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  19 loss =  27.847536087036133 train acc:  0.9779411554336548

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9698368906974792


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  20 loss =  23.750389099121094 train acc:  0.9811410307884216

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9747282266616821
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  21 loss =  21.396425247192383 train acc:  0.9814133644104004

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9703803658485413


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  22 loss =  19.041709899902344 train acc:  0.9838643670082092

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9744564890861511


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  23 loss =  17.9996395111084 train acc:  0.984885573387146

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9720108509063721
Early stopping at epoch 23


  0%|          | 0/138 [00:00<?, ?it/s]

In [None]:
# 查看各折预测结果
print(prediction_KFold)

In [None]:
# 最终预测结果采取众数投票选出
final_prediction = list(prediction_KFold.mode(axis=1)[0].astype(int))

# 将数字标签转换回文本标签
test['label'] = [labels_unique[i] for i in final_prediction]

# 保存结果文件
test.to_csv('../dataset/result2.csv', index=False)