In [None]:
#安装预训练裤
!pip install torch torchvision timm

## 随机种子

In [1]:
import torch
import numpy as np
import random

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

seed = 42
set_seed(seed)

# ResNet-VIT模型

In [2]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from timm.models import vit_base_patch16_224 # timm库提供了很多预训练模型，包括ViT

resnet = models.resnet18(pretrained=True)
for param in resnet.parameters():
    param.requires_grad = True

num_classes = 9
resnet.fc = nn.Linear(resnet.fc.in_features, num_classes)

vit = vit_base_patch16_224(pretrained=True)

vit.head = nn.Linear(vit.head.in_features, num_classes)

class CombinedModel(nn.Module):
    def __init__(self, resnet, vit):
        super(CombinedModel, self).__init__()
        self.resnet = resnet
        self.vit = vit
        self.classifier = nn.Linear(resnet.fc.out_features + vit.head.out_features, num_classes)

    def forward(self, x):
        resnet_features = self.resnet(x)
        vit_features = self.vit(x)
        combined_features = torch.cat((resnet_features, vit_features), dim=1)
        output = self.classifier(combined_features)
        return output

model = CombinedModel(resnet, vit)



## 导入工具包

In [3]:
import time
import os
from tqdm import tqdm

import pandas as pd
import numpy as np

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
%matplotlib inline

# 忽略烦人的红色提示
import warnings
warnings.filterwarnings("ignore")

# 获取计算硬件
# 有 GPU 就用 GPU，没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)


device cuda:0


In [4]:
#定义一些超参数
BATCH_SIZE = 36
#训练轮数，可以不用设的太高
EPOCHS = 72
learning_rate = 1e-4

## 图像预处理

In [5]:
from torchvision import transforms

# 训练集图像预处理：缩放裁剪、图像增强、转 Tensor、归一化
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                     ])

# 测试集图像预处理-RCTN：缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
                                    ])

## 载入数据集

In [6]:
# 数据集文件夹路径，要换成自己图像的路径
dataset_dir = 'D:\mushrooms_split'

In [7]:
train_path = os.path.join(dataset_dir, 'train')
test_path = os.path.join(dataset_dir, 'val')
print('训练集路径', train_path)
print('测试集路径', test_path)

from torchvision import datasets
# 载入训练集
train_dataset = datasets.ImageFolder(train_path, train_transform)
# 载入测试集
test_dataset = datasets.ImageFolder(test_path, test_transform)
print('训练集图像数量', len(train_dataset))
print('类别个数', len(train_dataset.classes))
print('各类别名称', train_dataset.classes)
print('测试集图像数量', len(test_dataset))
print('类别个数', len(test_dataset.classes))
print('各类别名称', test_dataset.classes)

训练集路径 D:\mushrooms_split\train
测试集路径 D:\mushrooms_split\val
训练集图像数量 5375
类别个数 9
各类别名称 ['Agaricus-双孢蘑菇', 'Amanita-毒蝇伞', 'Boletus-丽柄牛肝菌', 'Cortinarius-掷丝膜菌', 'Entoloma-霍氏粉褶菌', 'Hygrocybe-浅黄褐湿伞', 'Lactarius-松乳菇', 'Russula-褪色红菇', 'Suillus-乳牛肝菌']
测试集图像数量 1339
类别个数 9
各类别名称 ['Agaricus-双孢蘑菇', 'Amanita-毒蝇伞', 'Boletus-丽柄牛肝菌', 'Cortinarius-掷丝膜菌', 'Entoloma-霍氏粉褶菌', 'Hygrocybe-浅黄褐湿伞', 'Lactarius-松乳菇', 'Russula-褪色红菇', 'Suillus-乳牛肝菌']


## 类别和索引号 映射字典

In [8]:
# 各类别名称
class_names = train_dataset.classes
n_class = len(class_names)
# 映射关系：类别 到 索引号
train_dataset.class_to_idx
# 映射关系：索引号 到 类别
idx_to_labels = {y:x for x,y in train_dataset.class_to_idx.items()}

