# 第4课：图像分类实战

## 学习目标
- 掌握图像数据预处理
- 学会使用预训练模型
- 理解迁移学习
- 完成 CIFAR-10 分类任务

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")

## 1. 数据加载与预处理

In [None]:
# CIFAR-10 类别
classes = ('plane', 'car', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck')

# 数据增强
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

# 加载数据集
train_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10('./data', train=False, transform=test_transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=2)

print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")

In [None]:
# 可视化数据增强效果
def show_augmentation(dataset, idx=0):
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    
    # 原始图像
    original_transform = transforms.Compose([transforms.ToTensor()])
    original_dataset = datasets.CIFAR10('./data', train=True, transform=original_transform)
    original_img, label = original_dataset[idx]
    
    axes[0, 0].imshow(original_img.permute(1, 2, 0))
    axes[0, 0].set_title(f'原始: {classes[label]}')
    axes[0, 0].axis('off')
    
    # 增强后的图像
    for i in range(1, 10):
        ax = axes[i // 5, i % 5]
        img, _ = dataset[idx]
        # 反归一化显示
        img = img * torch.tensor([0.2470, 0.2435, 0.2616]).view(3, 1, 1)
        img = img + torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
        img = torch.clamp(img, 0, 1)
        ax.imshow(img.permute(1, 2, 0))
        ax.set_title(f'增强 {i}')
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

show_augmentation(train_dataset)

## 2. 构建自定义 CNN

In [None]:
class CIFAR10CNN(nn.Module):
    def __init__(self, num_classes=10):
        super(CIFAR10CNN, self).__init__()
        
        # 卷积块 1
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)  # 32 -> 16
        )
        
        # 卷积块 2
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)  # 16 -> 8
        )
        
        # 卷积块 3
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)  # 8 -> 4
        )
        
        # 分类器
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256 * 4 * 4, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

model = CIFAR10CNN().to(device)
print(f"参数量: {sum(p.numel() for p in model.parameters()):,}")

## 3. 训练模型

In [None]:
# 训练设置
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)

def train_one_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * data.size(0)
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)
    
    return total_loss / total, correct / total

def evaluate(model, test_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            total_loss += criterion(output, target).item() * data.size(0)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)
    
    return total_loss / total, correct / total

In [None]:
# 训练循环
num_epochs = 10
history = {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': []}

for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    scheduler.step()
    
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['test_loss'].append(test_loss)
    history['test_acc'].append(test_acc)
    
    print(f'Epoch {epoch+1}/{num_epochs}')
    print(f'  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
    print(f'  Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')
    print(f'  LR: {scheduler.get_last_lr()[0]:.6f}')

In [None]:
# 可视化训练历史
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(history['train_loss'], label='训练')
axes[0].plot(history['test_loss'], label='测试')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].set_title('损失曲线')

axes[1].plot(history['train_acc'], label='训练')
axes[1].plot(history['test_acc'], label='测试')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].legend()
axes[1].set_title('准确率曲线')

plt.tight_layout()
plt.show()

## 4. 迁移学习

In [None]:
# 使用预训练的 ResNet
def create_resnet_model(num_classes=10, pretrained=True):
    # 加载预训练模型
    model = models.resnet18(weights='IMAGENET1K_V1' if pretrained else None)
    
    # 修改第一层（因为 CIFAR-10 是 32x32）
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()  # 移除 maxpool
    
    # 修改最后一层
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    
    return model

resnet_model = create_resnet_model().to(device)
print(f"ResNet18 参数量: {sum(p.numel() for p in resnet_model.parameters()):,}")

In [None]:
# 冻结部分层进行微调
def freeze_layers(model, num_layers_to_freeze):
    """冻结前 N 层"""
    layers = list(model.children())
    for i, layer in enumerate(layers[:num_layers_to_freeze]):
        for param in layer.parameters():
            param.requires_grad = False
    
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    print(f"可训练参数: {trainable:,} / {total:,}")

# 冻结前 6 层
freeze_layers(resnet_model, 6)

In [None]:
# 训练迁移学习模型
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, resnet_model.parameters()), 
                        lr=0.001, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# 训练几个 epoch
for epoch in range(5):
    train_loss, train_acc = train_one_epoch(resnet_model, train_loader, criterion, optimizer, device)
    test_loss, test_acc = evaluate(resnet_model, test_loader, criterion, device)
    scheduler.step()
    
    print(f'Epoch {epoch+1}/5')
    print(f'  Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

## 5. 模型评估与分析

In [None]:
# 混淆矩阵
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

def get_predictions(model, test_loader, device):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for data, target in test_loader:
            data = data.to(device)
            output = model(data)
            preds = output.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(target.numpy())
    
    return np.array(all_preds), np.array(all_labels)

preds, labels = get_predictions(model, test_loader, device)

# 绘制混淆矩阵
cm = confusion_matrix(labels, preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=classes, yticklabels=classes)
plt.xlabel('预测类别')
plt.ylabel('真实类别')
plt.title('混淆矩阵')
plt.show()

# 分类报告
print("分类报告:")
print(classification_report(labels, preds, target_names=classes))

In [None]:
# 可视化错误预测
def show_misclassified(model, test_loader, device, num_images=10):
    model.eval()
    misclassified = []
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            preds = output.argmax(dim=1)
            
            wrong_idx = (preds != target).nonzero().squeeze()
            if wrong_idx.dim() == 0:
                wrong_idx = wrong_idx.unsqueeze(0)
            
            for idx in wrong_idx:
                if len(misclassified) >= num_images:
                    break
                misclassified.append({
                    'image': data[idx].cpu(),
                    'true': target[idx].item(),
                    'pred': preds[idx].item()
                })
            
            if len(misclassified) >= num_images:
                break
    
    # 显示
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    for i, ax in enumerate(axes.flatten()):
        if i < len(misclassified):
            img = misclassified[i]['image']
            # 反归一化
            img = img * torch.tensor([0.2470, 0.2435, 0.2616]).view(3, 1, 1)
            img = img + torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
            img = torch.clamp(img, 0, 1)
            
            ax.imshow(img.permute(1, 2, 0))
            ax.set_title(f'真实: {classes[misclassified[i]["true"]]}\n预测: {classes[misclassified[i]["pred"]]}')
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

show_misclassified(model, test_loader, device)

## 6. 保存和加载模型

In [None]:
# 保存模型
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': num_epochs,
    'history': history
}, 'cifar10_model.pth')

print("模型已保存")

# 加载模型
checkpoint = torch.load('cifar10_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
print("模型已加载")

## 7. 练习题

### 练习：尝试不同的模型架构和超参数

In [None]:
# 在这里编写代码
# 1. 尝试使用其他预训练模型（如 VGG、DenseNet）
# 2. 调整数据增强策略
# 3. 尝试不同的优化器和学习率调度
# 4. 比较不同配置的性能


## 8. 本课小结

1. **数据增强**：RandomFlip、RandomCrop、ColorJitter
2. **模型设计**：卷积块、BatchNorm、Dropout
3. **迁移学习**：使用预训练模型、冻结层
4. **训练技巧**：学习率调度、权重衰减
5. **评估分析**：混淆矩阵、错误样本分析