In [2]:
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt

In [3]:
torch.manual_seed(1)

<torch._C.Generator at 0x7f49440c9350>

In [5]:
EPOCH = 1 #训练整批数据的批次
BATCH_SIZE=50 #每个批次读取多少数据
LR=1e-3
DOWLOAD_MINIST=True


In [6]:
train_data= torchvision.datasets.MNIST(
    root='./minist',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=DOWLOAD_MINIST
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [7]:
test_data=torchvision.datasets.MNIST(root='./minist',train=False)

In [9]:
# 批训练 50samples, 1 channel, 28x28 (50, 1, 28, 28)
train_loader = Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)

In [10]:
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255.   # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
test_y = test_data.test_labels[:2000]



In [16]:
# 定义模型
class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                out_channels=16,
                kernel_size=5,
                stride=1,
                padding=2, #如果想要 con2d 出来的图片长宽没有变化, padding=(kernel_size-1)/2 当 stride=1
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        
        self.con2 = nn.Sequential(
            nn.Conv2d(16,32,5,1,2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        
        self.out = nn.Linear(32*7*7,10)
        
    def forward(self,x):
        x=self.conv1(x)
        x=self.con2(x)
        x=x.view(x.size(0),-1) # 展平多维的卷积图成 (batch_size, 32 * 7 * 7)
        output = self.out(x) 
        return output
    
    

In [17]:
cnn = CNN()

In [18]:
print(cnn)

CNN(
  (conv1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (con2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (out): Linear(in_features=1568, out_features=10, bias=True)
)


In [19]:
optimizer = torch.optim.Adam(cnn.parameters(),lr=LR)
loss_func=nn.CrossEntropyLoss()
# training and testing
for epoch in range(EPOCH):
    for step, (b_x, b_y) in enumerate(train_loader):   # 分配 batch data, normalize x when iterate train_loader
        output = cnn(b_x)               # cnn output
        loss = loss_func(output, b_y)   # cross entropy loss
        optimizer.zero_grad()           # clear gradients for this training step
        loss.backward()                 # backpropagation, compute gradients
        optimizer.step()                # apply gradients

In [31]:
test_output = cnn(test_x[:10])
print(test_output)
pred_y = torch.max(test_output, 1)[1].data.numpy()
print( torch.max(test_output,1))
# print(pred_y, 'prediction number')
# print(test_y[:10].numpy(), 'real number')

tensor([[ -8.0440,  -8.0997,  -1.4495,   0.9826, -13.4916,  -7.8216, -25.6799,
          13.1803,  -4.9843,  -1.7314],
        [ -0.1532,  -0.6090,   9.8525,  -3.9455, -16.9682, -10.0578,  -3.0559,
         -14.9017,  -0.7535, -17.0841],
        [ -4.0329,   6.9566,  -3.5273,  -4.5607,  -0.0524,  -5.6656,  -3.7588,
          -1.6082,  -2.5601,  -4.0307],
        [  9.8704, -14.9059,  -3.2396,  -8.6015,  -6.1942,  -3.4078,  -1.1316,
          -6.4802,  -4.2630,  -3.3603],
        [ -5.8548,  -4.0518,  -7.3428,  -6.9419,   8.8865, -10.4285,  -5.8612,
          -2.8335,  -6.5180,   0.9340],
        [ -4.6251,   7.7894,  -4.6738,  -5.2390,  -0.2096,  -7.9809,  -5.3063,
          -0.3419,  -2.0124,  -4.0211],
        [-15.5516,  -3.5360,  -6.2202,  -8.8890,   6.0318,  -8.0361, -11.5356,
          -3.0995,   0.6543,  -1.4803],
        [ -9.7924,  -4.1924,  -3.3156,  -2.3941,   0.9446,  -3.1389,  -8.3976,
          -4.9584,   0.2414,   3.9010],
        [ -8.0625, -20.2427, -10.0270,  -6.3685,