In [9]:
idx_to_labels

{0: 'Agaricus-双孢蘑菇',
 1: 'Amanita-毒蝇伞',
 2: 'Boletus-丽柄牛肝菌',
 3: 'Cortinarius-掷丝膜菌',
 4: 'Entoloma-霍氏粉褶菌',
 5: 'Hygrocybe-浅黄褐湿伞',
 6: 'Lactarius-松乳菇',
 7: 'Russula-褪色红菇',
 8: 'Suillus-乳牛肝菌'}

In [10]:
# 保存为本地的 npy 文件
np.save('idx_to_labels.npy', idx_to_labels)
np.save('labels_to_idx.npy', train_dataset.class_to_idx)

In [10]:
## 定义数据加载器DataLoader
from torch.utils.data import DataLoader

In [11]:
# 训练集的数据加载器
train_loader = DataLoader(train_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=True,
                          num_workers=8#处理器的数量
                         )

# 测试集的数据加载器
test_loader = DataLoader(test_dataset,
                         batch_size=BATCH_SIZE,
                         shuffle=False,
                         num_workers=8
                        )

## 训练工具包

In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision import models
import torch.optim as optim
from torch.optim import lr_scheduler

In [13]:
# 定义 ResNet-ViT 模型
model = model.to(device)

# 定义优化器和损失函数
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), learning_rate)

# 学习率调度器
lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

## 进行训练

In [14]:
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.metrics import roc_auc_score

In [15]:
def train_one_batch(images, labels):
    '''
    运行一个 batch 的训练，返回当前 batch 的训练日志
    '''

    # 获得一个 batch 的数据和标注
    images = images.to(device)
    labels = labels.to(device)

    outputs = model(images) # 输入模型，执行前向预测
    loss = criterion(outputs, labels) # 计算当前 batch 中，每个样本的平均交叉熵损失函数值

    # 优化更新权重
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # 获取当前 batch 的标签类别和预测类别
    _, preds = torch.max(outputs, 1) # 获得当前 batch 所有图像的预测类别
    preds = preds.cpu().numpy()
    loss = loss.detach().cpu().numpy()
    outputs = outputs.detach().cpu().numpy()
    labels = labels.detach().cpu().numpy()

    log_train = {}
    log_train['epoch'] = epoch
    log_train['batch'] = batch_idx
    # 计算分类评估指标
    log_train['train_loss'] = loss
    log_train['train_accuracy'] = accuracy_score(labels, preds)
    log_train['train_precision'] = precision_score(labels, preds, average='macro')
    log_train['train_recall'] = recall_score(labels, preds, average='macro')
    log_train['train_f1-score'] = f1_score(labels, preds, average='macro')

    return log_train

## 在整个测试集上进行评估

In [16]:
def evaluate_testset():
    '''
    在整个测试集上评估，返回分类评估指标日志
    '''

    loss_list = []
    labels_list = []
    preds_list = []

    with torch.no_grad():
        for images, labels in test_loader: # 生成一个 batch 的数据和标注
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images) # 输入模型，执行前向预测

            # 获取整个测试集的标签类别和预测类别
            _, preds = torch.max(outputs, 1) # 获得当前 batch 所有图像的预测类别
            preds = preds.cpu().numpy()
            loss = criterion(outputs, labels) # 由 logit，计算当前 batch 中，每个样本的平均交叉熵损失函数值
            loss = loss.detach().cpu().numpy()
            outputs = outputs.detach().cpu().numpy()
            labels = labels.detach().cpu().numpy()

            loss_list.append(loss)
            labels_list.extend(labels)
            preds_list.extend(preds)

    log_test = {}
    log_test['epoch'] = epoch

    # 计算分类评估指标
    log_test['test_loss'] = np.mean(loss)
    log_test['test_accuracy'] = accuracy_score(labels_list, preds_list)
    log_test['test_precision'] = precision_score(labels_list, preds_list, average='macro')
    log_test['test_recall'] = recall_score(labels_list, preds_list, average='macro')
    log_test['test_f1-score'] = f1_score(labels_list, preds_list, average='macro')

    return log_test

