In [None]:
# 导入lib
# 导入树叶数据集
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader,random_split
from torchvision import transforms
from tqdm import tqdm
import os
import csv
import torchvision.utils as vutils
from torchvision.transforms import AutoAugment, AutoAugmentPolicy
import matplotlib.pyplot as plt

In [None]:
def load_data(batch_size=128):
    """导入数据集"""
    imagenet_mean = [0.485, 0.456, 0.406]
    imagenet_std = [0.229, 0.224, 0.225]

    train_trans = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.08, 1.0), ratio=(3./4., 4./3.)),
        transforms.RandomHorizontalFlip(p=0.5),
        AutoAugment(policy=AutoAugmentPolicy.IMAGENET), 
        transforms.ToTensor(),
        transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
    ])

    test_trans = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
    ])
    root_dir = "./datasets/train" 
    dataset = torchvision.datasets.ImageFolder(root=root_dir)
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    train_dataset.dataset = torchvision.datasets.ImageFolder(root=root_dir, transform=train_trans)
    test_dataset.dataset = torchvision.datasets.ImageFolder(root=root_dir, transform=test_trans)
    num_workers = 4
    return (DataLoader(train_dataset, batch_size=batch_size, shuffle=True,num_workers=num_workers,pin_memory=True,  prefetch_factor=2,persistent_workers=True),
            DataLoader(test_dataset, batch_size=batch_size, shuffle=False,num_workers=num_workers,pin_memory=True,prefetch_factor=2,persistent_workers=True)
    )
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
train_iter, test_iter = load_data(batch_size=64)
# for i,(X,y) in enumerate(train_iter): # 这里费时间,如何找到合适的num_workers
#     # X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)
#     X, y = X[:8], y[:8]
#     print("Batch shape:", X.shape, "Labels shape:", y)
#     grid_img = vutils.make_grid(X, nrow=4, normalize=True)  # nrow=4 表示每行 4 张
#     plt.figure(figsize=(8, 4))  # 调整图像大小
#     plt.imshow(grid_img.permute(1, 2, 0))  # 调整通道顺序 (C, H, W) → (H, W, C)
#     plt.axis("off")
#     plt.show()
#     break


In [None]:
# 网络
finetune_net = torchvision.models.resnet18(weights=True)
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 10)
nn.init.xavier_uniform_(finetune_net.fc.weight);

In [None]:
def train_fromKK(net, train_iter, test_iter, num_epochs, lr, device,param_group=True):
    print('training on', device)
    net.to(device)
    loss = nn.CrossEntropyLoss()
    best_weights = 0
    if param_group:
        params_1x = [param for name, param in net.named_parameters()
             if name not in ["fc.weight", "fc.bias"]]
        optimizer = torch.optim.AdamW([{'params': params_1x},
                                   {'params': net.fc.parameters(),
                                    'lr': lr * 10}],
                                lr=lr, weight_decay=0.001)
    else:
        optimizer = torch.optim.AdamW(net.parameters(), lr=lr,
                                  weight_decay=0.001)
    for epoch in range(num_epochs):
        net.train()
        train_loss_sum, train_acc_sum,num_samples = 0,0,0
        with tqdm(train_iter, desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:  
            for X, y in pbar:
                optimizer.zero_grad()
                X,y = X.to(device),y.to(device)
                y_hat = net(X)
                l = loss(y_hat, y)
                l.backward()
                optimizer.step()
                train_loss_sum += l.item() * X.shape[0]
                train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
                num_samples += X.shape[0]
                pbar.set_postfix(loss=l.item(), acc=train_acc_sum / num_samples)
        train_loss = train_loss_sum / num_samples
        train_acc = train_acc_sum / num_samples

        # if (epoch+1) ==  num_epochs:
        net.eval()  # 评估模式
        test_acc_sum, test_samples = 0, 0
        with torch.no_grad():
            for X, y in test_iter: # 这里也很费时间，连续运行效率极高
                X, y = X.to(device), y.to(device)
                y_hat = net(X)
                test_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
                test_samples += X.shape[0]
        test_acc = test_acc_sum / test_samples
        if (test_acc>best_weights):
            best_weights = test_acc
            torch.save(net.state_dict(), 'test_best.pth')
        print(f"______ | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Test Acc: {test_acc:.4f}")

In [None]:
# 训练
lr,num_epochs =5e-5,100
net = finetune_net
train_fromKK(net,train_iter,test_iter,num_epochs,lr,device)

resNe18
test_best_5epoch.pt
______ | Train Loss: 0.5319 | Train Acc: 0.8173 | Test Acc: 0.8885
验证集精度：0.8885

resNe50
______ | Train Loss: 0.1763 | Train Acc: 0.9406 | Test Acc: 0.8997
验证集精度：0.9064


In [None]:
net.load_state_dict(torch.load('test_best_100epoch_resnet18.pth'))
net.eval()  # 评估模式
test_acc_sum, test_samples = 0, 0
with torch.no_grad():
    for X, y in test_iter: # 这里也很费时间，连续运行效率极高
        X, y = X.to(device), y.to(device)
        y_hat = net(X)
        test_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
        test_samples += X.shape[0]
test_acc = test_acc_sum / test_samples
print(f'验证集精度：{test_acc:.4f}')