# 利用多层全连接神经网络及简单卷积神经网络对MNIST数据集分类

导入相关包

In [1]:
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn

import torch.nn.functional as F

定义是否使用GPU

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

定义多层全连接神经网络

In [3]:
class My_Model(nn.Module):
    def __init__(self, input_dim):
        super(My_Model, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 8),
            nn.ReLU(),
            nn.Linear(8, 1)
        )

    def forward(self, x):
        x = self.layers(x)
        x = x.squeeze(1) # (B, 1) -> (B)
        return x

定义LeNet神经网络

In [9]:
class LeNet(nn.Module):
    def __init__(self, output_dim):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels=1,
                               out_channels=6,
                               kernel_size=5)

        self.conv2 = nn.Conv2d(in_channels=6,
                               out_channels=16,
                               kernel_size=5)

        self.fc_1 = nn.Linear(16 * 4 * 4, 120)
        self.fc_2 = nn.Linear(120, 84)
        self.fc_3 = nn.Linear(84, output_dim)

    def forward(self, x):

        # x = [batch size, 1, 28, 28]

        x = self.conv1(x)

        # x = [batch size, 6, 24, 24]

        x = F.max_pool2d(F.relu(x), kernel_size=2)

        # x = [batch size, 6, 12, 12]

        x = self.conv2(x)

        # x = [batch size, 16, 8, 8]

        x = F.max_pool2d(F.relu(x), kernel_size=2)

        # x = [batch size, 16, 4, 4]

        x = x.view(x.shape[0], -1)

        # x = [batch size, 16*4*4 = 256]

        h = x

        x = self.fc_1(x)

        # x = [batch size, 120]

        x = F.relu(x) # 修改这里的激活函数

        x = self.fc_2(x)

        # x = batch size, 84]

        x = F.relu(x) # 修改这里的激活函数

        x = self.fc_3(x)

        # x = [batch size, output dim]

        return x, h

超参数设置

In [5]:
EPOCH =   10# 遍历数据集次数
BATCH_SIZE =  128 # 批处理尺寸(batch_size)
LR =  0.0001 # 学习率

定义数据预处理方式

In [6]:
transform = transforms.ToTensor()
trainset = tv.datasets.MNIST(
    root='/data/',
    train=True,
    download=True,
    transform=transform)

# 定义训练批处理数据
trainloader = torch.utils.data.DataLoader(
    trainset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    )

# 定义测试数据集
testset = tv.datasets.MNIST(
    root='/data/',
    train=False,
    download=True,
    transform=transform)

# 定义测试批处理数据
testloader = torch.utils.data.DataLoader(
    testset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    )

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /data/MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:01<00:00, 9080842.52it/s] 


Extracting /data/MNIST\raw\train-images-idx3-ubyte.gz to /data/MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /data/MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 674685.28it/s]


Extracting /data/MNIST\raw\train-labels-idx1-ubyte.gz to /data/MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /data/MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 4092890.06it/s]


Extracting /data/MNIST\raw\t10k-images-idx3-ubyte.gz to /data/MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /data/MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3991311.29it/s]

Extracting /data/MNIST\raw\t10k-labels-idx1-ubyte.gz to /data/MNIST\raw






实例化网络，如有GPU，将模型加载至GPU运算

In [11]:
model = LeNet(10).to(device)

定义损失函数和优化方式，尝试SGD、Adam等优化器，尝试运用权重衰减

In [21]:
criterion = nn.CrossEntropyLoss() # Define your loss function, do not modify this.

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-8)

模型训练及测试

In [22]:
from tqdm import tqdm
step = 0

In [23]:
for epoch in range(EPOCH):
    model.train()
    loss_record=[]
    train_pbar = tqdm(trainloader, position=0, leave=True)

    for x, y in train_pbar:
        optimizer.zero_grad()               # Set gradient to zero.
        x, y = x.to(device), y.to(device)   # Move your data to device.
        pred, _ = model(x)
        loss = criterion(pred, y)
        loss.backward()                     # Compute gradient(backpropagation).
        optimizer.step()
        loss_record.append(loss.detach().item())
        step+=1

        # Display current epoch number and loss on tqdm progress bar.
        train_pbar.set_description(f'Epoch [{epoch+1}/{EPOCH}]')
        train_pbar.set_postfix({'loss': loss.detach().item()})

    mean_train_loss = sum(loss_record)/len(loss_record)

    model.eval() # Set your model to evaluation mode.
    loss_record = []
    correct = 0
    total = 0
    for x, y in testloader:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            pred, _ = model(x)
            loss = criterion(pred, y)
            _, predicted = torch.max(pred.data, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()

        loss_record.append(loss.item())

    mean_valid_loss = sum(loss_record)/len(loss_record)
    accuracy = 100 * correct / total
    print(f'Epoch [{epoch+1}/{EPOCH}]: Train loss: {mean_train_loss:.4f}, Valid loss: {mean_valid_loss:.4f}, Accuracy: {accuracy:.2f}%')

Epoch [1/50]: 100%|██████████| 469/469 [00:08<00:00, 57.02it/s, loss=0.000306]


Epoch [1/50]: Train loss: 0.0013, Valid loss: 0.0369, Accuracy: 99.18%


Epoch [2/50]: 100%|██████████| 469/469 [00:08<00:00, 57.59it/s, loss=0.000114]


Epoch [2/50]: Train loss: 0.0006, Valid loss: 0.0383, Accuracy: 99.18%


Epoch [3/50]: 100%|██████████| 469/469 [00:08<00:00, 57.61it/s, loss=0.000969]


Epoch [3/50]: Train loss: 0.0005, Valid loss: 0.0368, Accuracy: 99.19%


Epoch [4/50]: 100%|██████████| 469/469 [00:08<00:00, 56.77it/s, loss=9.79e-5] 


Epoch [4/50]: Train loss: 0.0004, Valid loss: 0.0374, Accuracy: 99.20%


Epoch [5/50]: 100%|██████████| 469/469 [00:08<00:00, 56.86it/s, loss=3.82e-6] 


Epoch [5/50]: Train loss: 0.0003, Valid loss: 0.0382, Accuracy: 99.24%


Epoch [6/50]: 100%|██████████| 469/469 [00:08<00:00, 57.37it/s, loss=3.21e-5] 


Epoch [6/50]: Train loss: 0.0002, Valid loss: 0.0393, Accuracy: 99.21%


Epoch [7/50]: 100%|██████████| 469/469 [00:08<00:00, 58.15it/s, loss=6.33e-6] 


Epoch [7/50]: Train loss: 0.0002, Valid loss: 0.0398, Accuracy: 99.23%


Epoch [8/50]: 100%|██████████| 469/469 [00:08<00:00, 57.90it/s, loss=0.000424]


Epoch [8/50]: Train loss: 0.0002, Valid loss: 0.0407, Accuracy: 99.21%


Epoch [9/50]: 100%|██████████| 469/469 [00:08<00:00, 58.14it/s, loss=1.04e-5] 


Epoch [9/50]: Train loss: 0.0001, Valid loss: 0.0421, Accuracy: 99.23%


Epoch [10/50]: 100%|██████████| 469/469 [00:08<00:00, 57.69it/s, loss=1.73e-5] 


KeyboardInterrupt: 