# 1.Prepare section

In [6]:
import os
from collections import OrderedDict

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import random_split
import torchvision.transforms as transforms
import cv2
import tqdm
import numpy as np
import matplotlib.pyplot as plt

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


# 2.Transform section

In [7]:
class ComposeTransform():
    """
    複数のTransformをまとめあげる
    """
    def __init__(self, transforms=None):
        """
        Parameters
        --------------
        transforms: list
            transformのインスタンスをリストにして渡す
        """
        self.transforms = transforms

    def __call__(self, x):
        if self.transforms:
            for transform in self.transforms:
                x = transform(x)
        return x


class BaseTransform():
    """
    自作Transformの基底クラス
    """
    def __init__(self, debug=False):
        self.debug = debug
    
    def __call__(self):
        raise NotImplementedError()


class SimpleTransform(BaseTransform):
    """
    とりあえずのクラス
    よく使うものを入れておく
    扱う関数が増えてきたらテーマごとに分離する
    """
    def __call__(self, x):
        if self.debug:
            # ここで途中途中のxの値を確認できるようにしたい
            pass
        x = self.pil2cv(x)
        x = cv2.resize(x, (224, 224))
        return x


    def pil2cv(self, image):
        ''' PIL型 -> OpenCV型 '''
        new_image = np.array(image, dtype=np.uint8)
        if new_image.ndim == 2:  # モノクロ
            pass
        elif new_image.shape[2] == 3:  # カラー
            new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
        elif new_image.shape[2] == 4:  # 透過
            new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGRA)
        return new_image

# 3.Dataset section

I will use given dataset in this notebook, so this section is no in use

# 4.Model section

In [8]:
class VGG(nn.Module):

    def __init__(self):
        super(VGG, self).__init__()

        self.conv01 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv02 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)

        self.conv03 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv04 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)

        self.conv05 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv06 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv07 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(2, 2)

        self.conv08 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.conv09 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv10 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.pool4 = nn.MaxPool2d(2, 2)

        self.conv11 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv12 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv13 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.pool5 = nn.MaxPool2d(2, 2)

        self.fc1 = nn.Linear(512 * 7 * 7, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, 10)

        self.dropout1 = nn.Dropout(0.5)
        self.dropout2 = nn.Dropout(0.5)



    def forward(self, x):
        x = F.relu(self.conv01(x))
        x = F.relu(self.conv02(x))
        x = self.pool1(x)

        x = F.relu(self.conv03(x))
        x = F.relu(self.conv04(x))
        x = self.pool2(x)

        x = F.relu(self.conv05(x))
        x = F.relu(self.conv06(x))
        x = F.relu(self.conv07(x))
        x = self.pool3(x)

        x = F.relu(self.conv08(x))
        x = F.relu(self.conv09(x))
        x = F.relu(self.conv10(x))
        x = self.pool4(x)

        x = F.relu(self.conv11(x))
        x = F.relu(self.conv12(x))
        x = F.relu(self.conv13(x))
        x = self.pool5(x)

        # 平らにする
        x = x.view(-1, 512 * 7 * 7)
        
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)

        return x

# 5.Main function section

In [9]:
def train_net(net, train_loader, eval_loader, optim_cls=optim.SGD, loss_fn=nn.CrossEntropyLoss(), n_iter=20, device= 'cpu'):
    optimizer = optim_cls(net.parameters(), lr=0.1)
    train_losses = []
    val_losses = []
    train_acc = []
    val_acc = []
    n = 0
    n_acc = 0
    net = net.to(device)

    for epoch in range(n_iter):
        running_loss=0.0
        net.train()
        with tqdm.tqdm(train_loader) as pbar:
            for i, (x, label) in enumerate(pbar):
                x = x.to(device)
                label = label.to(device)
                h = net(x)
                loss = loss_fn(h, label)
                running_loss+=loss.item()
                n += len(label)
                _, y_pred = h.max(1)
                n_acc += (y_pred==label).float().sum().item()

                # 逆伝播によるパラメータ更新
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                pbar.set_postfix(OrderedDict(
                    epoch= epoch+1,
                    loss=running_loss/(i+1), 
                    ))
            train_losses.append(running_loss / len(train_loader))
            train_acc.append(n_acc / n)
            val_loss, val_acc_ = val_net(net, eval_loader, loss_fn, device=device)
            val_losses.append(val_loss)
            val_acc.append(val_acc_)

    return train_losses, val_losses, train_acc, val_acc

def val_net(net, val_loader, loss_fn, device= 'cpu'):
    net.eval()
    val_acc = 0
    val_loss = 0
    n = 0
    n_acc =0
    running_loss=0.0
    net = net.to(device)
    for i, (x, label) in enumerate(val_loader):
        x = x.to(device)
        label = label.to(device)
        h = net(x)
        loss = loss_fn(h, label)
        running_loss+=loss.item()
        n += len(label)
        _, y_pred = h.max(1)
        n_acc += (y_pred==label).float().sum().item()
    val_acc = n_acc / n
    val_loss = running_loss / len(val_loader)
    return val_loss, val_acc

def pred_net(net, test_loader, device= 'cpu'):
    y_preds = []
    net = net.to(device)
    for i, x in enumerate(test_loader):
        x = x.to(device)
        h = net(x)
        _, y_pred = h.max(1)
        y_preds.append(y_pred)
    return torch.cat(y_preds,dim=0)

# 6.Train section

In [None]:
# Transform組み立て済み
transform = ComposeTransform([
    SimpleTransform(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Dataset組み立て
dataset = torchvision.datasets.CIFAR10('./datasets', train=True, 
                                         download=True, transform=transform)
print(dataset[0][0].size())
train_size = int(len(dataset)*0.8)
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Dataloader組み立て
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, 
                                           shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, 
                                           shuffle=True, num_workers=4)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

# Model組み立て
net = VGG().to(device)
if device == 'cuda':
    net = nn.DataParallel(net)
    torch.backends.cudnn.benchmark = True

# MainFunction実行
train_losses, val_losses, train_acc, val_acc= train_net(net, train_loader, val_loader, device=device)

Files already downloaded and verified
torch.Size([3, 224, 224])


  cpuset_checked))
100%|██████████| 1250/1250 [09:03<00:00,  2.30it/s, epoch=1, loss=2.3]
100%|██████████| 1250/1250 [10:14<00:00,  2.03it/s, epoch=2, loss=2.3]


# 7.Validate section

In [None]:
plt.plot(train_losses)
plt.plot(val_losses)
plt.show()
print(train_acc)
print(val_acc)

# 8.Test section

In [None]:
# Transform組み立て済み

# Dataset組み立て
test_set = torchvision.datasets.CIFAR10('./datasets', train=False, 
                                        download=True, transform=transform)

# Dataloader組み立て
test_loader = torch.utils.data.DataLoader(test_set, batch_size=100, 
                                          shuffle=False, num_workers=4)

# Model組み立て済み

# MainFunction実行
y_preds = pred_net(net, test_loader)

# Postprocess
print(y_preds)