In [None]:
%pylab inline

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
import torchvision
import torchvision.transforms as transforms

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize(0.5, 0.5)])

In [None]:
#CIFAR-10を使う場合は，"FashionMNIST" を "CIFAR10" に書き換えればよい
trainset = torchvision.datasets.FashionMNIST(root='./data', train=True,
                                        download=True, transform=transform)

In [None]:
#CIFAR-10を使う場合は，"FashionMNIST" を "CIFAR10" に書き換えればよい
testset = torchvision.datasets.FashionMNIST(root='./data', train=False,
                                            download=True, transform=transform)

In [None]:
# 自分で用意した画像データを使う場合は，画像データを以下のようにフォルダに格納して，torchvision.datasets.ImageFolderを用いる
# ./data/my_image_dataset/
# 　├ train/
# 　│　├ class1/
# 　│　│　　├ 1.png
# 　│　│　　└ ...
# 　│　├ class2/
# 　│　│　　├ 2.png
# 　│　│ 　└ ... 
# 　├ test/
# 　│　├ class1/
# 　│　│　　├ 3.png
# 　│　│　　└ ...
# 　│　├ class2/
# 　│　│　　├ 4.png
# 　│　│ 　└ ... 
# trainset = torchvision.datasets.ImageFolder(root='./data/my_image_dataset/train', transform=transform)
# testset = torchvision.datasets.ImageFolder(root='./data/my_image_dataset/test', transform=transform)

In [None]:
print(trainset)

In [None]:
# CIFAR-10のクラス
# classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
classes = torchvision.datasets.FashionMNIST.classes
print(classes)

In [None]:
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [None]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, 
                                          shuffle=True, num_workers=2)

In [None]:
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

In [None]:
dataiter = iter(trainloader)

In [None]:
x, y = dataiter.next()
imshow(torchvision.utils.make_grid(x))
print(x.shape)
print([classes[yi] for yi in y])

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        #print(x.shape)
        x = self.pool(F.relu(self.conv1(x)))
        #print(x.shape)
        x = self.pool(F.relu(self.conv2(x)))
        #print(x.shape)
        x = x.view(-1, 16 * 4 * 4)
        #print(x.shape)
        x = F.relu(self.fc1(x))
        #print(x.shape)
        x = F.relu(self.fc2(x))
        #print(x.shape)
        x = self.fc3(x)
        #print(x.shape)
        return x

In [None]:
cnn = CNN()
x, y = dataiter.next()
a = cnn(x)
pred_y = torch.argmax(a, dim=1)
print(pred_y)

In [None]:
from itertools import islice
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
cnn = CNN()
cnn.to(device)
optimizer = optim.SGD(cnn.parameters(), lr = 0.01)
for epoch in range(10):
    sumloss = 0.0
    #for data in trainloader:  （計算資源が十分ある環境では，全てのデータを使う方が良い）
    for data in islice(trainloader, 250):
        x, y = data
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        a = cnn(x)
        loss = F.cross_entropy(a, y)
        loss.backward()
        optimizer.step()
        sumloss += loss.item()
    print('epoch: {}, loss: {:.4f}'.format(epoch, sumloss))

In [None]:
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        x, y = data
        x = x.to(device)
        y = y.to(device)          
        a = cnn(x)
        pred_y = torch.argmax(a, dim=1)
        correct += (pred_y == y).sum().item()
        total += pred_y.size(0)

print(correct / total)

In [None]:
dataiter = iter(testloader)

In [None]:
x, y = dataiter.next()
imshow(torchvision.utils.make_grid(x))
x = x.to(device)
y = y.to(device)
a = cnn(x)
pred_y = torch.argmax(a, dim=1)
print([classes[yi] for yi in pred_y])
print([classes[yi] for yi in y])