In [15]:
"""
    作者：Troublemaker
    功能：
    版本：
    日期：2020/4/5 19:57
    脚本：cnn.py
"""
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt

torch.manual_seed(1)
# 设置超参数
epoches = 2
batch_size = 50
learning_rate = 0.001

# 搭建CNN
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()   # 继承__init__功能
        ## 第一层卷积
        self.conv1 = nn.Sequential(
            # 输入[1,28,28]
            nn.Conv2d(
                in_channels=1,    # 输入图片的高度
                out_channels=16,  # 输出图片的高度
                kernel_size=5,    # 5x5的卷积核，相当于过滤器
                stride=1,         # 卷积核在图上滑动，每隔一个扫一次
                padding=2,        # 给图外边补上0
            ),
            # 经过卷积层 输出[16,28,28] 传入池化层
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)   # 经过池化 输出[16,14,14] 传入下一个卷积
        )
        ## 第二层卷积
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                in_channels=16,    # 同上
                out_channels=32,
                kernel_size=5,
                stride=1,
                padding=2
            ),
            # 经过卷积 输出[32, 14, 14] 传入池化层
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)  # 经过池化 输出[32,7,7] 传入输出层
        )
        ## 输出层
        self.linear = nn.Linear(in_features=32*7*7, out_features=10)
        # softmax输出分类
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)           # [batch, 32,7,7]
        x = x.view(x.size(0), -1)   # 保留batch, 将后面的乘到一起 [batch, 32*7*7]
        linear_o = self.linear(x)     # 输出[50,10]
        output = self.softmax(linear_o) # 输出[50,10]
        return linear_o


# 下载MNist数据集
train_data = torchvision.datasets.MNIST(
    root="./mnist/",  # 训练数据保存路径
    train=True,
    transform=torchvision.transforms.ToTensor(),  # 数据范围已从(0-255)压缩到(0,1)
    download=False,  # 是否需要下载
)
# print(train_data.train_data.size())   # [60000,28,28]
# print(train_data.train_labels.size())  # [60000]
# plt.imshow(train_data.train_data[0].numpy())
# plt.show()

test_data = torchvision.datasets.MNIST(root="./mnist/", train=False)
print(test_data.test_data.size())    # [10000, 28, 28]
# print(test_data.test_labels.size())  # [10000, 28, 28]
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255
test_y = test_data.test_labels[:2000]

# 装入Loader中
train_loader = Data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, num_workers=3)



torch.Size([10000, 28, 28])




In [8]:
test_x.shape

torch.Size([2000, 1, 28, 28])

In [18]:

# def main():
# cnn 实例化
cnn = CNN()
print(cnn)

# 定义优化器和损失函数
optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate)
loss_function = nn.CrossEntropyLoss()

