In [1]:
%matplotlib inline

In [2]:
from chainer.datasets import cifar
from chainer import iterators
from chainer import optimizers
from chainer import training
from chainer.training import extensions
import chainer.links as L
import networks

def train(model_object, batchsize=64, gpu_id=0, max_epoch=20, train_dataset=None, test_dataset=None):

    # 1. Dataset
    if train_dataset is None and test_dataset is None:
        train, test = cifar.get_cifar10()
    else:
        train, test = train_dataset, test_dataset

    # 2. Iterator
    train_iter = iterators.SerialIterator(train, batchsize)
    test_iter = iterators.SerialIterator(test, batchsize, False, False)

    # 3. Model
    model = L.Classifier(model_object)
    if gpu_id >= 0:
        model.to_gpu(gpu_id)

    # 4. Optimizer
    optimizer = optimizers.Adam()
    optimizer.setup(model)

    # 5. Updater
    updater = training.StandardUpdater(train_iter, optimizer, device=gpu_id)

    # 6. Trainer
    trainer = training.Trainer(updater, (max_epoch, 'epoch'), out='{}_cifar10_result'.format(model_object.__class__.__name__))

    # 7. Evaluator

    class TestModeEvaluator(extensions.Evaluator):

        def evaluate(self):
            model = self.get_target('main')
            model.train = False
            ret = super(TestModeEvaluator, self).evaluate()
            model.train = True
            return ret

    trainer.extend(extensions.LogReport())
    trainer.extend(TestModeEvaluator(test_iter, model, device=gpu_id))
    trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'validation/main/loss', 'validation/main/accuracy', 'elapsed_time']))
    trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'], x_key='epoch', file_name='loss.png'))
    trainer.extend(extensions.PlotReport(['main/accuracy', 'validation/main/accuracy'], x_key='epoch', file_name='accuracy.png'))
    trainer.run()
    del trainer

    return model

model = train(networks.DeepCNN(10), gpu_id=1)  # CPUで実行する場合は、`gpu_id=-1`を指定

epoch       main/loss   main/accuracy  validation/main/loss  validation/main/accuracy  elapsed_time
[J1           1.55118     0.436781       1.33559               0.519705                  6.09307       
[J2           1.24418     0.551937       1.20892               0.563197                  11.7061       
[J3           1.09452     0.609075       1.17599               0.580215                  17.1812       
[J4           0.977321    0.652469       1.12981               0.598627                  22.8672       
[J5           0.867811    0.691057       1.06906               0.619924                  28.3825       
[J6           0.76205     0.729073       1.08851               0.62918                   33.7901       
[J7           0.654319    0.769626       1.1346                0.630076                  39.1825       
[J8           0.547686    0.805678       1.17452               0.630175                  44.6538       
[J9           0.443951    0.84413        1.24817           