## 日志记录

In [17]:
epoch = 0
batch_idx = 0
best_test_accuracy = 0

In [18]:
# 训练日志-训练集
df_train_log = pd.DataFrame()
log_train = {}
log_train['epoch'] = 0
log_train['batch'] = 0
images, labels = next(iter(train_loader))
log_train.update(train_one_batch(images, labels))
df_train_log = df_train_log.append(log_train, ignore_index=True)

In [19]:
df_train_log

Unnamed: 0,epoch,batch,train_loss,train_accuracy,train_precision,train_recall,train_f1-score
0,0,0,2.239984,0.083333,0.028571,0.047619,0.035714


In [20]:
# 训练日志-测试集
df_test_log = pd.DataFrame()
log_test = {}
log_test['epoch'] = 0
log_test.update(evaluate_testset())
df_test_log = df_test_log.append(log_test, ignore_index=True)

In [21]:
df_test_log

Unnamed: 0,epoch,test_loss,test_accuracy,test_precision,test_recall,test_f1-score
0,0.0,2.158765,0.143391,0.090115,0.113611,0.087708


## 训练

In [24]:
for epoch in range(1, EPOCHS+1):

    print(f'Epoch {epoch}/{EPOCHS}')

    ## 训练阶段
    model.train()
    for images, labels in tqdm(train_loader): # 获得一个 batch 的数据和标注
        batch_idx += 1
        log_train = train_one_batch(images, labels)
        df_train_log = df_train_log.append(log_train, ignore_index=True)

    lr_scheduler.step()

    ## 测试阶段
    model.eval()
    log_test = evaluate_testset()
    df_test_log = df_test_log.append(log_test, ignore_index=True)

    # 保存最新的最佳模型文件
    if log_test['test_accuracy'] > best_test_accuracy:
        # 删除旧的最佳模型文件(如有)
        old_best_checkpoint_path = 'checkpoints/seed42_resnet-vit_best-{:.3f}.pth'.format(best_test_accuracy)
        if os.path.exists(old_best_checkpoint_path):
            os.remove(old_best_checkpoint_path)
        # 保存新的最佳模型文件
        new_best_checkpoint_path = 'checkpoints/seed42_resnet-vit_best-{:.3f}.pth'.format(log_test['test_accuracy'])
        torch.save(model, new_best_checkpoint_path)
        print('保存新的最佳模型', 'checkpoints/seed42_resnet-vit_best-{:.3f}.pth'.format(best_test_accuracy))
        best_test_accuracy = log_test['test_accuracy']

df_train_log.to_csv('resnet-vit_训练日志1-训练集.csv', index=False)
df_test_log.to_csv('resnet-vit_训练日志1-测试集.csv', index=False)

Epoch 1/36


100%|██████████| 448/448 [02:26<00:00,  3.06it/s]


保存新的最佳模型 checkpoints/resnet-vit_best-0.000.pth
Epoch 2/36


100%|██████████| 448/448 [02:25<00:00,  3.08it/s]


Epoch 3/36


100%|██████████| 448/448 [02:24<00:00,  3.09it/s]


保存新的最佳模型 checkpoints/resnet-vit_best-0.785.pth
Epoch 4/36


100%|██████████| 448/448 [02:26<00:00,  3.06it/s]


保存新的最佳模型 checkpoints/resnet-vit_best-0.793.pth
Epoch 5/36


100%|██████████| 448/448 [02:27<00:00,  3.03it/s]


保存新的最佳模型 checkpoints/resnet-vit_best-0.843.pth
Epoch 6/36


100%|██████████| 448/448 [02:28<00:00,  3.02it/s]


保存新的最佳模型 checkpoints/resnet-vit_best-0.866.pth
Epoch 7/36


100%|██████████| 448/448 [02:29<00:00,  3.00it/s]


Epoch 8/36


100%|██████████| 448/448 [02:28<00:00,  3.01it/s]


