In [11]:
import torch
import torchvision
from torch import optim, nn
import torch.nn.functional as F
from tqdm import tqdm
import argparse
import os

In [12]:

class LeNet5(nn.Module):
    """
    LeNet-5 网络结构
    输入: 1x32x32 灰度图像
    输出: 10类分类结果
    """
    def __init__(self):
        super(LeNet5, self).__init__()
        # 卷积层 C1
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        # 池化层 S2
        self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
        # 卷积层 C3
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        # 池化层 S4
        self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
        # 全连接层 F5
        self.fc1 = nn.Linear(16*5*5, 120)
        # 全连接层 F6
        self.fc2 = nn.Linear(120, 84)
        # 输出层
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # 特征提取
        x = F.tanh(self.conv1(x))
        x = self.pool1(x)
        x = F.tanh(self.conv2(x))
        x = self.pool2(x)
        # 分类器
        x = x.view(-1, 16*5*5)
        x = F.tanh(self.fc1(x))
        x = F.tanh(self.fc2(x))
        x = self.fc3(x)
        return x

In [13]:

def train(model, device, train_loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for data, target in tqdm(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

In [14]:

def test(model, device, test_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
    return correct / len(test_loader.dataset)

In [15]:
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--save_dir', type=str, default='model')
args, _ = parser.parse_known_args() 

    # 设备配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 数据加载
transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((32, 32)),
        torchvision.transforms.ToTensor()
    ])
train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    
train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size)

    # 模型初始化
model = LeNet5().to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
criterion = nn.CrossEntropyLoss()

    # 训练循环
for epoch in range(1, args.epochs + 1):
    loss = train(model, device, train_loader, optimizer, criterion)
    acc = test(model, device, test_loader)
    print(f'Epoch {epoch}: Loss={loss:.4f}, Accuracy={acc:.2%}')

    # 模型保存
save_path = os.path.join(args.save_dir, 'lenet.pth')
os.makedirs(os.path.dirname(save_path), exist_ok=True)
torch.save(model.state_dict(), save_path)

  0%|          | 0/938 [00:00<?, ?it/s]

100%|██████████| 938/938 [00:36<00:00, 25.81it/s]


Epoch 1: Loss=0.2981, Accuracy=96.25%


100%|██████████| 938/938 [00:39<00:00, 23.85it/s]


Epoch 2: Loss=0.1019, Accuracy=97.53%


100%|██████████| 938/938 [00:56<00:00, 16.54it/s]


Epoch 3: Loss=0.0672, Accuracy=98.31%


100%|██████████| 938/938 [00:45<00:00, 20.67it/s]


Epoch 4: Loss=0.0520, Accuracy=97.97%


100%|██████████| 938/938 [00:39<00:00, 23.81it/s]


Epoch 5: Loss=0.0397, Accuracy=98.30%


100%|██████████| 938/938 [00:30<00:00, 30.91it/s]


Epoch 6: Loss=0.0335, Accuracy=98.46%


100%|██████████| 938/938 [00:41<00:00, 22.70it/s]


Epoch 7: Loss=0.0279, Accuracy=98.53%


100%|██████████| 938/938 [00:34<00:00, 27.57it/s]


Epoch 8: Loss=0.0237, Accuracy=98.40%


100%|██████████| 938/938 [00:35<00:00, 26.68it/s]


Epoch 9: Loss=0.0194, Accuracy=98.34%


100%|██████████| 938/938 [00:28<00:00, 32.37it/s]


Epoch 10: Loss=0.0170, Accuracy=98.34%
