In [2]:
import sys
import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
 
#定义一个稍微简单一点的 vgg11 结构，其中有 8 个卷积层
class vgg(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature = nn.Sequential(
            nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1),
                          nn.ReLU(True),
                          nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, padding=1),
                          nn.ReLU(True),
                          nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, padding=1),
                          nn.ReLU(True),
                          nn.Conv2d(256, 256, kernel_size=3, padding=1),
                          nn.ReLU(True),
                          nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, padding=1),
                          nn.ReLU(True),
                          nn.Conv2d(512, 512, kernel_size=3, padding=1),
                          nn.ReLU(True),
                          nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1),
                          nn.ReLU(True),
                          nn.Conv2d(512, 512, kernel_size=3, padding=1),
                          nn.ReLU(True),
                          nn.MaxPool2d(2, 2)),         
        )
        self.fc = nn.Sequential(
            nn.Linear(512, 100),
            nn.ReLU(True),
            nn.Linear(100, 10)
        )
 
    def forward(self, x):
        x = self.feature(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x
 
transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
                                    ])
 
def get_acc(output, label):
    total = output.shape[0]
    _, pred_label = output.max(1)
    num_correct = (pred_label == label).sum().item()
    return num_correct / total

def train(net, train_data, num_epochs, optimizer, criterion):
    net = net.train()
    for epoch in range(num_epochs):
        train_loss = 0
        train_acc = 0
        for im, label in train_data:
            # forward
            output = net(im)
            loss = criterion(output, label)
            # forward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
 
            train_loss += loss.item()
            train_acc += get_acc(output, label)
        print("Epoch %d. Train Loss: %f, Train Acc: %f, " %
                         (epoch, train_loss / len(train_data),
                          train_acc / len(train_data)))
 
if __name__ == '__main__':
    train_set = CIFAR10('./data', train=True, transform=transform, download=True)
    train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
    test_set = CIFAR10('./data', train=False, transform=transform, download=True)
    test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)
 
    net = vgg()
    print(net)
    optimizer = torch.optim.SGD(net.parameters(), lr=1e-1)
    criterion = nn.CrossEntropyLoss() #损失函数为交叉熵
 
    train(net, train_data, 20, optimizer, criterion)
    
    test_loss = 0
    test_acc = 0
    net = net.eval()
    for im, label in test_data:
        with torch.no_grad():
            output = net(im)
            loss = criterion(output, label)
            test_loss += loss.item()
            test_acc += get_acc(output, label)
    print("Test Loss: %f, Test Acc: %f, "
                    % (test_loss / len(test_data),
                       test_acc / len(test_data)))

Files already downloaded and verified
Files already downloaded and verified
vgg(
  (feature): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (1): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (2): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (3): Sequential(
      (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace