# PyTorch实现CNN

## 网络定义

In [4]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module): # 必须集成nn.Module
    def __init__(self):
        super(Net, self).__init__() # 必须调用父类的构造函数，传入类名和self
        # 输入是1个通道(灰度图)，卷积feature map的个数是6，大小是5x5，无padding，stride是1。
        self.conv1 = nn.Conv2d(3, 6, 5)
        # 第二个卷积层feature map个数是16，大小还是5*5，无padding，stride是1。
        self.conv2 = nn.Conv2d(6, 16, 5)
        #  仿射层y = Wx + b，ReLu层没有参数，因此不在这里定义
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # 卷积然后Relu然后2x2的max pooling
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # 再一层卷积relu和max pooling
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        # 把batch x channel x width x height 展开成 batch x all_nodes
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # 除了batchSize之外的其它维度
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


net = Net()
print(net)

Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


In [5]:
print(torch.__version__)

0.4.0


In [6]:
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
transform = transforms.Compose([
        transforms.ToTensor(), # 转为Tensor
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 归一化
                             ])
# 因为原始数据链接https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz可能无法访问
# 因此手动下载

# 请先下载 https://media.githubusercontent.com/media/fancyerii/fancyerii.github.io/master/assets/cifar-10-python.tar.gz   
# 然后解压到和这个ipynb到相同的目录下
# 训练集
trainset = datasets.CIFAR10(
                    root='./', 
                    train=True, 
                    download=False,
                    transform=transform)

trainloader = torch.utils.data.DataLoader(
                    trainset, 
                    batch_size=4,
                    shuffle=True, 
                    num_workers=2)

# 测试集
testset = datasets.CIFAR10(
                    './',
                    train=False, 
                    download=False, 
                    transform=transform)

testloader = torch.utils.data.DataLoader(
                    testset,
                    batch_size=4, 
                    shuffle=False,
                    num_workers=2)

In [7]:
def test(): 
    correct = 0 # 预测正确的图片数
    total = 0 # 总共的图片数
    for data in testloader:
        images, labels = data
        outputs = net(Variable(images))
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
        
    print('准确率为: %f %%' % (100 * correct / total))

In [None]:
from torch import optim
criterion = nn.CrossEntropyLoss() # 交叉熵损失函数
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

for epoch in range(100):  
     
    for i, data in enumerate(trainloader):
        
        # 输入数据
        inputs, labels = data
        inputs, labels = Variable(inputs), Variable(labels)
        # 梯度清零
        optimizer.zero_grad()
        
        # forward + backward 
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()   
        
        # 更新参数 
        optimizer.step()
        if i%1000==0:
            print("loss: %f" %(loss))
 
    test()
print('Finished Training')


loss: 2.341669
loss: 2.313134
loss: 2.070405
loss: 1.813186
loss: 1.748052
loss: 1.716123
loss: 1.444413
loss: 1.138906
loss: 1.225269
loss: 1.764292
loss: 2.043734
loss: 2.384421
loss: 1.968518
准确率为: 47.000000 %
loss: 1.031336
loss: 1.682770
loss: 1.323722
loss: 2.045666
loss: 1.470216
loss: 0.447384
loss: 2.814883
loss: 0.975228
loss: 1.103358
loss: 1.007044
loss: 2.966924
loss: 1.874554
loss: 1.401181
准确率为: 53.000000 %
loss: 0.918949
loss: 0.982429
loss: 0.547143
loss: 0.350398
loss: 1.226486
loss: 1.343273
loss: 1.867988
loss: 1.170456
loss: 0.954924
loss: 0.624940
loss: 1.046091
loss: 1.086277
loss: 1.553882
准确率为: 57.000000 %
loss: 0.236300
loss: 0.575509
loss: 1.398829
loss: 1.091057
loss: 0.983441
loss: 0.928330
loss: 0.478113
loss: 1.160415
loss: 1.125568
loss: 0.836833
loss: 0.282486
loss: 0.418222
loss: 1.156843
准确率为: 60.000000 %
loss: 1.026640
loss: 0.857518
loss: 0.974910
loss: 0.733361
loss: 1.156264
loss: 2.583951
loss: 1.370304
loss: 0.930291
loss: 0.444114
loss: 0.82373