In [48]:
#ネットワーク定義

import numpy as np
import chainer
import chainer.links as L
import chainer.functions as F
from chainer.datasets import split_dataset_random
from chainer import iterators
from chainer import optimizers
from chainer.dataset import concat_examples
from chainer.cuda import to_cpu
from chainer import training
from chainer.training import extensions
from chainercv.datasets import DirectoryParsingLabelDataset
from chainer.datasets import split_dataset_random

class MyNet(chainer.Chain):

    def __init__(self, n_out):
        super(MyNet, self).__init__()
        with self.init_scope():
            self.conv1 = L.Convolution2D(None, 32, 3, 3, 1)
            self.conv2 = L.Convolution2D(32, 64, 3, 3, 1)
            self.conv3 = L.Convolution2D(64, 128, 3, 3, 1)
            self.fc4 = L.Linear(None, 1000)
            self.fc5 = L.Linear(1000, n_out)

    def __call__(self, x):
        h = F.relu(self.conv1(x))
        h = F.relu(self.conv2(h))
        h = F.relu(self.conv3(h))
        h = F.relu(self.fc4(h))
        h = self.fc5(h)
        return h


In [49]:
class PreprocessDataset(chainer.dataset.DatasetMixin):
    def __init__(self, pair):
        self.base = pair
        
    def __len__(self):
        return len(self.base)
    
    def get_example(self, i):
        
        image, label = self.base[i]
        image *= (1.0 / 255.0)
        
        return (image, label)

def load_dataset(path):
    dataset = DirectoryParsingLabelDataset(path)
    dataset = PreprocessDataset(dataset)
    train_size = int(len(dataset) * 0.9)
    train, test = split_dataset_random(dataset, train_size, seed=0)
    return train, test

In [50]:
from matplotlib import pyplot as plt
from chainer.datasets import cifar
# from load_datasets import load_dataset


def train(net, batchsize=128, gpu_id=0, max_epoch=5, train_dataset=None, valid_dataset=None, test_dataset=None, postfix='', base_lr=0.01, lr_decay=None):

    # 1. Dataset
    if train_dataset is None and valid_dataset is None and test_dataset is None:
        train, test = load_dataset('./../resize_good_condition/')
        # train_size = int(len(train_val) * 0.9)
        # train, valid = split_dataset_random(train_val, train_size, seed=0)
    else:
        train, valid, test = train_dataset, valid_dataset, test_dataset
    
    #Augument
    #train = CIFAR10Augmented(train)
    
    # 2. Iterator
    train_iter = iterators.MultiprocessIterator(train, batchsize)
    valid_iter = iterators.MultiprocessIterator(test, batchsize, False, False)

    # 3. Model
    model = L.Classifier(net)

    # 4. Optimizer
    optimizer = optimizers.MomentumSGD(lr=base_lr)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.WeightDecay(0.0005))

    # 5. Updater
    updater = training.StandardUpdater(train_iter, optimizer, device=gpu_id)

    # 6. Trainer
    trainer = training.Trainer(updater, (max_epoch, 'epoch'), out='{}_original_{}result'.format(net.__class__.__name__, postfix))

    # 7. Trainer extensions
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.observe_lr())
    trainer.extend(extensions.Evaluator(valid_iter, model, device=gpu_id), name='val')
    trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'val/main/loss', 'val/main/accuracy', 'elapsed_time', 'lr']))
    trainer.extend(extensions.PlotReport(['main/loss', 'val/main/loss'], x_key='epoch', file_name='loss.png'))
    trainer.extend(extensions.PlotReport(['main/accuracy', 'val/main/accuracy'], x_key='epoch', file_name='accuracy.png'))
    if lr_decay is not None:
        trainer.extend(extensions.ExponentialShift('lr', 0.1), trigger=lr_decay)
    trainer.run()
    
    del trainer

    # 8. Evaluation
    test_iter = iterators.MultiprocessIterator(test, batchsize, False, False)
    test_evaluator = extensions.Evaluator(test_iter, model, device=gpu_id)
    results = test_evaluator()
    print('Test accuracy:', results['main/accuracy'])

    return net

In [51]:
net = train(MyNet(8), gpu_id=0)

epoch       main/loss   main/accuracy  val/main/loss  val/main/accuracy  elapsed_time  lr        
[J1           2.53451     0.128015       2.21566        0.140929           6.55583       0.01        
[J2           3.754       0.126953       2.09829        0.107853           11.4152       0.01        
[J3           2.10493     0.155971       2.13422        0.173295           16.4271       0.01        
[J4           2.72076     0.130757       2.11554        0.140929           21.3235       0.01        
[J5           2.15452     0.125279       2.27431        0.142891           26.1579       0.01        
Test accuracy: 0.14289096
