In [7]:
import os
import cv2
import copy
import torch
import torchvision.datasets
import albumentations

import pandas as pd
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms

from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ExponentialLR, CosineAnnealingLR
from sklearn.model_selection import train_test_split, StratifiedKFold
from PIL import Image
from albumentations.pytorch.transforms import ToTensorV2

In [8]:
# 创建模型保存目录
os.makedirs('../models', exist_ok=True)

In [9]:
# 读取训练数据和测试数据
train = pd.read_csv('../dataset/train.csv')
test = pd.read_csv('../dataset/test.csv')
# 对标签列进行数字编码
train['number'], labels_unique = pd.factorize(train['label'])
# 保存编码结果，便于查看
train.to_csv('../dataset/train_num_label4.csv', index=False)
# 保存标签映射关系
pd.Series(labels_unique).to_csv('../dataset/label_mapping.csv', index=False)

In [11]:
# 数据增强
transforms_train = albumentations.Compose(
    [
        albumentations.Resize(320, 320),            # 调整图像尺寸到 320x320
        albumentations.HorizontalFlip(p=0.5),       # 概率50%水平翻转
        albumentations.VerticalFlip(p=0.5),         # 概率50%垂直翻转
        albumentations.Rotate(limit=180, p=0.7),    # 随机翻转（ ±180°，概率70% ）
        albumentations.RandomBrightnessContrast(),  # 随机调整亮度和对比度
        albumentations.Affine(
            translate_percent = (-0.25, 0.25),  # 平移范围±25%
            scale = (0.9, 1.1),                 # 缩放范围±10%
            rotate = 0,                         # 无旋转
            p = 0.5                             # 应用概率50%
        ),
        albumentations.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], max_pixel_value=255.0),
        ToTensorV2(p=1)  # 将图像从 numpy 数组转换为 PyTorch 张量
    ]
)

# 不添加随机增强，保证评估结果的一致性
transforms_test = albumentations.Compose(
    [
        albumentations.Resize(320, 320),
        albumentations.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], max_pixel_value=255.0),
        ToTensorV2(p=1.0)
    ]
)

In [16]:
class Leaf_Dataset(Dataset):
    def __init__(self, train_csv, transform=None, test_bool=False):
        '''
        train_csv: 记录图像路径及标号的csv文件
        transform: 图像变换
        test_bool: 是否为测试集
        '''
        super().__init__()
        self.train_csv = train_csv
        self.transform = transform
        self.test_bool = test_bool
        self.image_path = list(self.train_csv['image'])  # 提取所有图像路径
        # 如果不是测试集，加载标签信息
        if not test_bool:
            self.label_nums = list(self.train_csv['number'])

    def __getitem__(self, idx):
        '''
        获取单个样本
        idx: 图像索引
        return: image, label
        '''
        # 读取图像
        image = cv2.imread(os.path.join('../dataset', self.image_path[idx]))  
        # 转换颜色空间（OpenCV默认BGR，PyTorch需要RGB）
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  
        # 图像变换
        if self.transform is not None:                   
            image = self.transform(image=image)['image']
        # 测试集只返回图像，训练集返回图像和标签
        if not self.test_bool:
            label = self.label_nums[idx]
            return image, label
        else:
            return image

    def __len__(self):
        return len(self.image_path)

