# 1.Prepare section

In [None]:
import os
from collections import OrderedDict

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import random_split
import torchvision.transforms as transforms
import cv2
import tqdm
import numpy as np
import matplotlib.pyplot as plt

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


# 2.Transform section

In [None]:
class ComposeTransform():
    """
    複数のTransformをまとめあげる
    """
    def __init__(self, transforms=None):
        """
        Parameters
        --------------
        transforms: list
            transformのインスタンスをリストにして渡す
        """
        self.transforms = transforms

    def __call__(self, x):
        if self.transforms:
            for transform in self.transforms:
                x = transform(x)
        return x


class BaseTransform():
    """
    自作Transformの基底クラス
    """
    def __init__(self, debug=False):
        self.debug = debug
    
    def __call__(self):
        raise NotImplementedError()


class SimpleTransform(BaseTransform):
    """
    とりあえずのクラス
    よく使うものを入れておく
    扱う関数が増えてきたらテーマごとに分離する
    """
    def __call__(self, x):
        if self.debug:
            # ここで途中途中のxの値を確認できるようにしたい
            pass
        x = self.pil2cv(x)
        x = cv2.resize(x, (227, 227))
        return x


    def pil2cv(self, image):
        ''' PIL型 -> OpenCV型 '''
        new_image = np.array(image, dtype=np.uint8)
        if new_image.ndim == 2:  # モノクロ
            pass
        elif new_image.shape[2] == 3:  # カラー
            new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
        elif new_image.shape[2] == 4:  # 透過
            new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGRA)
        return new_image

# 3.Dataset section

I will use given dataset in this notebook, so this section is no in use

# 4.Model section

In [None]:
NUM_CLASSES = 10

class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()

        self.output1 = nn.Sequential (
            nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),   # out : 55 x 55
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),   # out : 27 x 27
            nn.BatchNorm2d(96)
        )

        self.output2 = nn.Sequential(
            nn.Conv2d(96, 256, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),   # out : 13 x 13
            nn.BatchNorm2d(256)
        )

        self.features = nn.Sequential(
            nn.Conv2d(256, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),   # out : 6 x 6
            nn.BatchNorm2d(256)
        )

        self.classifier = nn.Sequential (
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, NUM_CLASSES),
        )


    def forward(self, x):
        x = self.output1(x)
        x = self.output2(x)
        x = self.features(x)

        x = x.view(x.size(0), 256 * 6 * 6)   # out : 9216
        return  self.classifier(x)

# 5.Main function section

In [None]:
def train_net(net, train_loader, eval_loader, optim_cls=optim.SGD, loss_fn=nn.CrossEntropyLoss(), n_iter=20, device= 'cpu'):
    optimizer = optim_cls(net.parameters(), lr=0.1)
    train_losses = []
    val_losses = []
    train_acc = []
    val_acc = []
    n = 0
    n_acc = 0
    net = net.to(device)

    for epoch in range(n_iter):
        running_loss=0.0
        net.train()
        with tqdm.tqdm(train_loader) as pbar:
            for i, (x, label) in enumerate(pbar):
                x = x.to(device)
                label = label.to(device)
                h = net(x)
                loss = loss_fn(h, label)
                running_loss+=loss.item()
                n += len(label)
                _, y_pred = h.max(1)
                n_acc += (y_pred==label).float().sum().item()

                # 逆伝播によるパラメータ更新
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                pbar.set_postfix(OrderedDict(
                    epoch= epoch+1,
                    loss=running_loss/(i+1), 
                    ))
            train_losses.append(running_loss / len(train_loader))
            train_acc.append(n_acc / n)
            val_loss, val_acc_ = val_net(net, eval_loader, loss_fn, device=device)
            val_losses.append(val_loss)
            val_acc.append(val_acc_)

    return train_losses, val_losses, train_acc, val_acc

def val_net(net, val_loader, loss_fn, device= 'cpu'):
    net.eval()
    val_acc = 0
    val_loss = 0
    n = 0
    n_acc =0
    running_loss=0.0
    net = net.to(device)
    for i, (x, label) in enumerate(val_loader):
        x = x.to(device)
        label = label.to(device)
        h = net(x)
        loss = loss_fn(h, label)
        running_loss+=loss.item()
        n += len(label)
        _, y_pred = h.max(1)
        n_acc += (y_pred==label).float().sum().item()
    val_acc = n_acc / n
    val_loss = running_loss / len(val_loader)
    return val_loss, val_acc

def pred_net(net, test_loader, device= 'cpu'):
    y_preds = []
    net = net.to(device)
    for i, x in enumerate(test_loader):
        x = x.to(device)
        h = net(x)
        _, y_pred = h.max(1)
        y_preds.append(y_pred)
    return torch.cat(y_preds,dim=0)

# 6.Train section

In [None]:
# Transform組み立て済み
transform = ComposeTransform([
    SimpleTransform(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Dataset組み立て
dataset = torchvision.datasets.CIFAR10('./datasets', train=True, 
                                         download=True, transform=transform)
print(dataset[0][0].size())
train_size = int(len(dataset)*0.8)
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Dataloader組み立て
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, 
                                           shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=128, 
                                           shuffle=True, num_workers=4)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

# Model組み立て
net = AlexNet().to(device)
if device == 'cuda':
    net = nn.DataParallel(net)
    torch.backends.cudnn.benchmark = True

# MainFunction実行
train_losses, val_losses, train_acc, val_acc= train_net(net, train_loader, val_loader, device=device)

Files already downloaded and verified
torch.Size([3, 227, 227])


  cpuset_checked))
100%|██████████| 313/313 [01:09<00:00,  4.49it/s, epoch=1, loss=1.44]
100%|██████████| 313/313 [01:08<00:00,  4.55it/s, epoch=2, loss=0.941]
100%|██████████| 313/313 [01:09<00:00,  4.47it/s, epoch=3, loss=0.718]
100%|██████████| 313/313 [01:10<00:00,  4.46it/s, epoch=4, loss=0.566]
100%|██████████| 313/313 [01:10<00:00,  4.45it/s, epoch=5, loss=0.43]
100%|██████████| 313/313 [01:11<00:00,  4.40it/s, epoch=6, loss=0.317]
100%|██████████| 313/313 [01:10<00:00,  4.43it/s, epoch=7, loss=0.232]
100%|██████████| 313/313 [01:10<00:00,  4.43it/s, epoch=8, loss=0.158]
100%|██████████| 313/313 [01:10<00:00,  4.44it/s, epoch=9, loss=0.113]
100%|██████████| 313/313 [01:10<00:00,  4.43it/s, epoch=10, loss=0.0852]
100%|██████████| 313/313 [01:10<00:00,  4.43it/s, epoch=11, loss=0.0705]
 98%|█████████▊| 307/313 [01:09<00:01,  5.02it/s, epoch=12, loss=0.0498]

# 7.Validate section

In [None]:
plt.plot(train_losses)
plt.plot(val_losses)
plt.show()
print(train_acc)
print(val_acc)

# 8.Test section

In [None]:
# Transform組み立て済み

# Dataset組み立て
test_set = torchvision.datasets.CIFAR10('./datasets', train=False, 
                                        download=True, transform=transform)

# Dataloader組み立て
test_loader = torch.utils.data.DataLoader(test_set, batch_size=100, 
                                          shuffle=False, num_workers=4)

# Model組み立て済み

# MainFunction実行
y_preds = pred_net(net, test_loader)

# Postprocess
print(y_preds)