In [ ]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

In [ ]:
# 要对数据进行预处理，cifar10的图片数据尺寸是3*32*32
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=2),   # 先四周填充0，再把图像随机裁剪成32*32
    transforms.RandomHorizontalFlip(),      # 图像一半的概率翻转，一半的概率不翻转
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225)) # 这组数据是从imgnet得到的,在pytorch文档中也由建议用
    ])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
    ])

In [ ]:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [ ]:
# functions to show an image
def imshow(img):
    print(img.size())
    img[0] = img[0] * 0.229 + 0.485    
    img[1] = img[1] * 0.224 + 0.456
    img[2] = img[2] * 0.225 + 0.406         # 将归一化的数据还原
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# 随机采样获得几张图片
dataiter = iter(trainloader)
images, labels = dataiter.next()

print(images.size())

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
print(labels[1].size())

In [ ]:
from vgg16.vgg16 import VGG16   # 由于路径问题可能需要调整
net = VGG16()
PATH = './train_ipynb.pth'
net.load_state_dict(torch.load(PATH))

In [ ]:
dataiter = iter(testloader)

In [ ]:
images, labels = dataiter.next()

# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
outputs = net(images)
_, predicted = torch.max(outputs, 1)
print('Predicted:   ', ' '.join('%5s' % classes[predicted[j]]
                              for j in range(4)))