In [17]:
def train_model(train_loader, valid_loader, test, fold_n, device=torch.device('cuda:0')):
    # 模型初始化
    net = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1)
    in_features = net.fc.in_features      # 获取全连接层的输入特征维度
    net.fc = nn.Linear(in_features, 176)  # 替换全连接层以适应176类树叶分类任务
    net = net.to(device)

    # 训练参数设置
    epoch = 30
    best_epoch = 0
    best_score = 0.0
    best_model_state = None   # 保存最佳模型状态
    early_stopping_round = 3  # 早停轮数
    losses = []               # 记录每轮的训练损失

    # 优化器、损失函数和调度器
    optimizer = optim.Adam(net.parameters(), lr=0.0001, weight_decay=1e-5)
    loss = nn.CrossEntropyLoss(reduction='mean')
    scheduler = ExponentialLR(optimizer, gamma=0.9)

    # 训练循环
    for i in range(epoch):
        acc = 0       # 累计训练准确数
        loss_sum = 0  # 累计训练损失

        # 训练阶段
        net.train()
        for x, y in tqdm(train_loader):
            # 准备输入数据
            x = torch.as_tensor(x, dtype=torch.float).to(device)
            y = y.to(device)

            # 前向传播
            y_hat = net(x)

            # 计算损失
            loss_temp = loss(y_hat, y)
            loss_sum += loss_temp

            # 反向传播和优化
            optimizer.zero_grad()
            loss_temp.backward()
            optimizer.step()

            # 计算准确数
            acc += torch.sum(y_hat.argmax(dim=1).type(y.dtype) == y)

        # 更新学习率
        scheduler.step()

        # 记录平均损失
        losses.append(loss_sum.cpu().detach().numpy() / len(train_loader))

        # 打印训练结果
        print('epoch: ', i,
             'loss: ', loss_sum.item(),
             'train acc: ', (acc / (len(train_loader) * train_loader.batch_size)).item(), end='')

        # 验证阶段
        valid_acc = 0.0
        net.eval()
        for x, y in tqdm(valid_loader):
            # 准备输入数据
            x = torch.as_tensor(x, dtype=torch.float).to(device)
            y = y.to(device)

            # 前向传播
            with torch.no_grad():
                y_hat = net(x)

            # 计算准确数
            valid_acc += torch.sum(y_hat.argmax(dim=1).type(y.dtype) == y)

        # 计算并打印验证准确率
        print('valid acc: ', (valid_acc / (len(valid_loader) * valid_loader.batch_size)).item())

        # 模型保存与早停
        if valid_acc > best_score:
            best_model_state = copy.deepcopy(net.state_dict())
            best_score = valid_acc
            best_epoch = i
            print('best epoch save!')

        if i - best_epoch >= early_stopping_round:
            print(f'Early stopping at epoch {i}')
            break

    # 保存最佳模型
    model_path = f'../models/fold_{fold_n}_best_model.pth'
    torch.save({
            'model_state_dict': best_model_state,
            'num_classes': len(labels_unique)
        }, model_path)
    print(f'Saved model for fold {fold_n} at {model_path}')
    
    # 加载最佳模型
    net.load_state_dict(best_model_state)

    # 加载测试数据
    testset = Leaf_Dataset(test, transform=transforms_test, test_bool=True)
    test_loader = DataLoader(testset, batch_size=64, shuffle=False, drop_last=False)

    # 执行预测
    predictions = []
    with torch.no_grad():
        for x in tqdm(test_loader):
            # 准备输入数据
            x = torch.as_tensor(x, dtype=torch.float).to(device)

            # 前向传播
            y_hat = net(x)

            # 获取预测结果
            predict = torch.argmax(y_hat, dim=1).reshape(-1)
            predict = list(predict.cpu().detach().numpy())
            predictions.extend(predict)
    return predictions

In [18]:
# 初始化分层K折交叉验证器
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=2025)

# 创建空 DataFrame 存储各折的预测结果
prediction_KFold = pd.DataFrame()

# 开始K折交叉验证
for fold_n, (train_idx, val_idx) in enumerate(skf.split(train, train['number'])):
    print(f'fold {fold_n} training...')

    # 划分训练集和验证集
    train_data = train.iloc[train_idx]
    valid_data = train.iloc[val_idx]

    # 创建数据集对象
    trainset = Leaf_Dataset(train_data, transform=transforms_train)
    validset = Leaf_Dataset(valid_data, transform=transforms_test)

    # 创建数据加载器
    train_loader = DataLoader(trainset, batch_size=32, shuffle=True, drop_last=False)
    valid_loader = DataLoader(validset, batch_size=32, shuffle=False, drop_last=False)

    # 训练模型并在测试集上预测
    predictions = train_model(train_loader, valid_loader, test, fold_n)

    # 存储当前折的预测结果
    prediction_KFold[f'fold {fold_n}'] = predictions

