In [2]:
import torch 
import torch.nn as nn
import torch.nn.functional as F 
import torch.utils.data as Data 
import torch.optim as optim
import torchvision

In [3]:
batch_size = 64
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    # 切换到标准差分布
    torchvision.transforms.Normalize((0.1307, ), (0.3081, ))
])

In [65]:
train_data = torchvision.datasets.MNIST('./dataset/mnist/', train= True, transform= transform)
train_loader = Data.DataLoader(train_data, shuffle= True, batch_size= batch_size)
test_date = torchvision.datasets.MNIST('./dataset/mnist/', train= False, transform= transform)
test_loader = Data.DataLoader(test_date, shuffle= False, batch_size= batch_size)


In [72]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.conv1 = nn.Conv2d(1,10,5)
        self.conv2 = nn.Conv2d(10,20,5)
        self.pooling = nn.MaxPool2d(2)
        self.fc = nn.Linear(320,10)
    def forward(self,x):
        batch_size = x.size(0)
        x = F.relu(self.pooling(self.conv1(x)))
        x = F.relu(self.pooling(self.conv2(x)))
        x = x.view(batch_size,-1)
        x = self.fc(x)
        return x
model = CNN()

In [73]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum= 0.5)

In [74]:
def train(epoch):
    running_loss = 0
    for batch_idx, data in enumerate(train_loader):
        inputs, target = data
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 300 ==299:
            print("[%d,%5d] loss: %.3f" % (epoch+1, batch_idx+1, running_loss/300))
            running_loss = 0


In [75]:
def test():
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data
            outputs = model(inputs)
            _, prediction = torch.max(outputs.data, dim = 1)
            total += labels.size(0)
            correct += (prediction == labels).sum().item()
    print("ACC: %d %%" % (100*correct/total))

In [76]:
if __name__ == '__main__':
    for epoch in range(5):
        train(epoch)
        test()

[1,  300] loss: 0.692
[1,  600] loss: 0.193
[1,  900] loss: 0.138
ACC: 96 %
[2,  300] loss: 0.112
[2,  600] loss: 0.093
[2,  900] loss: 0.090
ACC: 97 %
[3,  300] loss: 0.077
[3,  600] loss: 0.076
[3,  900] loss: 0.076
ACC: 98 %
[4,  300] loss: 0.061
[4,  600] loss: 0.065
[4,  900] loss: 0.063
ACC: 98 %
[5,  300] loss: 0.056
[5,  600] loss: 0.061
[5,  900] loss: 0.050
ACC: 98 %
