In [1]:
! pip install d2l==0.17

Collecting d2l==0.17
  Downloading d2l-0.17.0-py3-none-any.whl.metadata (346 bytes)
Collecting jupyter (from d2l==0.17)
  Downloading jupyter-1.1.1-py2.py3-none-any.whl.metadata (2.0 kB)
Downloading d2l-0.17.0-py3-none-any.whl (83 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m83.2/83.2 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jupyter-1.1.1-py2.py3-none-any.whl (2.7 kB)
Installing collected packages: jupyter, d2l
Successfully installed d2l-0.17.0 jupyter-1.1.1


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import models, transforms
from torchvision.transforms import AutoAugment, AutoAugmentPolicy
from d2l import torch as d2l
%matplotlib inline

In [3]:
import numpy as np

In [4]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [5]:
def get_data_transforms(img_size):
    """预处理"""
    train_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.RandomCrop(img_size, padding=16),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.RandomGrayscale(p=0.1),
        transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 0.5)),
        transforms.ToTensor(),
        transforms.RandomErasing(p=0.2),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet均值
                             std=[0.229, 0.224, 0.225])
    ])
    test_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    return train_transform, test_transform

In [6]:
d2l.set_figsize()
devices = d2l.try_all_gpus()
batch_size = 128
img_size = 224  # ResNet50适配输入尺寸（32×32→224×224）
num_classes = 10
epochs = 80
init_lr = 1e-3
weight_decay = 1e-4
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device_ids = [0, 1] if len(devices) >= 2 else [0]
save_path = "best_resnet50_cifar10.pth" 

In [7]:
def build_resnet50(num_classes, device_ids):
    # 加载ImageNet预训练的ResNet50
    model = models.resnet50(pretrained=True)

    for param in list(model.parameters())[:40]:
        param.requires_grad = False
    in_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(in_features, 512),
        nn.ReLU(inplace=True),
        nn.BatchNorm1d(512),  # 稳定一下
        nn.Dropout(0.3),
        nn.Linear(512, num_classes)
    )
    for m in model.fc.modules():
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    if len(device_ids) > 1:
        model = nn.DataParallel(model, device_ids=device_ids)
    model = model.to(device)
    return model
    
model = build_resnet50(num_classes, device_ids)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 218MB/s]


In [8]:
def load_cifar10(is_train, augs, batch_size):
    dataset = torchvision.datasets.CIFAR10(root="../data", train=is_train,
                                           transform=augs, download=True)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                    shuffle=is_train, num_workers=d2l.get_dataloader_workers(), pin_memory=True)
    return dataloader

In [9]:
from torch.optim.lr_scheduler import CosineAnnealingLR

In [10]:
class EarlyStopping:
    def __init__(self, patience=10, min_delta=1e-4):
        self.patience = patience
        self.min_delta = min_delta
        self.best_acc = 0.0
        self.counter = 0
        self.early_stop = False

    def __call__(self, current_acc):
        if current_acc > self.best_acc + self.min_delta:
            self.best_acc = current_acc
            self.counter = 0
            return True  
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
            return False

early_stopping = EarlyStopping(patience=10)

In [11]:
def train_one_epoch(model, train_iter, criterion, optimizer, device, use_mixup=True):
    model.train()
    total_loss, total_correct, total_samples = 0.0, 0, 0
    for X, y in train_iter:
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        
        if use_mixup:
            lam = np.random.beta(0.2, 0.2)
            index = torch.randperm(X.size(0)).to(device)
            mixed_X = lam * X + (1 - lam) * X[index]
            y_hat = model(mixed_X)
            loss = lam * criterion(y_hat, y) + (1 - lam) * criterion(y_hat, y[index])
            pred = y_hat.argmax(dim=1)
            correct = lam * (pred == y).sum().item() + (1 - lam) * (pred == y[index]).sum().item()
        else:
            y_hat = model(X)
            loss = criterion(y_hat, y)
            correct = (y_hat.argmax(dim=1) == y).sum().item()
        
        loss.backward()
        optimizer.step()
        
        batch_size = X.shape[0]
        total_loss += loss.item() * batch_size
        total_correct += correct
        total_samples += batch_size
    return total_loss / total_samples, total_correct / total_samples

