In [1]:
#导入必要的库
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms

In [2]:
# 设定超参数
batch_size = 64
learning_rate = 0.001
num_epochs = 10

In [3]:
# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),          # 将图像转化为张量
    transforms.Normalize((0.1307,), (0.3081,)) # 对图像进行归一化处理
])


In [4]:
# 加载数据集
train_data = MNIST(root='./data', train=True, download=True, transform=transform)
test_data = MNIST(root='./data', train=False, download=True, transform=transform)

In [5]:
# 创建DataLoader
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

In [6]:
# 定义网络模型
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(64, 10)  # 输出层为10个分类
        )

    def forward(self, x):
        x = x.view(-1, 28*28)
        return self.fc(x)

model = SimpleNet()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

In [7]:
# 训练过程
for epoch in range(num_epochs):
    model.train()
    for i, (images, labels) in enumerate(train_loader):
        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}], Loss: {loss.item():.4f}')

# 测试过程
with torch.no_grad():
    correct = 0
    total = 0
    model.eval()
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print('Test Accuracy of the model on {} test images: {:.2f} %'.format(total, accuracy))

Epoch [1/10], Step [100], Loss: 0.7927
Epoch [1/10], Step [200], Loss: 0.5853
Epoch [1/10], Step [300], Loss: 0.6704
Epoch [1/10], Step [400], Loss: 0.2987
Epoch [1/10], Step [500], Loss: 0.3050
Epoch [1/10], Step [600], Loss: 0.4666
Epoch [1/10], Step [700], Loss: 0.4651
Epoch [1/10], Step [800], Loss: 0.3782
Epoch [1/10], Step [900], Loss: 0.4698
Epoch [2/10], Step [100], Loss: 0.3279
Epoch [2/10], Step [200], Loss: 0.1600
Epoch [2/10], Step [300], Loss: 0.1242
Epoch [2/10], Step [400], Loss: 0.2654
Epoch [2/10], Step [500], Loss: 0.3670
Epoch [2/10], Step [600], Loss: 0.2862
Epoch [2/10], Step [700], Loss: 0.2958
Epoch [2/10], Step [800], Loss: 0.2446
Epoch [2/10], Step [900], Loss: 0.4469
Epoch [3/10], Step [100], Loss: 0.1754
Epoch [3/10], Step [200], Loss: 0.1586
Epoch [3/10], Step [300], Loss: 0.1119
Epoch [3/10], Step [400], Loss: 0.2658
Epoch [3/10], Step [500], Loss: 0.2807
Epoch [3/10], Step [600], Loss: 0.2507
Epoch [3/10], Step [700], Loss: 0.3555
Epoch [3/10], Step [800],