fold 0 training...


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  0 loss:  1316.7943115234375 train acc:  0.35443899035453796

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.6247282028198242
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  1 loss:  549.9865112304688 train acc:  0.6856617331504822

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.783152163028717
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  2 loss:  332.0941467285156 train acc:  0.8047385215759277

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8706521391868591
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  3 loss:  228.48272705078125 train acc:  0.8641748428344727

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8970108032226562
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  4 loss:  166.0929412841797 train acc:  0.898080050945282

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9154890775680542
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  5 loss:  131.86460876464844 train acc:  0.9190495610237122

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9317934513092041
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  6 loss:  106.69075012207031 train acc:  0.9330065250396729

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.939945638179779
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  7 loss:  89.86566925048828 train acc:  0.9417892098426819

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9456521272659302
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  8 loss:  76.9225082397461 train acc:  0.9500952959060669

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9554347395896912
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  9 loss:  64.85164642333984 train acc:  0.9568355083465576

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9489129781723022


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  10 loss:  56.970672607421875 train acc:  0.9634395241737366

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9538043141365051


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  11 loss:  50.00904083251953 train acc:  0.9652096629142761

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9529891014099121
Early stopping at epoch 11
Saved model for fold 0 at ../models/fold_0_best_model.pth


  0%|          | 0/138 [00:00<?, ?it/s]

fold 1 training...


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  0 loss:  1306.993408203125 train acc:  0.3542347550392151

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.6070652008056641
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  1 loss:  551.3101806640625 train acc:  0.6817129254341125

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.7608695030212402
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  2 loss:  331.1969299316406 train acc:  0.8035130500793457

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8543477654457092
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  3 loss:  232.3572235107422 train acc:  0.8601579070091248

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8766303658485413
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  4 loss:  165.43922424316406 train acc:  0.8982843160629272

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9144021272659302
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  5 loss:  128.0765838623047 train acc:  0.9202069640159607

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9160325527191162
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  6 loss:  105.54917907714844 train acc:  0.9330065250396729

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9271738529205322
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  7 loss:  88.90028381347656 train acc:  0.9430827498435974

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9404891133308411
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  8 loss:  75.34486389160156 train acc:  0.9502995610237122

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9440217018127441
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  9 loss:  64.41315460205078 train acc:  0.9577886462211609

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9554347395896912
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  10 loss:  54.8652458190918 train acc:  0.9635075926780701

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9573369026184082
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  11 loss:  48.025108337402344 train acc:  0.9685457348823547

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9581521153450012
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  12 loss:  44.1237907409668 train acc:  0.9698392748832703

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9630434513092041
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  13 loss:  40.128780364990234 train acc:  0.9708605408668518

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9589673280715942


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  14 loss:  36.873626708984375 train acc:  0.9743327498435974

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9633151888847351
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  15 loss:  30.530569076538086 train acc:  0.9788942933082581

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9660325646400452
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  16 loss:  29.84308624267578 train acc:  0.9789624214172363

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9665760397911072
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  17 loss:  29.480342864990234 train acc:  0.9793028235435486

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.960869550704956


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  18 loss:  27.40502166748047 train acc:  0.9794389605522156

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9611412882804871


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  19 loss:  26.359664916992188 train acc:  0.9791666269302368

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9682064652442932
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  20 loss:  22.25103187561035 train acc:  0.9829111695289612

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9682064652442932


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  21 loss:  21.894020080566406 train acc:  0.9825707674026489

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9690216779708862
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  22 loss:  21.364192962646484 train acc:  0.9839324355125427

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9682064652442932


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  23 loss:  20.864154815673828 train acc:  0.9829111695289612

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9682064652442932


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  24 loss:  18.333070755004883 train acc:  0.9864515066146851

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9698368906974792
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  25 loss:  18.639936447143555 train acc:  0.9854983687400818

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9706521034240723
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  26 loss:  18.306499481201172 train acc:  0.985838770866394

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9687499403953552


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  27 loss:  16.464168548583984 train acc:  0.9860429763793945

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9690216779708862


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  28 loss:  16.779991149902344 train acc:  0.985974907875061

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9695651531219482
Early stopping at epoch 28
Saved model for fold 1 at ../models/fold_1_best_model.pth


  0%|          | 0/138 [00:00<?, ?it/s]