In [12]:
def evaluate_with_tta(model, test_iter, criterion, device, tta_times=5):
    model.eval()
    total_loss, total_correct, total_samples = 0.0, 0, 0

    tta_transform = transforms.Compose([transforms.RandomHorizontalFlip(p=0.5)])
    
    with torch.no_grad():
        for X, y in test_iter:
            X, y = X.to(device), y.to(device)
            batch_size = X.shape[0]
            y_hat_avg = torch.zeros(batch_size, num_classes).to(device)
            
            # 多次增强后取平均
            for _ in range(tta_times):
                X_tta = tta_transform(X)
                y_hat_avg += model(X_tta)
            y_hat_avg /= tta_times
            
            loss = criterion(y_hat_avg, y)
            total_loss += loss.item() * batch_size
            total_correct += (y_hat_avg.argmax(dim=1) == y).sum().item()
            total_samples += batch_size
    model.train()
    return total_loss / total_samples, total_correct / total_samples

In [13]:
if __name__ == "__main__":

    train_transform, test_transform = get_data_transforms(img_size)
    train_iter = load_cifar10(True, train_transform, batch_size)
    test_iter = load_cifar10(False, test_transform, batch_size)
    
    model = build_resnet50(num_classes, device_ids)
    
    optimizer = optim.AdamW(model.parameters(), lr=init_lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)  
    
    early_stopping = EarlyStopping(patience=10)
    
    train_losses, train_accs = [], []
    test_losses, test_accs = [], []
    unfreeze_epoch = 20  
    
    for epoch in range(epochs):

        if epoch == unfreeze_epoch:
            print("解冻所有层，降低学习率微调")
            for param in model.parameters():
                param.requires_grad = True
            optimizer = optim.AdamW(model.parameters(), lr=init_lr * 0.1, weight_decay=weight_decay)
            scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10)
        
        # 训练+评估
        train_loss, train_acc = train_one_epoch(model, train_iter, criterion, optimizer, device)
        test_loss, test_acc = evaluate_with_tta(model, test_iter, criterion, device)
        
        # 保存指标
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        test_losses.append(test_loss)
        test_accs.append(test_acc)
        
        # 打印日志
        print(f"Epoch [{epoch+1}/{epochs}]")
        print(f"训练集：损失={train_loss:.4f} | 准确率={train_acc:.4f}")
        print(f"测试集：损失={test_loss:.4f} | 准确率={test_acc:.4f}\n")
        
        # 早停+保存模型
        if early_stopping(test_acc):
            torch.save(model.state_dict(), save_path)
            print(f"保存最佳模型，当前最佳测试准确率：{early_stopping.best_acc:.4f}")
        elif (epoch + 1) % 5 == 0:
            torch.save(model.state_dict(), f"checkpoint_epoch_{epoch+1}.pth")
        
        # 学习率调度
        scheduler.step()
        if early_stopping.early_stop:
            print("早停触发，终止训练")
            break

100%|██████████| 170M/170M [00:02<00:00, 80.4MB/s]


Epoch [1/80]
训练集：损失=1.6976 | 准确率=0.5370
测试集：损失=1.0865 | 准确率=0.7702

保存最佳模型，当前最佳测试准确率：0.7702
Epoch [2/80]
训练集：损失=1.3324 | 准确率=0.6852
测试集：损失=0.9864 | 准确率=0.7980

保存最佳模型，当前最佳测试准确率：0.7980
Epoch [3/80]
训练集：损失=1.1645 | 准确率=0.7542
测试集：损失=0.8687 | 准确率=0.8510

保存最佳模型，当前最佳测试准确率：0.8510
Epoch [4/80]
训练集：损失=1.1217 | 准确率=0.7711
测试集：损失=0.7748 | 准确率=0.8984

保存最佳模型，当前最佳测试准确率：0.8984
Epoch [5/80]
训练集：损失=1.0915 | 准确率=0.7823
测试集：损失=0.7487 | 准确率=0.9029

保存最佳模型，当前最佳测试准确率：0.9029
Epoch [6/80]
训练集：损失=1.0529 | 准确率=0.7951
测试集：损失=0.7032 | 准确率=0.9274

保存最佳模型，当前最佳测试准确率：0.9274
Epoch [7/80]
训练集：损失=0.9815 | 准确率=0.8189
测试集：损失=0.6924 | 准确率=0.9328

保存最佳模型，当前最佳测试准确率：0.9328
Epoch [8/80]
训练集：损失=0.9383 | 准确率=0.8421
测试集：损失=0.6517 | 准确率=0.9458

保存最佳模型，当前最佳测试准确率：0.9458
Epoch [9/80]
训练集：损失=0.9559 | 准确率=0.8354
测试集：损失=0.6484 | 准确率=0.9490

保存最佳模型，当前最佳测试准确率：0.9490
Epoch [10/80]
训练集：损失=0.9524 | 准确率=0.8350
测试集：损失=0.6475 | 准确率=0.9535

保存最佳模型，当前最佳测试准确率：0.9535
Epoch [11/80]
训练集：损失=1.0897 | 准确率=0.7766
测试集：损失=0.8471 | 准确率=0.8652

Epoch [12/