卷积神经网络实现MNIST手写数字分类

加载数据

In [2]:
batch_size = 64

In [3]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
data_tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])

# train属性是区别并对应加载训练集和测试机
train_dataset = datasets.MNIST(root='./data', train=True, transform=data_tf)
test_dataset  = datasets.MNIST(root='./data', train=False, transform=data_tf)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False)

网络结构

In [4]:
from torch import nn

In [8]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # 1 @ 28*28
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, 3), # 16 @ 26*26
            # 参数是图片张数？
            nn.BatchNorm2d(16),
            nn.ReLU(True)
        )
        
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, 3), # 32 @ 24*24
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2) # 32 @ 12*12
        )
        
        self.layer3 = nn.Sequential(
            nn.Conv2d(32, 64, 3), # 64 @ 10*10
            nn.BatchNorm2d(64),
            nn.ReLU(True)
        )
        
        self.layer4 = nn.Sequential(
            nn.Conv2d(64, 128, 3), # 128 @ 8*8
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2) # 128 @ 4*4
        )
        
        self.fc = nn.Sequential(
            nn.Linear(128*4*4, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 128),
            nn.ReLU(True),
            nn.Linear(128,10)
        )
        
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

模型训练

In [13]:
from torch.autograd import Variable
from torch import optim

In [14]:
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr = 1e-2)

In [15]:
model.train()
epoch = 0
for img, label in train_loader:
    img = Variable(img)
    label = Variable(label)
    
    out = model(img)
    loss = criterion(out, label)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    epoch += 1
    if epoch % 100 == 0:
        print('epoch: {}, loss: {:.6f}'.format(epoch,loss.item()))

epoch: 100, loss: 1.893072
epoch: 200, loss: 0.730214
epoch: 300, loss: 0.377038
epoch: 400, loss: 0.256213
epoch: 500, loss: 0.203365
epoch: 600, loss: 0.137815
epoch: 700, loss: 0.156686
epoch: 800, loss: 0.108450
epoch: 900, loss: 0.176699


In [25]:
from torch import max as tmax
import torch

In [27]:
model.eval()
acc = 0
for img, label in test_loader:
    img = Variable(img)
    label = Variable(label)
    out = model(img)
    _, pred = tmax(out, 1)
    correct =(pred==label).sum()
    acc += correct.item()
print(acc/len(test_dataset))

AttributeError: 'Tensor' object has no attribute 'no_grad'

In [19]:
print(acc/len(test_dataset))

0.9746


In [28]:
for data in test_loader:
    img, label = data
    break
out = model(img)
print(out)

tensor([[ 5.4759e-01, -7.1116e-01,  2.0367e+00,  1.0010e+00, -5.0946e+00,
         -2.6976e+00, -9.2522e+00,  1.0232e+01, -1.4421e+00,  2.2433e+00],
        [ 2.7320e+00,  2.6999e-01,  6.3447e+00,  1.2400e+00, -4.4484e+00,
         -1.9520e+00,  2.8867e+00, -3.0929e+00,  1.5398e+00, -6.0209e+00],
        [-2.2988e+00,  7.0427e+00, -9.0539e-01, -2.1142e+00,  1.7290e+00,
         -1.2376e+00, -5.8222e-01,  4.7535e-01, -1.9419e-01, -2.0531e+00],
        [ 8.8201e+00, -4.6206e+00,  5.1141e-01, -3.0814e+00, -4.4891e+00,
         -5.5719e-02,  1.8532e+00, -2.8189e-01, -2.1902e-01, -1.0543e+00],
        [-2.3116e+00, -1.9736e+00, -2.6354e+00, -4.3019e+00,  8.3027e+00,
         -2.3012e+00, -5.4642e-01, -1.6279e+00, -9.8473e-01,  3.3866e+00],
        [-2.4552e+00,  7.5266e+00, -1.0971e+00, -2.4439e+00,  1.7729e+00,
         -1.5307e+00, -1.2322e+00,  1.0633e+00, -1.0890e-01, -1.8033e+00],
        [-5.2435e+00,  2.0876e-01, -4.2011e+00, -2.4665e+00,  7.0182e+00,
         -4.2022e-01, -3.8480e+0