In [1]:
from copy import deepcopy

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler
from tqdm import tqdm

from model import Model

In [2]:
# hyps
epochs = 100
batch_size_train = 512  # 训练的batch_size
batch_size_test = 1000  # 测试的batch_size
learning_rate = 1e-3  # 学习率
momentum = 0.5  # 优化器动量

In [3]:
random_seed = 1  # 随机种子
torch.manual_seed(random_seed)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [4]:
# train transform, dataset, dataloader
train_transform = transforms.Compose([
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.RandomRotation((-10, 10)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
])

train_dataset = datasets.MNIST(
    root='../Practice2/MNIST',
    train=True,
    download=False,
    transform=train_transform,
)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=512,
    num_workers=0,
    shuffle=True,
    pin_memory=True,
)

In [5]:
# test transform, dataset, dataloader
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
])

test_dataset = datasets.MNIST(
    root='../Practice2/MNIST',
    train=False,
    download=False,
    transform=test_transform,
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=1024,
    num_workers=0,
    shuffle=False,
    pin_memory=True,
)

In [6]:
model = Model().to(device)
optimizer = optim.RMSprop(model.parameters(), lr=learning_rate, alpha=0.99, momentum=momentum)
scheduler = lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='max',
    factor=0.5,
    patience=3,
    verbose=False,
    threshold=0.00005,
    threshold_mode='rel',
    cooldown=0,
    min_lr=0,
    eps=1e-08
)
loss_fn = nn.NLLLoss()  # 模型输出的最后一层为LogSoftMax函数，故这里只需使用NLL_Loss即可变为CrossEntropyLoss

In [7]:
test_accuracies = []  # 存放每次测试的accuracy

In [8]:
def train(epoch):  # single epoch
    correct = torch.zeros(1, device=device)
    total = torch.zeros(1, device=device)
    mloss = torch.zeros(1, device=device)  # mean_loss

    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Epoch {epoch}/{epochs}', unit='batches')
    model.train()
    for i, (imgs, labels) in pbar:
        imgs, labels = imgs.to(device), labels.to(device)
        preds = model(imgs)
        loss = loss_fn(preds, labels)
        preds_ = torch.argmax(preds, dim=1)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        mloss = (mloss * i + loss) / (i + 1)
        total += torch.tensor(labels.size(0))
        correct += (preds_ == labels).sum()
        accuracy = (correct / total).item()

        mem = f'{torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0:.3g}G'  # GPU_mem
        pbar.set_postfix(loss=mloss.item(), GPU_mem=mem, accuracy=round(accuracy, 3))
    scheduler.step(mloss)

In [9]:
def test(epoch):  # single epoch
    correct = torch.zeros(1, device=device)
    total = torch.zeros(1, device=device)

    model.eval()
    with torch.no_grad():
        pbar = tqdm(enumerate(test_loader), total=len(test_loader), desc='Test', unit='batches')
        for i, (imgs, labels) in pbar:
            imgs, labels = imgs.to(device), labels.to(device)
            preds = model(imgs)  # 传入这一组 batch，进行前向计算
            preds = torch.argmax(preds, dim=1)

            total += torch.tensor(labels.size(0))
            correct += (preds == labels).sum()
            accuracy = (correct / total).item()
            pbar.set_postfix(accuracy=round(accuracy, 3))
    test_accuracies.append(accuracy)

    if accuracy >= max(test_accuracies):
        ckpt = {  # checkpoint
            'epoch': epoch,
            'model': deepcopy(model).half(),
            'optimizer': optimizer.state_dict(),
        }
        torch.save(ckpt, 'Model.pt')

In [10]:
for e in range(1, epochs + 1):
    train(e)
    test(e)
print(f'max accuracy: {100 * max(test_accuracies):.1f}%')

Epoch 1/100: 100%|██████████| 118/118 [00:18<00:00,  6.36batches/s, GPU_mem=0.484G, accuracy=0.735, loss=0.894]
Test: 100%|██████████| 10/10 [00:02<00:00,  4.04batches/s, accuracy=0.941]
Epoch 2/100: 100%|██████████| 118/118 [00:26<00:00,  4.50batches/s, GPU_mem=0.614G, accuracy=0.948, loss=0.189]
Test: 100%|██████████| 10/10 [00:00<00:00, 11.56batches/s, accuracy=0.969]
Epoch 3/100: 100%|██████████| 118/118 [00:13<00:00,  8.85batches/s, GPU_mem=0.614G, accuracy=0.961, loss=0.141]
Test: 100%|██████████| 10/10 [00:00<00:00, 11.25batches/s, accuracy=0.973]
Epoch 4/100: 100%|██████████| 118/118 [00:13<00:00,  8.77batches/s, GPU_mem=0.614G, accuracy=0.968, loss=0.117]
Test: 100%|██████████| 10/10 [00:00<00:00, 11.47batches/s, accuracy=0.991]
Epoch 5/100: 100%|██████████| 118/118 [00:13<00:00,  8.83batches/s, GPU_mem=0.614G, accuracy=0.972, loss=0.1]   
Test: 100%|██████████| 10/10 [00:00<00:00, 11.64batches/s, accuracy=0.989]
Epoch 6/100: 100%|██████████| 118/118 [00:13<00:00,  8.89batches

max accuracy: 99.7%