fold 2 training...


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  0 loss:  1328.2191162109375 train acc:  0.34259259700775146

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.573369562625885
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  1 loss:  556.0230102539062 train acc:  0.6787853837013245

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.7472825646400452
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  2 loss:  339.1282653808594 train acc:  0.7985430359840393

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8410325646400452
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  3 loss:  230.47659301757812 train acc:  0.86036217212677

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8725543022155762
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  4 loss:  174.0442657470703 train acc:  0.8940631747245789

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8752716779708862
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  5 loss:  136.14634704589844 train acc:  0.9138071537017822

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9214673638343811
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  6 loss:  110.09455871582031 train acc:  0.9306917190551758

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9366847276687622
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  7 loss:  92.766357421875 train acc:  0.9407679438591003

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9342390894889832


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  8 loss:  76.21405792236328 train acc:  0.9504356980323792

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9451086521148682
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  9 loss:  65.09365844726562 train acc:  0.9579248428344727

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9527173638343811
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  10 loss:  57.92766189575195 train acc:  0.960716187953949

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9519021511077881


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  11 loss:  52.096656799316406 train acc:  0.9640522599220276

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9589673280715942
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  12 loss:  45.309173583984375 train acc:  0.969362735748291

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9584238529205322


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  13 loss:  40.43522262573242 train acc:  0.9725626111030579

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9592390656471252
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  14 loss:  37.30217361450195 train acc:  0.9720179438591003

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9616847634315491
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  15 loss:  33.983890533447266 train acc:  0.976443350315094

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.960869550704956


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  16 loss:  31.268062591552734 train acc:  0.9768518209457397

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9614130258560181


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  17 loss:  29.182727813720703 train acc:  0.9795751571655273

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9654890894889832
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  18 loss:  26.310840606689453 train acc:  0.9817537665367126

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9663043022155762
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  19 loss:  24.882444381713867 train acc:  0.9824346303939819

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9641304016113281


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  20 loss:  25.533842086791992 train acc:  0.9809368252754211

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9652173519134521


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  21 loss:  22.34005355834961 train acc:  0.9831154346466064

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9676629900932312
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  22 loss:  20.09132194519043 train acc:  0.9840005040168762

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9682064652442932
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  23 loss:  20.96208381652832 train acc:  0.9837281703948975

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9701086282730103
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  24 loss:  19.254730224609375 train acc:  0.9857707023620605

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9671195149421692


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  25 loss:  19.257232666015625 train acc:  0.9852941036224365

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9682064652442932


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  26 loss:  18.04496192932129 train acc:  0.9854983687400818

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9701086282730103
Early stopping at epoch 26
Saved model for fold 2 at ../models/fold_2_best_model.pth


  0%|          | 0/138 [00:00<?, ?it/s]

