In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

In [2]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device


device(type='cpu')

In [4]:
# 超参数
EPOCH = 10
NUM_CLASSES = 10
BATCH_SIZE = 64
LEARNING_RATE = 0.01

In [104]:
# 数据集
train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True,
                                           transform=transforms.ToTensor())
test_dataset = torchvision.datasets.MNIST(root="./data", train=False,
                                          transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE,
                                          shuffle=False)

In [116]:
class Net(nn.Module):
    """
    定义LeNet网络结构
    """
    def __init__(self, num_classes=10):
        super(Net, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2),
                                   nn.ReLU(), 
                                   nn.MaxPool2d(2, 2))
        
        self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), 
                                   nn.ReLU(), 
                                   nn.MaxPool2d(2, 2))
        
        self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                                 nn.BatchNorm1d(120), 
                                 nn.ReLU())
        
        self.fc2 = nn.Sequential(nn.Linear(120, 84), 
                                 nn.BatchNorm1d(84), 
                                 nn.ReLU())
        
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):
        """
        前向传播
        :param x: 输入的图片矩阵 
        :return: 图片的类别
        """
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

In [106]:
leNet = Net(NUM_CLASSES).to(device)
leNet

Net(
  (conv1): Sequential(
    (0): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc1): Sequential(
    (0): Linear(in_features=400, out_features=120, bias=True)
    (1): BatchNorm1d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (fc2): Sequential(
    (0): Linear(in_features=120, out_features=84, bias=True)
    (1): BatchNorm1d(84, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

In [107]:
# 定义优化器和损失函数
optimizer = torch.optim.Adam(leNet.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

In [108]:
# 训练
total_step = len(train_loader)
for epoch in range(EPOCH):
    print("epoch: {}".format(epoch))
    train_loss = 0.0
    train_acc = 0.0
    for batch_idx, (batch_x, batch_y) in enumerate(train_loader):
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)

        # forward pass
        outputs = leNet(batch_x)
        loss = criterion(outputs, batch_y)

        # backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (batch_idx + 1) % 100 == 0:
            print("Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}"
                  .format(epoch + 1, EPOCH, batch_idx + 1, total_step, loss.item()))


epoch: 0


Epoch [1/10], Step [100/938], Loss: 0.1329


Epoch [1/10], Step [200/938], Loss: 0.0586


Epoch [1/10], Step [300/938], Loss: 0.0426


Epoch [1/10], Step [400/938], Loss: 0.1361


Epoch [1/10], Step [500/938], Loss: 0.3304


Epoch [1/10], Step [600/938], Loss: 0.1750


Epoch [1/10], Step [700/938], Loss: 0.0338


Epoch [1/10], Step [800/938], Loss: 0.0511


Epoch [1/10], Step [900/938], Loss: 0.0264


epoch: 1


Epoch [2/10], Step [100/938], Loss: 0.0285


Epoch [2/10], Step [200/938], Loss: 0.1233


Epoch [2/10], Step [300/938], Loss: 0.0166


Epoch [2/10], Step [400/938], Loss: 0.0055


Epoch [2/10], Step [500/938], Loss: 0.0133


Epoch [2/10], Step [600/938], Loss: 0.0750


Epoch [2/10], Step [700/938], Loss: 0.0059


Epoch [2/10], Step [800/938], Loss: 0.1493


Epoch [2/10], Step [900/938], Loss: 0.0168


epoch: 2


Epoch [3/10], Step [100/938], Loss: 0.0498


Epoch [3/10], Step [200/938], Loss: 0.0162


Epoch [3/10], Step [300/938], Loss: 0.0174


Epoch [3/10], Step [400/938], Loss: 0.0387


Epoch [3/10], Step [500/938], Loss: 0.0327


Epoch [3/10], Step [600/938], Loss: 0.0418


Epoch [3/10], Step [700/938], Loss: 0.0156


Epoch [3/10], Step [800/938], Loss: 0.1146


Epoch [3/10], Step [900/938], Loss: 0.1290


epoch: 3


Epoch [4/10], Step [100/938], Loss: 0.0088


Epoch [4/10], Step [200/938], Loss: 0.0075


Epoch [4/10], Step [300/938], Loss: 0.0067


Epoch [4/10], Step [400/938], Loss: 0.0049


Epoch [4/10], Step [500/938], Loss: 0.0116


Epoch [4/10], Step [600/938], Loss: 0.0091


Epoch [4/10], Step [700/938], Loss: 0.0264


Epoch [4/10], Step [800/938], Loss: 0.0585


Epoch [4/10], Step [900/938], Loss: 0.0063


epoch: 4


Epoch [5/10], Step [100/938], Loss: 0.0782


Epoch [5/10], Step [200/938], Loss: 0.0908


Epoch [5/10], Step [300/938], Loss: 0.0176


Epoch [5/10], Step [400/938], Loss: 0.0051


Epoch [5/10], Step [500/938], Loss: 0.0293


Epoch [5/10], Step [600/938], Loss: 0.0287


Epoch [5/10], Step [700/938], Loss: 0.0238


Epoch [5/10], Step [800/938], Loss: 0.0683


Epoch [5/10], Step [900/938], Loss: 0.1018


epoch: 5


Epoch [6/10], Step [100/938], Loss: 0.0058


Epoch [6/10], Step [200/938], Loss: 0.0010


Epoch [6/10], Step [300/938], Loss: 0.1713


Epoch [6/10], Step [400/938], Loss: 0.1923


Epoch [6/10], Step [500/938], Loss: 0.0143


Epoch [6/10], Step [600/938], Loss: 0.0093


Epoch [6/10], Step [700/938], Loss: 0.0051


Epoch [6/10], Step [800/938], Loss: 0.0011


Epoch [6/10], Step [900/938], Loss: 0.0120


epoch: 6


Epoch [7/10], Step [100/938], Loss: 0.0290


Epoch [7/10], Step [200/938], Loss: 0.0019


Epoch [7/10], Step [300/938], Loss: 0.0055


Epoch [7/10], Step [400/938], Loss: 0.0019


Epoch [7/10], Step [500/938], Loss: 0.0106


Epoch [7/10], Step [600/938], Loss: 0.0083


Epoch [7/10], Step [700/938], Loss: 0.1138


Epoch [7/10], Step [800/938], Loss: 0.0528


Epoch [7/10], Step [900/938], Loss: 0.0255


epoch: 7


Epoch [8/10], Step [100/938], Loss: 0.0622


Epoch [8/10], Step [200/938], Loss: 0.0807


Epoch [8/10], Step [300/938], Loss: 0.0095


Epoch [8/10], Step [400/938], Loss: 0.0690


Epoch [8/10], Step [500/938], Loss: 0.0072


Epoch [8/10], Step [600/938], Loss: 0.0029


Epoch [8/10], Step [700/938], Loss: 0.1217


Epoch [8/10], Step [800/938], Loss: 0.0093


Epoch [8/10], Step [900/938], Loss: 0.0053


epoch: 8


Epoch [9/10], Step [100/938], Loss: 0.0022


Epoch [9/10], Step [200/938], Loss: 0.0019


Epoch [9/10], Step [300/938], Loss: 0.1117


Epoch [9/10], Step [400/938], Loss: 0.0303


Epoch [9/10], Step [500/938], Loss: 0.1450


Epoch [9/10], Step [600/938], Loss: 0.0271


Epoch [9/10], Step [700/938], Loss: 0.0106


Epoch [9/10], Step [800/938], Loss: 0.0275


Epoch [9/10], Step [900/938], Loss: 0.0689


epoch: 9


Epoch [10/10], Step [100/938], Loss: 0.0122


Epoch [10/10], Step [200/938], Loss: 0.0026


Epoch [10/10], Step [300/938], Loss: 0.0002


Epoch [10/10], Step [400/938], Loss: 0.0004


Epoch [10/10], Step [500/938], Loss: 0.0010


Epoch [10/10], Step [600/938], Loss: 0.0006


Epoch [10/10], Step [700/938], Loss: 0.0150


Epoch [10/10], Step [800/938], Loss: 0.0200


Epoch [10/10], Step [900/938], Loss: 0.0015


In [113]:
# test model
leNet.eval()
total_test = 0
with torch.no_grad():
    correct = 0
    total = 0
    for batch_x, batch_y in test_loader:
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)
        outputs = leNet(batch_x)
        _, predicted = torch.max(outputs.data, 1)
        total += batch_y.size(0)
        correct += (predicted == batch_y).sum().item()
        total_test += len(batch_x)
    print("Test Accuracy of the model on the {} test images: {}%"
          .format(total_test, 100 * correct / total))

Net(
  (conv1): Sequential(
    (0): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc1): Sequential(
    (0): Linear(in_features=400, out_features=120, bias=True)
    (1): BatchNorm1d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (fc2): Sequential(
    (0): Linear(in_features=120, out_features=84, bias=True)
    (1): BatchNorm1d(84, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

Test Accuracy of the model on the 10000 test images: 98.58%


In [114]:
# save model
torch.save(leNet.state_dict(), "./model/leNet.ckpt")