保存新的最佳模型 checkpoints/resnet-vit_best-0.904.pth
Epoch 9/36


100%|██████████| 448/448 [02:28<00:00,  3.02it/s]


Epoch 10/36


100%|██████████| 448/448 [02:28<00:00,  3.02it/s]


Epoch 11/36


100%|██████████| 448/448 [02:28<00:00,  3.03it/s]


保存新的最佳模型 checkpoints/resnet-vit_best-0.907.pth
Epoch 12/36


100%|██████████| 448/448 [02:28<00:00,  3.01it/s]


Epoch 13/36


100%|██████████| 448/448 [02:26<00:00,  3.05it/s]


保存新的最佳模型 checkpoints/resnet-vit_best-0.924.pth
Epoch 14/36


100%|██████████| 448/448 [02:27<00:00,  3.04it/s]


Epoch 15/36


100%|██████████| 448/448 [02:26<00:00,  3.05it/s]


保存新的最佳模型 checkpoints/resnet-vit_best-0.930.pth
Epoch 16/36


100%|██████████| 448/448 [02:23<00:00,  3.13it/s]


保存新的最佳模型 checkpoints/resnet-vit_best-0.936.pth
Epoch 17/36


100%|██████████| 448/448 [02:22<00:00,  3.14it/s]


Epoch 18/36


100%|██████████| 448/448 [02:22<00:00,  3.14it/s]


Epoch 19/36


100%|██████████| 448/448 [02:22<00:00,  3.14it/s]


Epoch 20/36


100%|██████████| 448/448 [02:22<00:00,  3.14it/s]


保存新的最佳模型 checkpoints/resnet-vit_best-0.940.pth
Epoch 21/36


100%|██████████| 448/448 [02:22<00:00,  3.14it/s]


Epoch 22/36


100%|██████████| 448/448 [02:22<00:00,  3.14it/s]


保存新的最佳模型 checkpoints/resnet-vit_best-0.942.pth
Epoch 23/36


100%|██████████| 448/448 [02:22<00:00,  3.14it/s]


保存新的最佳模型 checkpoints/resnet-vit_best-0.946.pth
Epoch 24/36


100%|██████████| 448/448 [02:22<00:00,  3.13it/s]


保存新的最佳模型 checkpoints/resnet-vit_best-0.948.pth
Epoch 25/36


100%|██████████| 448/448 [02:22<00:00,  3.13it/s]


Epoch 26/36


100%|██████████| 448/448 [02:22<00:00,  3.14it/s]


Epoch 27/36


100%|██████████| 448/448 [02:23<00:00,  3.13it/s]


保存新的最佳模型 checkpoints/resnet-vit_best-0.950.pth
Epoch 28/36


100%|██████████| 448/448 [02:22<00:00,  3.14it/s]


Epoch 29/36


100%|██████████| 448/448 [02:22<00:00,  3.14it/s]


Epoch 30/36


100%|██████████| 448/448 [02:22<00:00,  3.14it/s]


保存新的最佳模型 checkpoints/resnet-vit_best-0.951.pth
Epoch 31/36


100%|██████████| 448/448 [02:22<00:00,  3.14it/s]


保存新的最佳模型 checkpoints/resnet-vit_best-0.952.pth
Epoch 32/36


100%|██████████| 448/448 [02:22<00:00,  3.14it/s]


保存新的最佳模型 checkpoints/resnet-vit_best-0.953.pth
Epoch 33/36


100%|██████████| 448/448 [02:22<00:00,  3.14it/s]


Epoch 34/36


100%|██████████| 448/448 [02:22<00:00,  3.14it/s]


Epoch 35/36


100%|██████████| 448/448 [02:25<00:00,  3.08it/s]


Epoch 36/36


100%|██████████| 448/448 [02:29<00:00,  2.99it/s]


## 测试集评价

In [None]:
# 载入最佳模型作为当前模型
model = torch.load('checkpoints/seed42_resnet-vit_best-{:.3f}.pth'.format(best_test_accuracy))
model.eval()
print(evaluate_testset())