In [338]:
import torch.nn as nn
from torch.nn import functional as F
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader


In [339]:
train_set=torchvision.datasets.MNIST(root='./data',
                                     train=True,
                                     download=False,
                                     transform=torchvision.transforms.ToTensor())

In [340]:
test_set=torchvision.datasets.MNIST(root='./data',
                                    transform=torchvision.transforms.ToTensor(),
                                    download=False,
                                    train=False)

In [341]:
transform = transforms.Compose([   #将输入的图片进行“[]”中的一系列的处理
    transforms.ToTensor(),         #将输入图片进行转成3通道的张量，并将像素值压缩到0-1 
    transforms.Normalize((0.1307,),(0.3081,))  # 这两个值使均值和方差,将图像数据进行标准化
                                               #（这里使数据形成正态分布）
]
)
batch_size=12

In [342]:
train_dataloader=DataLoader(train_set,batch_size=batch_size, shuffle=True,num_workers=0)
test_dataloader=DataLoader(test_set,batch_size=batch_size,shuffle=True,num_workers=0)

## 搭建网络，模型训练

In [343]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(  # input shape (1, 28, 28)
            nn.Conv2d(
                in_channels=1,      # input height
                out_channels=16,    # n_filters
                kernel_size=5,      # filter size
                stride=1,           # filter movement/step
                padding=2,          # 如果想要 con2d 出来的图片长宽没有变化, padding=(kernel_size-1)/2 当 stride=1
                ),                  # output shape (16, 28, 28)
            nn.ReLU(),              # activation
            nn.MaxPool2d(kernel_size=2),    # 在 2x2 空间里向下采样, output shape (16, 14, 14)
        )
        self.conv2 = nn.Sequential(         # input shape (16, 14, 14)
            nn.Conv2d(16, 32, 5, 1, 2),     # output shape (32, 14, 14)
            nn.ReLU(),                      # activation 
            nn.MaxPool2d(2),                # output shape (32, 7, 7)
        )
        self.out = nn.Linear(32 * 7 * 7, 10)   # fully connected layer, output 10 classes

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)            # 展平多维的卷积图成 (batch_size, 32 * 7 * 7)
        output = self.out(x)
        return output

cnn = CNN()

In [344]:
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)   # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss()   # the target label is not one-hotted
epochs=10

In [346]:
def train_model():
    cnn.train()
    train_loss_all=[];train_acc_all=[]
    for epoch in range(2):
        print("+++++++第{}轮训练++++".format(epoch+1))
        train_loss=0.0;train_correct=0.0;train_num=0.0;train_accuracy=0.0
        for step, data in enumerate(train_dataloader):
            img,label=data
            output=cnn(img)
            loss = loss_func(output,label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step() 
            ###计算损失，预测正确数量，准确率
            train_correct+=(output.argmax(1)==label).sum().item()##第i轮预测正确数量
            train_loss+=(loss.item()*img.size(0))   ##第i轮损失值
            train_num+=img.size(0)
            if step%300==0:
                print("loss:",train_loss/train_num)

        train_loss_all.append(train_loss/train_num)
        train_acc_all.append(train_correct/train_num)
        print("train loss{:.4f}, train accuracy{}/{} ({:.4f})".format(train_loss_all[-1],train_correct,train_num,
                                                                      train_acc_all[-1]))
train_model()

+++++++第1轮训练++++
loss: 0.034036751836538315
loss: 0.10468880747803842
loss: 0.0943584059322576
loss: 0.09222525307257064
loss: 0.08814896368841814
loss: 0.08277153241690804
loss: 0.07964189589825853
loss: 0.0795858176105023
loss: 0.07747923125589863
loss: 0.07572384189554146
loss: 0.07412676711866867
loss: 0.07131481050443399
loss: 0.06955558339409776
loss: 0.06759861760855175
loss: 0.06670221177130967
loss: 0.06613573548899923
loss: 0.06507879180144087
train loss0.0647, train accuracy58796.0/60000.0 (0.9799)
+++++++第2轮训练++++
loss: 0.012058623135089874
loss: 0.03415434030962251
loss: 0.03187426500474198
loss: 0.03661498846658218
loss: 0.037411481576938535
loss: 0.03647618953139637
loss: 0.03572638172251256
loss: 0.037468621352681586
loss: 0.03758775558007896
loss: 0.037450450938337765
loss: 0.037910165207493705
loss: 0.03807287186924797
loss: 0.037883391338761674
loss: 0.03716896984554072
loss: 0.03717418113791107
loss: 0.038106647019900075
loss: 0.0377742952826942
train loss0.0373, tr

In [347]:
def val_model():
    cnn.eval()
    test_loss_all=[];test_acc_all=[]
    for epoch in range(2):
        print("+++++++第{}轮评估++++".format(epoch+1))
        val_loss=0;val_correct=0;val_num=0;val_accuracy=0
        for step,data in enumerate(test_dataloader):
            img,label=data
            output=cnn(img)
            loss = loss_func(output,label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #计算损失，准确率
            val_loss+=loss.item()*img.size(0)
            val_correct+=(output.argmax(1)==label).sum().item()
            val_num+=img.size(0)
            if step%200==0:
                print("val loss{}".format(val_loss/val_num))
        test_loss_all.append(val_loss/val_num)
        test_acc_all.append(val_correct/val_num)
        print("val loss is:{:.4f}, val accuracy is:{}/{} {:.4f}%".format(test_loss_all[-1],val_correct,
                                                                 val_num,100*test_acc_all[-1]))
val_model()

+++++++第1轮评估++++
val loss0.00048447493463754654
val loss0.035059050744985595
val loss0.028841819254027956
val loss0.028369459748891088
val loss0.028602542346601416
val loss is:0.0279, val accuracy is:9911/10000 99.1100%
+++++++第2轮评估++++
val loss0.002067637164145708
val loss0.00867289263656808
val loss0.01025436863539399
val loss0.009862066880653995
val loss0.010363350169286604
val loss is:0.0103, val accuracy is:9961/10000 99.6100%