# 开始训练
for epoch in range(epoches):
    print("进行第{}个epoch".format(epoch))
    for step, (batch_x, batch_y) in enumerate(train_loader):
        output = cnn(batch_x)  # batch_x=[50,1,28,28]
        # output = output[0]
        loss = loss_function(output, batch_y)
        # print("output: ", output)
        # print("batch_y: ", batch_y)
        # print("output.shape: ", output.shape)
        # print("batch_y.shape: ", batch_y.shape)
        # break
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % 50 == 0:
            test_output = cnn(test_x)  # [10000 ,10]
            pred_y = torch.max(test_output, 1)[1].data.numpy()
            # accuracy = sum(pred_y==test_y)/test_y.size(0)
            accuracy = ((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
            print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy)


test_output = cnn(test_x[:10])
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
print(pred_y)
print(test_y[:10])


CNN(
  (conv1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (linear): Linear(in_features=1568, out_features=10, bias=True)
  (softmax): Softmax(dim=1)
)
进行第0个epoch
Epoch:  0 | train loss: 2.3188 | test accuracy: 0.21
Epoch:  0 | train loss: 0.4021 | test accuracy: 0.83
Epoch:  0 | train loss: 0.2665 | test accuracy: 0.85
Epoch:  0 | train loss: 0.3791 | test accuracy: 0.91
Epoch:  0 | train loss: 0.1110 | test accuracy: 0.92
Epoch:  0 | train loss: 0.2147 | test accuracy: 0.93
Epoch:  0 | train loss: 0.1130 | test accuracy: 0.94
Epoch:  0 | train loss: 0.1676 | test accuracy: 0.96
Epoch:  0 | train loss: 0.2340 | test accuracy: 0.95
Epoch:  0

In [14]:
for epoch in range(epoches):
    print("进行第{}个epoch".format(epoch))
    for step, (batch_x, batch_y) in enumerate(train_loader):
        print(batch_y)
        print(batch_y.shape)

进行第0个epoch
tensor([2, 1, 3, 8, 9, 4, 8, 7, 0, 1, 6, 2, 7, 9, 5, 8, 4, 1, 6, 0, 3, 9, 2, 0,
        9, 0, 7, 1, 1, 7, 2, 4, 1, 1, 1, 4, 3, 1, 5, 4, 7, 9, 7, 7, 3, 1, 1, 5,
        8, 0])
torch.Size([50])
tensor([6, 9, 5, 5, 1, 4, 1, 9, 4, 6, 3, 1, 0, 1, 7, 8, 1, 2, 0, 7, 1, 0, 1, 7,
        9, 4, 6, 2, 1, 9, 3, 0, 4, 0, 8, 7, 2, 4, 6, 7, 9, 8, 1, 2, 1, 9, 4, 4,
        0, 5])
torch.Size([50])
tensor([7, 6, 6, 4, 7, 4, 5, 9, 3, 1, 9, 1, 4, 0, 9, 6, 8, 2, 6, 7, 7, 1, 3, 9,
        7, 2, 3, 5, 8, 1, 3, 4, 0, 1, 9, 7, 7, 8, 0, 7, 8, 5, 2, 9, 1, 9, 6, 6,
        0, 6])
torch.Size([50])
tensor([2, 4, 3, 7, 2, 2, 3, 6, 7, 7, 0, 8, 6, 1, 0, 0, 0, 4, 1, 6, 4, 2, 4, 7,
        7, 1, 4, 0, 9, 3, 3, 4, 1, 6, 3, 0, 7, 0, 5, 9, 7, 2, 8, 5, 4, 3, 1, 8,
        7, 9])
torch.Size([50])
tensor([8, 7, 7, 7, 7, 9, 0, 5, 0, 9, 8, 5, 1, 4, 5, 5, 6, 4, 4, 6, 1, 3, 0, 3,
        2, 8, 3, 5, 6, 2, 2, 9, 1, 5, 3, 4, 1, 0, 6, 9, 7, 1, 5, 2, 7, 6, 1, 3,
        3, 6])
torch.Size([50])
tensor([0, 3, 4, 9, 3, 3, 2, 

In [13]:
test_output

tensor([[3.0313e-22, 5.7359e-22, 1.5067e-18, 3.6383e-12, 2.0558e-25, 1.5355e-20,
         1.9524e-37, 1.0000e+00, 8.9209e-20, 9.5562e-16],
        [4.8300e-14, 7.6052e-16, 1.0000e+00, 8.6907e-15, 9.8237e-26, 8.0096e-21,
         8.0753e-17, 1.5559e-33, 5.9845e-13, 7.7266e-30],
        [1.2534e-12, 1.0000e+00, 6.6348e-12, 1.1126e-11, 5.1692e-07, 2.6485e-14,
         6.7244e-11, 1.6364e-10, 8.2765e-09, 8.7886e-12],
        [1.0000e+00, 4.1332e-21, 2.6839e-14, 1.4809e-21, 6.1688e-15, 1.5420e-16,
         2.3162e-09, 4.1920e-16, 3.9538e-15, 1.6353e-13],
        [1.5804e-22, 5.1695e-17, 1.5534e-21, 3.5539e-21, 1.0000e+00, 4.5396e-21,
         5.4560e-16, 2.6763e-15, 6.6864e-19, 4.6347e-10],
        [2.6359e-15, 1.0000e+00, 8.9523e-15, 6.9645e-15, 3.5899e-08, 2.0984e-19,
         1.0372e-14, 1.3445e-11, 6.8859e-11, 2.5343e-14],
        [4.6407e-30, 7.8450e-14, 2.7526e-20, 1.4499e-21, 1.0000e+00, 2.5468e-15,
         9.7448e-20, 1.1679e-14, 6.3356e-11, 2.5740e-12],
        [2.6960e-17, 1.1462