# CIFAR-10データセットを用いて画像分類してみる

## import

In [1]:
import os
import sys

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torchvision import datasets, transforms
import random
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

sys.path.append('../src')
import utils

sys.path.append('./src')
from model import *

In [2]:
config = utils.readConfig('../config.json')
#raw_path = config['filepath']['output_dir'] + '/20211105_log_revise/01'
output_path = config['filepath']['output_dir'] + '/01'

utils.makeDirs(output_path, ['graph'])

## GPU or CPU 

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

## set seed

In [None]:
# parserなどで指定
seed = 42

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(seed)

## Preprocessing

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(), #0〜255の整数値を0〜1の浮動小数点数型に変換する
    transforms.Normalize((0.5,), (0.5,), (0.5,)) #平均と標準偏差に0.5を指定することで、値の範囲を[-1, 1]にする
])

##　load datset

In [None]:
train_dataset = datasets.CIFAR10(root="../output/cifar_data", train=True, download=True, transform=transform)
validation_dataset = datasets.CIFAR10(root="../output/cifar_data", train=False, download=True, transform=transform)

## dataloader 

In [None]:
train_dataloader  = DataLoader(train_dataset, batch_size=16, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=4, shuffle=False)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

## functions to show an image

In [None]:
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(train_dataloader)
images, labels = dataiter.next()

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(16)))

## initialize

In [None]:
input_channels = 3
output_shape = 10

In [None]:
#model = Conv_Net()
model = ResNet34(input_channels, output_shape).to(device) #モデルを指定したデバイスに送る
#model.to(device) 

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

## train

In [None]:
loss_metric = nn.MSELoss()

In [None]:
for epoch in range(10):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(train_dataloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
    print(f'[{epoch+1}, {i+1}] loss: {running_loss/len(dataloader)}')

print('Finished Training')

## save

In [None]:
PATH = '../output/cifar_net.pth'
torch.save(model.state_dict(), PATH)

## test 

In [None]:
dataiter = iter(validation_dataloader)
images, labels = dataiter.next()

# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))

net = Conv_Net()
net.load_state_dict(torch.load(PATH))

outputs = net(images)

_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
                              for j in range(4)))

correct = 0
total = 0
with torch.no_grad():
    for data in validation_dataloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

In [None]:
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in validation_dataloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))

In [None]:
net