In [None]:
# 导入相关依赖库
import os
import struct
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt

# 设定运行设备 (如果有 GPU 则使用 GPU，否则使用 CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
print("STEP1 finish.")

In [None]:
# =========================================================================
# 自定义 Dataset 类，用于读取本地特定结构的 MNIST 文件
# =========================================================================
class LocalMNIST(Dataset):
    def __init__(self, root_dir, train=True, transform=None):
        """
        Args:
            root_dir (string): 数据集根目录，例如 "./MNIST"
            train (bool): True 加载 train 文件夹，False 加载 test 文件夹
            transform (callable, optional): 数据预处理转换
        """
        self.transform = transform
        
        # 训练集在 MNIST/train，测试集在 MNIST/test
        sub_dir = "train" if train else "test"
        path = os.path.join(root_dir, sub_dir)
        
        if train:
            img_file = "train-images.idx3-ubyte"
            lbl_file = "train-labels.idx1-ubyte"
        else:
            img_file = "t10k-images.idx3-ubyte"
            lbl_file = "t10k-labels.idx1-ubyte"
            
        self.images = self._read_images(os.path.join(path, img_file))
        self.labels = self._read_labels(os.path.join(path, lbl_file))

    def _read_images(self, path):
        with open(path, 'rb') as f:
            # 读取头文件信息：magic, count, rows, cols
            magic, num, rows, cols = struct.unpack(">IIII", f.read(16))
            # 读取像素数据
            data = np.fromfile(f, dtype=np.uint8)
        # 重塑形状为 (N, H, W)
        return data.reshape(num, rows, cols)

    def _read_labels(self, path):
        with open(path, 'rb') as f:
            # 读取头文件信息：magic, count
            magic, num = struct.unpack(">II", f.read(8))
            # 读取标签数据
            data = np.fromfile(f, dtype=np.uint8)
        return data

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        # 获取图片和标签
        img = self.images[idx]
        label = self.labels[idx]
        
        # 转换为 PIL Image 以便使用 torchvision 的 transform
        # 或者直接转换为 Tensor
        
        # 这里为了配合 transform，我们先归一化并转为 Tensor
        # img 是 numpy (28, 28)，uint8
        
        if self.transform:
            # ToTensor() 期望输入是 PIL Image 或 numpy.ndarray (H, W, C)
            # 这里的 img 是 (H, W)，我们需要扩充一个维度变成 (H, W, 1) 或者是直接转
            import PIL.Image as Image
            img = Image.fromarray(img, mode='L')
            img = self.transform(img)
        else:
            # 如果没有定义 transform，手动转 Tensor 并归一化
            img = torch.tensor(img, dtype=torch.float32) / 255.0
            img = img.unsqueeze(0) # 增加通道维度 (1, 28, 28)

        return img, int(label)
print("STEP2 finish.")

In [None]:
# 1. 定义数据预处理
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(), # 自动归一化到 [0, 1] 并转为 (C, H, W)
])

# 2. 实例化自定义数据集
# 假设你的当前目录下有 MNIST 文件夹，里面分了 train 和 test 子文件夹
# 如果你的路径不同，请修改 root_dir
current_dir = './MNIST'  # 你的数据集根目录

try:
    train_data = LocalMNIST(root_dir=current_dir, train=True, transform=transform)
    test_data = LocalMNIST(root_dir=current_dir, train=False, transform=transform)
    
    # 3. 创建 DataLoader
    batch_size = 128
    train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

    print('成功加载本地数据集！')
    print('训练数据集数量： ', len(train_data))
    print('测试数据集数量： ', len(test_data))

    # 4. 可视化检查（确保读取正确）
    # 获取一个 batch
    data_iter = iter(train_dataloader)
    images, labels = next(data_iter)
    
    print('图像 Tensor 形状:', images.shape) # 应为 [128, 1, 28, 28]
    
    # 查看部分数据
    for X, y in test_dataloader:
        print('Shape of X [N, C, H, W]: ', X.shape)
        print('Shape of y: ', y.shape, y.dtype)
        break

    print("STEP3 finish.")

except FileNotFoundError as e:
    print("错误：未找到文件。请确认你的目录结构如下：")
    print(f"{current_dir}/train/train-images.idx3-ubyte")
    print(f"{current_dir}/train/train-labels.idx1-ubyte")
    print(f"{current_dir}/test/t10k-images.idx3-ubyte")
    print(f"{current_dir}/test/t10k-labels.idx1-ubyte")
    print(f"详细错误信息: {e}")

In [None]:
# 显示前 10 张图片以及对应标签
figure = plt.figure(figsize=(10, 4))
cols, rows = 5, 2

# 从数据集中手动获取一些样本
# 我们使用 DataLoader 的迭代器
data_iter = iter(test_dataloader)
images, labels = next(data_iter)

for i in range(1, cols * rows + 1):
    # PyTorch 的图像是 (C, H, W)，显示时需要转换为 (H, W) 或 (H, W, C)
    # squeeze() 去掉通道维度 (1, 28, 28) -> (28, 28)
    img = images[i-1].squeeze().numpy()
    label = labels[i-1].item()
    
    figure.add_subplot(rows, cols, i)
    plt.title(f"Number: {label}")
    plt.axis("off")
    plt.imshow(img, cmap="gray")
    
plt.show()
print("STEP4 finish.")

In [None]:
# 损失函数：交叉熵损失 (包含了 Softmax 和 Log-Likelihood)
loss_fn = nn.CrossEntropyLoss()

# 优化器：Adam
lr = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
print("STEP5 finish.")

In [None]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train() # 设置为训练模式
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # 计算预测和损失
        pred = model(X)
        loss = loss_fn(pred, y)

        # 反向传播
        optimizer.zero_grad() # 清空梯度
        loss.backward()       # 计算梯度
        optimizer.step()      # 更新参数

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval() # 设置为评估模式
    test_loss, correct = 0, 0
    with torch.no_grad(): # 不计算梯度
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            # 这里的 pred 是 logits，argmax 获取最大概率的索引
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

# 开始训练
epochs = 10
print("============== Starting Training ==============")
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)

print("Done!")

# 保存模型
torch.save(model.state_dict(), "model_pytorch.pth")
print("STEP6 finish.")
print("Saved PyTorch Model State to model_pytorch.pth")