fold 3 training...


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  0 loss:  1300.73486328125 train acc:  0.35375815629959106

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.6081521511077881
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  1 loss:  545.1376342773438 train acc:  0.6807597875595093

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.7883151769638062
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  2 loss:  327.0072021484375 train acc:  0.8057597875595093

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8448368906974792
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  3 loss:  227.43359375 train acc:  0.8632897138595581

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.877445638179779
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  4 loss:  169.66119384765625 train acc:  0.8983523845672607

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9019021391868591
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  5 loss:  133.19317626953125 train acc:  0.9185729622840881

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9225543141365051
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  6 loss:  103.52729034423828 train acc:  0.9328022599220276

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9334238767623901
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  7 loss:  88.90583801269531 train acc:  0.9436274170875549

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9410325884819031
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  8 loss:  75.19413757324219 train acc:  0.951797366142273

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9478260278701782
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  9 loss:  63.52288818359375 train acc:  0.9577205777168274

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9486412405967712
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  10 loss:  55.69141387939453 train acc:  0.963030993938446

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9576086401939392
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  11 loss:  48.550228118896484 train acc:  0.9686138033866882

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9581521153450012
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  12 loss:  44.56276321411133 train acc:  0.970452070236206

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9600542783737183
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  13 loss:  38.21683120727539 train acc:  0.9747412800788879

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9614130258560181
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  14 loss:  34.388511657714844 train acc:  0.9763752222061157

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9619565010070801
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  15 loss:  33.56085968017578 train acc:  0.9761710166931152

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9638586640357971
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  16 loss:  30.51240348815918 train acc:  0.9776007533073425

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9668477773666382
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  17 loss:  26.478900909423828 train acc:  0.9810048937797546

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9676629900932312
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  18 loss:  26.303218841552734 train acc:  0.9805963635444641

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9657608270645142


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  19 loss:  23.981693267822266 train acc:  0.9818218946456909

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9684782028198242
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  20 loss:  22.77022933959961 train acc:  0.9825707674026489

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9698368906974792
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  21 loss:  21.68136215209961 train acc:  0.9840686321258545

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9690216779708862


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  22 loss:  20.420888900756836 train acc:  0.9846813678741455

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9728260636329651
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  23 loss:  19.9406795501709 train acc:  0.9851579070091248

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9722825884819031


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  24 loss:  18.517305374145508 train acc:  0.9861111044883728

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9703803658485413


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  25 loss:  18.975135803222656 train acc:  0.9852941036224365

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9711955785751343
Early stopping at epoch 25
Saved model for fold 3 at ../models/fold_3_best_model.pth


  0%|          | 0/138 [00:00<?, ?it/s]

fold 4 training...


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  0 loss:  1329.9984130859375 train acc:  0.3423202633857727

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.604347825050354
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  1 loss:  543.00927734375 train acc:  0.6834149956703186

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.7752717137336731
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  2 loss:  331.3494567871094 train acc:  0.8084149956703186

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.845652163028717
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  3 loss:  227.65411376953125 train acc:  0.8650599122047424

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.8820651769638062
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  4 loss:  173.06040954589844 train acc:  0.8944035768508911

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9209238886833191
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  5 loss:  133.09568786621094 train acc:  0.9169389605522156

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9271738529205322
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  6 loss:  108.37774658203125 train acc:  0.9323937892913818

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9328804016113281
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  7 loss:  91.28524017333984 train acc:  0.9409722089767456

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9432064890861511
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  8 loss:  74.24410247802734 train acc:  0.952273964881897

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9519021511077881
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  9 loss:  66.56438446044922 train acc:  0.9569035768508911

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9616847634315491
best epoch save!


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  10 loss:  58.682498931884766 train acc:  0.9616012573242188

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9565216898918152


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  11 loss:  49.35604476928711 train acc:  0.9668436646461487

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9600542783737183


  0%|          | 0/459 [00:00<?, ?it/s]

epoch:  12 loss:  44.628440856933594 train acc:  0.969362735748291

  0%|          | 0/115 [00:00<?, ?it/s]

valid acc:  0.9578803777694702
Early stopping at epoch 12
Saved model for fold 4 at ../models/fold_4_best_model.pth


  0%|          | 0/138 [00:00<?, ?it/s]

In [19]:
# 查看各折预测结果
print(prediction_KFold)

      fold 0  fold 1  fold 2  fold 3  fold 4
0          7       7       7       7       7
1        134     134     134     134     134
2        136     136     136     136     136
3         51      51      51      51      51
4        136     136     136     136     136
...      ...     ...     ...     ...     ...
8795     173     173     173     173     173
8796     174     174     174     174     174
8797      95     173     173     173     173
8798     175     175     175     175     175
8799     175     175     175     175     175

[8800 rows x 5 columns]


In [20]:
# 最终预测结果采取众数投票选出
prediction_final = list(prediction_KFold.mode(axis=1)[0].astype(int))

# 数字标签转换回文本标签
test['label'] = [labels_unique[i] for i in prediction_final]

# 保存结果文件
test.to_csv('../dataset/result_practice2.csv', index=False)