In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 定义CNN模型
class CNNNet(nn.Module):
    def __init__(self):
        super(CNNNet, self).__init__()

        self.conv1 = nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2)
        self.fc1 = nn.Linear(7*7*64, 1024)
        self.fc2 = nn.Linear(1024, 10)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 7*7*64)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 定义训练函数
def train(model, device, data_loader, optimizer, criterion, epoch):
    watch_batch_size = 100
    model.train()
    avg_loss = 0.0
    for batch_idx, (data, target) in enumerate(data_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        avg_loss += loss.item()
        if batch_idx % watch_batch_size == (watch_batch_size - 1):  # 每100个batch打印一次
            avg_loss =  avg_loss / watch_batch_size
            print(f'Epoch {epoch + 1}, Batch [{batch_idx + 1}/{len(data_loader)}], AvgLoss: {avg_loss:.4f}')
            avg_loss = 0.0

# 定义验证函数
def evaluate(model, device, data_loader, criterion):
    data_len = 0

    model.eval()
    loss = 0.0
    correct = 0.0
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss += criterion(output, target).item()  # 累积损失
            data_len += 1
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    loss = loss / data_len
    accuracy = 100. * correct / data_len
    return accuracy, loss

def main():
    # 设置设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f'Using device: {device}')

    # 实例化模型
    model = CNNNet().to(device)

    # 加载MNIST数据集
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

    validate_dataset = datasets.MNIST('./data', train=False, transform=transform)
    validate_loader = DataLoader(validate_dataset, batch_size=64, shuffle=True)

    # 初始化模型、损失函数和优化器
    model = CNNNet().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())

    # 训练和验证循环
    num_epochs = 10
    for epoch in range(num_epochs):
        train(model, device, train_loader, optimizer, criterion, epoch)
        train_accuracy, train_loss = evaluate(model, device, train_loader, criterion)
        validate_accuracy, validate_loss = evaluate(model, device, validate_loader, criterion)
        print(f'Epoch {epoch + 1}, train_accuracy = {train_accuracy:.4f}%, validate_accuracy = {validate_accuracy:.4f}%')
        print(f'Epoch {epoch + 1}, train_loss = {train_loss:.4f}, validate_loss = {validate_loss:.4f}')

    torch.save(model.state_dict(), 'mnist_cnn.pth')
    print('Train finished')
    
if __name__ == "__main__":
    main()

Using device: cpu
Epoch 1, Batch [100/938], AvgLoss: 0.4027
Epoch 1, Batch [200/938], AvgLoss: 0.1167
Epoch 1, Batch [300/938], AvgLoss: 0.0737
Epoch 1, Batch [400/938], AvgLoss: 0.0712
Epoch 1, Batch [500/938], AvgLoss: 0.0606
Epoch 1, Batch [600/938], AvgLoss: 0.0806
Epoch 1, Batch [700/938], AvgLoss: 0.0542
Epoch 1, Batch [800/938], AvgLoss: 0.0619
Epoch 1, Batch [900/938], AvgLoss: 0.0559
Epoch 1, train_accuracy = 98.5867%, validate_accuracy = 98.3600%
Epoch 1, train_loss = 0.0007, validate_loss = 0.0007
Epoch 2, Batch [100/938], AvgLoss: 0.0362
Epoch 2, Batch [200/938], AvgLoss: 0.0404
Epoch 2, Batch [300/938], AvgLoss: 0.0356
Epoch 2, Batch [400/938], AvgLoss: 0.0385
Epoch 2, Batch [500/938], AvgLoss: 0.0419
Epoch 2, Batch [600/938], AvgLoss: 0.0418
Epoch 2, Batch [700/938], AvgLoss: 0.0395
Epoch 2, Batch [800/938], AvgLoss: 0.0379
Epoch 2, Batch [900/938], AvgLoss: 0.0305
Epoch 2, train_accuracy = 98.8850%, validate_accuracy = 98.7800%
Epoch 2, train_loss = 0.0005, validate_loss

KeyboardInterrupt: 