In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

# 定义使用残差网络的CNN模型
class ResNetTransferLearning(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNetTransferLearning, self).__init__()
        # 使用预训练的ResNet18
        self.resnet = models.resnet18(pretrained=True)

        # 冻结所有卷积层
        for param in self.resnet.parameters():
            param.requires_grad = False

        # 替换最后的全连接层
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_features, num_classes)

    def forward(self, x):
        return self.resnet(x)

# 创建模型实例
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNetTransferLearning().to(device)

# 打印模型结构
print(model)

# 训练函数
def train_model(model, train_loader, val_loader, num_epochs=5, device='cuda'):
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # 记录训练指标
    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []

    print("开始训练优化模型...")

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        train_corrects = 0
        train_total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_corrects += (predicted == labels).sum().item()

        train_loss /= len(train_loader.dataset)
        train_acc = train_corrects / len(train_loader.dataset)
        train_losses.append(train_loss)
        train_accs.append(train_acc)

        # 验证阶段
        model.eval()
        val_loss = 0.0
        val_corrects = 0
        val_total = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_corrects += (predicted == labels).sum().item()

        val_loss /= len(val_loader.dataset)
        val_acc = val_corrects / len(val_loader.dataset)
        val_losses.append(val_loss)
        val_accs.append(val_acc)

        print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

    return model, train_losses, train_accs, val_losses, val_accs

# 训练模型
model, train_losses, train_accs, val_losses, val_accs = train_model(model, train_loader, val_loader, num_epochs=10, device=device)