In [1]:
from copy import deepcopy

import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from tqdm import tqdm

from model import LeNet5

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Train

In [3]:
train_model = LeNet5().to(device)
optimizer = torch.optim.Adam(params=train_model.parameters(), lr=1e-5)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=20, gamma=0.5)
loss_fn = nn.CrossEntropyLoss()

In [4]:
train_transform = transforms.Compose([
    transforms.Resize(32),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

train_dataset = datasets.MNIST(
    root='MNIST',
    train=True,
    download=False,
    transform=train_transform,
)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=128,
    num_workers=4,
    shuffle=True,
)

In [5]:
epochs = 300
for epoch in range(epochs):
    train_model.train()
    mloss = torch.zeros(1, device=device)  # mean_loss
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Epoch {epoch}/{epochs}', unit='batches')

    for i, (imgs, labels) in pbar:
        imgs, labels = imgs.to(device), labels.to(device)
        preds = train_model(imgs)
        loss = loss_fn(preds, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        mloss = (mloss * i + loss) / (i + 1)
        mem = f'{torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0:.3g}G'  # GPU_mem
        pbar.set_postfix(loss=mloss.item(), GPU_mem=mem)

    ckpt = {  # checkpoint
        'epoch': epoch,
        'model': deepcopy(train_model).half(),
        'optimizer': optimizer.state_dict(),
    }
    torch.save(ckpt, 'LeNet5.pt')

Epoch 0/300: 100%|██████████| 469/469 [00:06<00:00, 73.34batches/s, GPU_mem=0.0503G, loss=2.25] 
Epoch 1/300: 100%|██████████| 469/469 [00:04<00:00, 105.11batches/s, GPU_mem=0.0503G, loss=1.98]
Epoch 2/300: 100%|██████████| 469/469 [00:04<00:00, 104.18batches/s, GPU_mem=0.0503G, loss=1.5] 
Epoch 3/300: 100%|██████████| 469/469 [00:04<00:00, 104.23batches/s, GPU_mem=0.0503G, loss=1.17]
Epoch 4/300: 100%|██████████| 469/469 [00:04<00:00, 105.46batches/s, GPU_mem=0.0503G, loss=1]   
Epoch 5/300: 100%|██████████| 469/469 [00:06<00:00, 77.84batches/s, GPU_mem=0.0503G, loss=0.902] 
Epoch 6/300: 100%|██████████| 469/469 [00:06<00:00, 70.31batches/s, GPU_mem=0.0503G, loss=0.838] 
Epoch 7/300: 100%|██████████| 469/469 [00:06<00:00, 74.93batches/s, GPU_mem=0.0503G, loss=0.791] 
Epoch 8/300: 100%|██████████| 469/469 [00:05<00:00, 80.92batches/s, GPU_mem=0.0503G, loss=0.754] 
Epoch 9/300: 100%|██████████| 469/469 [00:06<00:00, 76.56batches/s, GPU_mem=0.0503G, loss=0.724] 
Epoch 10/300: 100%|██████

# Test

In [6]:
ckpt = torch.load('LeNet5.pt')
test_model = ckpt['model'].to(device).float()
test_model.eval()

LeNet5(
  (backbone): Sequential(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): Tanh()
    (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (3): Tanh()
    (4): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (5): Tanh()
    (6): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (7): Tanh()
    (8): Conv2d(16, 120, kernel_size=(5, 5), stride=(1, 1))
    (9): Tanh()
    (10): Flatten(start_dim=1, end_dim=-1)
    (11): Linear(in_features=120, out_features=84, bias=True)
    (12): Linear(in_features=84, out_features=10, bias=True)
  )
)

In [7]:
test_transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
])

test_dataset = datasets.MNIST(
    root='MNIST',
    train=False,
    download=False,
    transform=test_transform,
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=128,
    num_workers=4,
    shuffle=False,
)

In [8]:
correct = torch.zeros(1, device=device)
total = torch.zeros(1, device=device)

with torch.no_grad():
    pbar = tqdm(enumerate(test_loader), total=len(test_loader), desc='Test', unit='batches')

    for i, (imgs, labels) in pbar:
        imgs, labels = imgs.to(device), labels.to(device)
        preds = test_model(imgs)
        preds = torch.argmax(nn.Softmax(dim=1)(preds), dim=1)  # 将预测结果经softmax后取最大值的序号为预测标签

        total += torch.tensor(labels.size(0))
        correct += (preds == labels).sum().item()

Test: 100%|██████████| 79/79 [00:02<00:00, 30.54batches/s]


In [9]:
accuracy = round((correct / total).item(), 3)
accuracy

0.974