# 1.Prepare section

In [38]:
import os
from collections import OrderedDict

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import torchvision.transforms as transforms
import cv2
import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# 変数定義
batch_size = 2

cuda


# 2. Transform section

In [39]:
class ComposeTransform():
    """
    複数のTransformをまとめあげる
    """
    def __init__(self, transforms=None):
        """
        Parameters
        --------------
        transforms: list
            transformのインスタンスをリストにして渡す
        """
        self.transforms = transforms

    def __call__(self, x):
        if self.transforms:
            for transform in self.transforms:
                x = transform(x)
        return x


class BaseTransform():
    """
    自作Transformの基底クラス
    """
    def __init__(self, debug=False):
        self.debug = debug
    
    def __call__(self):
        raise NotImplementedError()


class SimpleTransform(BaseTransform):
    """
    とりあえずのクラス
    よく使うものを入れておく
    扱う関数が増えてきたらテーマごとに分離する
    """
    def __init__(self, debug):
        super().__init__(debug)
        self.applied_transforms = [
            self.resize,
            self.to_tensor,
            self.to_float
        ]
    def __call__(self, x):
        if self.debug:
            print(x)
            for transform in self.applied_transforms:
                x = transform(x)
                print('-------------------')
                print(str(transform))
                print(x)
        else:
            for transform in self.applied_transforms:
                x = transform(x)
        return x
    
    def resize(self, x):
        return x.reshape(1, 28, 28)
    
    def to_float(self, x):
        return x.float()
    
    def to_tensor(self, x):
        return torch.from_numpy(x)

class SimpleTargetTransform(BaseTransform):
    """
    とりあえずのクラス
    ターゲットに対して行う前処理を記述するクラス
    よく使うものを入れておく
    扱う関数が増えてきたらテーマごとに分離する
    """
    def __init__(self):
        super().__init__()
        self.applied_transforms = [
            # 適応する関数
        ]
    def __call__(self, y):
        if self.debug:
            for transform in self.applied_transforms:
                y = transform(y)
                print(str(transform))
                print(y)
                print('-------------------')
        else:
            for transform in self.applied_transforms:
                y = transform(y)
        return y

# 3. Dataset section

In [40]:
class Dataset():
    def __init__(self, path, transform= None, target_transform = None, train=True):
        self.transform = transform
        self.target_transform = target_transform
        # pandasは処理が遅いのでなるべく早くnumpyへ
        self.data = pd.read_csv(path).values
        self.train = train
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        if self.train:
            label =self.data[idx, 0]
            x = self.data[idx, 1:]
            if self.transform:
                x = self.transform(x)
            if self.target_transform:
                label = self.target_transform(label)
            return x, label
        else:
            x = self.data[idx, :]
            if self.transform:
                x = self.transform(x)
            return x

# 4. Model section

In [41]:
class FlattenLayer(nn.Module):
    """
    (N,C,H,W) -> (N,C*H*W)
    """
    def forward(self, x):
        size = x.size()
        return x.view(size[0], -1)

In [42]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        # (B, 1, 28, 28) -> (B, 10, 12, 12)
        self.conv1 = nn.Conv2d(1, 10, kernel_size=3)
        self.conv2 = nn.Conv2d(10, 10, kernel_size=3)
        self.pool1 = nn.MaxPool2d(2,2)

        # (B, 10, 12, 12) -> (B, 50, 5, 5)
        self.conv3 = nn.Conv2d(10, 50, kernel_size=3)
        self.pool2 = nn.MaxPool2d(2,2)

        # (B, 50, 5, 5) -> (B, 10)
        self.flatten = FlattenLayer()
        self.linear = nn.Linear(1250, 10)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.pool1(x)
        x = self.conv3(x)
        x = self.pool2(x)
        x = self.flatten(x)
        x = self.linear(x)
        return x





# 5. Main function section

In [43]:
def train_net(net, train_loader, eval_loader, optim_cls=optim.SGD, loss_fn=nn.CrossEntropyLoss(), n_iter=20, device= 'cpu'):
    optimizer = optim_cls(net.parameters(), lr=0.1)
    losses = []
    val_losses = []
    train_acc = []
    val_acc = []
    n = 0
    n_acc = 0
    net = net.to(device)

    for epoch in range(n_iter):
        running_loss=0.0
        net.train()
        with tqdm.tqdm(train_loader) as pbar:
            for i, (x, label) in enumerate(pbar):
                x = x.to(device)
                
                label = label.to(device)
                h = net(x)
                loss = loss_fn(h, label)
                running_loss+=loss.item()
                n += len(label)
                _, y_pred = h.max(1)
                n_acc += (y_pred==label).float().sum().item()

                # 逆伝播によるパラメータ更新
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                pbar.set_postfix(OrderedDict(
                    epoch= epoch+1,
                    loss=running_loss/(i+1), 
                    ))
            losses.append(running_loss / len(train_loader))
            train_acc.append(n_acc / n)
            val_loss, val_acc_ = val_net(net, eval_loader, loss_fn, device=device)
            val_losses.append(val_loss)
            val_acc.append(val_acc_)
        print(x[0])

    return losses, val_losses, train_acc, val_acc

def val_net(net, eval_loader, loss_fn, device= 'cpu'):
    net.eval()
    val_acc = 0
    val_loss = 0
    n = 0
    n_acc =0
    running_loss=0.0
    net = net.to(device)
    for i, (x, label) in enumerate(val_loader):
        x = x.to(device)
        print(x)
        label = label.to(device)
        h = net(x)
        print(x)
        loss = loss_fn(h, label)
        running_loss+=loss.item()
        n += len(label)
        _, y_pred = h.max(1)
        n_acc += (y_pred==label).float().sum().item()
    val_acc = n_acc / n
    val_loss = running_loss / len(val_loader)
    return val_loss, val_acc

def pred_net(net, test_loader, device= 'cpu'):
    y_preds = []
    net = net.to(device)
    for i, x in enumerate(test_loader):
        x = x.to(device)
        h = net(x)
        _, y_pred = h.max(1)
        y_preds.append(y_pred)
    return torch.cat(y_preds,dim=0)

# 6. Train section

In [44]:
"""
一応作っておくが、ここはいわば組み立て工場のような立ち位置で
どのTransform, Dataset, Dataloader, Model を使うかで
書き方がだいぶ違う
そのためその場その場で組み立てた方がいい
"""

# Transform組み立て
transform = ComposeTransform([
    SimpleTransform(debug=False),
    transforms.Normalize((0.5, ), (0.5, ))
    ])

# Dataset組み立て
train_dataset = Dataset('/content/sample_data/mnist_train_small.csv', transform=transform)
train_size = int(len(train_dataset)*0.8)
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

x = train_dataset[0][0].view(1,1,28,28).to(device)
label = train_dataset[0][1]

# Dataloader組み立て
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)


print(x)
aa = train_loader.__iter__()
bb, cc = aa.next()
print(bb[0])


# Model組み立て
D_net = Discriminator()
y = D_net(bb)
print(y)

# MainFunction組み立て
losses, val_losses, train_acc, val_acc= train_net(D_net, train_loader, val_loader, device=device)

KeyboardInterrupt: ignored

# 7. Validation section

In [None]:
plt.plot(losses)
plt.plot(val_losses)
plt.show()
print(train_acc)
print(val_acc)

# 8. Test section