In [None]:
# 导入相关依赖库
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
import numpy as np
import matplotlib.pyplot as plt

# 设定运行设备 (如果有 GPU 则使用 GPU，否则使用 CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
# 定义数据转换
# Resize: 确保大小为 28x28 (MNIST 默认就是，但为了兼容性保留)
# ToTensor: 转换为 Tensor 格式，并将像素值归一化到 [0, 1]，同时调整通道顺序
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
])

# 下载/读取训练集和测试集
# 注意：torchvision 会自动处理 MNIST 数据集的下载和解压。
# 如果你已经有了本地文件（如你上传的 .idx 文件），可以将 download=True 改为 False，
# 并确保 root 目录结构符合 torchvision 的要求 (通常在 ./MNIST/raw/ 下)。
# 为了方便，这里建议保留 download=True，它会自动检测。
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=transform,
)

test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=transform,
)

# 创建数据加载器 (DataLoader)
batch_size = 128
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

print('训练数据集数量： ', len(training_data))
print('测试数据集数量： ', len(test_data))

# 查看部分数据
for X, y in test_dataloader:
    print('Shape of X [N, C, H, W]: ', X.shape)
    print('Shape of y: ', y.shape, y.dtype)
    break

In [None]:
# 显示前 10 张图片以及对应标签
figure = plt.figure(figsize=(10, 4))
cols, rows = 5, 2

# 从数据集中手动获取一些样本
# 我们使用 DataLoader 的迭代器
data_iter = iter(test_dataloader)
images, labels = next(data_iter)

for i in range(1, cols * rows + 1):
    # PyTorch 的图像是 (C, H, W)，显示时需要转换为 (H, W) 或 (H, W, C)
    # squeeze() 去掉通道维度 (1, 28, 28) -> (28, 28)
    img = images[i-1].squeeze().numpy()
    label = labels[i-1].item()
    
    figure.add_subplot(rows, cols, i)
    plt.title(f"Number: {label}")
    plt.axis("off")
    plt.imshow(img, cmap="gray")
    
plt.show()

In [None]:
# 创建模型
class ForwardNN(nn.Module):
    def __init__(self):
        super(ForwardNN, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
            # 这里的输出是 logits (未归一化的分数)，CrossEntropyLoss 会处理 Softmax
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = ForwardNN().to(device)
print(model)

In [None]:
# 损失函数：交叉熵损失 (包含了 Softmax 和 Log-Likelihood)
loss_fn = nn.CrossEntropyLoss()

# 优化器：Adam
lr = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train() # 设置为训练模式
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # 计算预测和损失
        pred = model(X)
        loss = loss_fn(pred, y)

        # 反向传播
        optimizer.zero_grad() # 清空梯度
        loss.backward()       # 计算梯度
        optimizer.step()      # 更新参数

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval() # 设置为评估模式
    test_loss, correct = 0, 0
    with torch.no_grad(): # 不计算梯度
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            # 这里的 pred 是 logits，argmax 获取最大概率的索引
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

# 开始训练
epochs = 10
print("============== Starting Training ==============")
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)

print("Done!")

# 保存模型
torch.save(model.state_dict(), "model_pytorch.pth")
print("Saved PyTorch Model State to model_pytorch.pth")