# PyTorch实现CNN

## 网络定义

In [23]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module): # 必须集成nn.Module
    def __init__(self):
        super(Net, self).__init__() # 必须调用父类的构造函数，传入类名和self
        # 输入是1个通道(灰度图)，卷积feature map的个数是6，大小是5x5，无padding，stride是1。
        self.conv1 = nn.Conv2d(3, 6, 5)
        # 第二个卷积层feature map个数是16，大小还是5*5，无padding，stride是1。
        self.conv2 = nn.Conv2d(6, 16, 5)
        #  仿射层y = Wx + b，ReLu层没有参数，因此不在这里定义
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # 卷积然后Relu然后2x2的max pooling
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # 再一层卷积relu和max pooling
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        # 把batch x channel x width x height 展开成 batch x all_nodes
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # 除了batchSize之外的其它维度
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


net = Net()
print(net)

Net(
  (conv1): Conv2d (3, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d (6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120)
  (fc2): Linear(in_features=120, out_features=84)
  (fc3): Linear(in_features=84, out_features=10)
)


In [24]:
print(torch.__version__)

0.3.0.post4


In [25]:
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
transform = transforms.Compose([
        transforms.ToTensor(), # 转为Tensor
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 归一化
                             ])

# 训练集
trainset = datasets.CIFAR10(
                    root='/home/mc/PycharmProjects/Py35/pytorch-tutorial/data/', 
                    train=True, 
                    download=True,
                    transform=transform)

trainloader = torch.utils.data.DataLoader(
                    trainset, 
                    batch_size=4,
                    shuffle=True, 
                    num_workers=2)

# 测试集
testset = datasets.CIFAR10(
                    '/home/mc/PycharmProjects/Py35/pytorch-tutorial/data/',
                    train=False, 
                    download=True, 
                    transform=transform)

testloader = torch.utils.data.DataLoader(
                    testset,
                    batch_size=4, 
                    shuffle=False,
                    num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [26]:
def test(): 
    correct = 0 # 预测正确的图片数
    total = 0 # 总共的图片数
    for data in testloader:
        images, labels = data
        outputs = net(Variable(images))
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
        
    print('准确率为: %f %%' % (100 * correct / total))

In [30]:
from torch import optim
criterion = nn.CrossEntropyLoss() # 交叉熵损失函数
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

for epoch in range(100):  
     
    for i, data in enumerate(trainloader):
        
        # 输入数据
        inputs, labels = data
        inputs, labels = Variable(inputs), Variable(labels)
        # 梯度清零
        optimizer.zero_grad()
        
        # forward + backward 
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()   
        
        # 更新参数 
        optimizer.step()
        if i%1000==0:
            print("loss: %f" %(loss))
 
    test()
print('Finished Training')


loss: 1.931017
loss: 1.720093
loss: 1.899214
loss: 2.254998
loss: 1.867488
loss: 1.977913
loss: 1.639656
loss: 1.575049
loss: 1.457016
loss: 1.671834
loss: 1.774192
loss: 2.607725
loss: 2.129179
准确率为: 29.520000 %
loss: 1.509401
loss: 1.587983
loss: 1.845316
loss: 2.607810
loss: 1.513253
loss: 1.561820
loss: 1.415683
loss: 1.992751
loss: 1.787836
loss: 2.252830
loss: 1.969576
loss: 1.675066
loss: 2.491286
准确率为: 31.060000 %
loss: 1.373281
loss: 1.361114
loss: 1.697119
loss: 2.348906
loss: 1.526389
loss: 1.925153
loss: 2.125213
loss: 1.852802
loss: 1.721456
loss: 1.897223
loss: 2.017206
loss: 1.406271
loss: 1.601987
准确率为: 30.770000 %
loss: 1.425790
loss: 1.775990
loss: 1.562111
loss: 1.308941
loss: 1.118886
loss: 2.375990
loss: 2.354127
loss: 1.872196
loss: 2.011174
loss: 2.652064
loss: 2.271539
loss: 1.401000
loss: 1.059430
准确率为: 31.930000 %
loss: 1.496392
loss: 1.516793
loss: 1.226782
loss: 1.345947
loss: 2.559560
loss: 1.396280
loss: 1.727328
loss: 1.331092
loss: 1.380867
loss: 3.07434

Process Process-113:
Process Process-114:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/lib/python3.5/multiprocessing/process.py", line 252, in _bootstrap
    self.run()
  File "/usr/lib/python3.5/multiprocessing/process.py", line 252, in _bootstrap
    self.run()
  File "/usr/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/mc/env-py3.5/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 36, in _worker_loop
    r = index_queue.get()
  File "/usr/lib/python3.5/multiprocessing/queues.py", line 335, in get
    res = self._reader.recv_bytes()
  File "/usr/lib/python3.5/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)


KeyboardInterrupt: 

  File "/home/mc/env-py3.5/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 36, in _worker_loop
    r = index_queue.get()
  File "/usr/lib/python3.5/multiprocessing/connection.py", line 407, in _recv_bytes
    buf = self._recv(4)
  File "/usr/lib/python3.5/multiprocessing/queues.py", line 334, in get
    with self._rlock:
KeyboardInterrupt
  File "/usr/lib/python3.5/multiprocessing/connection.py", line 379, in _recv
    chunk = read(handle, remaining)
  File "/usr/lib/python3.5/multiprocessing/synchronize.py", line 96, in __enter__
    return self._semlock.__enter__()
KeyboardInterrupt
