In [1]:
%matplotlib inline

In [2]:
from chainer.datasets import cifar
from chainer import iterators
from chainer import optimizers
from chainer import training
from chainer.training import extensions
import chainer.links as L
import networks

def train(model_object, batchsize=64, gpu_id=0, max_epoch=20, train_dataset=None, test_dataset=None):

    # 1. Dataset
    if train_dataset is None and test_dataset is None:
        train, test = cifar.get_cifar10()
    else:
        train, test = train_dataset, test_dataset

    # 2. Iterator
    train_iter = iterators.SerialIterator(train, batchsize)
    test_iter = iterators.SerialIterator(test, batchsize, False, False)

    # 3. Model
    model = L.Classifier(model_object)
    if gpu_id >= 0:
        model.to_gpu(gpu_id)

    # 4. Optimizer
    optimizer = optimizers.Adam()
    optimizer.setup(model)

    # 5. Updater
    updater = training.StandardUpdater(train_iter, optimizer, device=gpu_id)

    # 6. Trainer
    trainer = training.Trainer(updater, (max_epoch, 'epoch'), out='{}_cifar10_result'.format(model_object.__class__.__name__))

    # 7. Evaluator

    class TestModeEvaluator(extensions.Evaluator):

        def evaluate(self):
            model = self.get_target('main')
            model.train = False
            ret = super(TestModeEvaluator, self).evaluate()
            model.train = True
            return ret

    trainer.extend(extensions.LogReport())
    trainer.extend(TestModeEvaluator(test_iter, model, device=gpu_id))
    trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'validation/main/loss', 'validation/main/accuracy', 'elapsed_time']))
    trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'], x_key='epoch', file_name='loss.png'))
    trainer.extend(extensions.PlotReport(['main/accuracy', 'validation/main/accuracy'], x_key='epoch', file_name='accuracy.png'))
    trainer.run()
    del trainer

    return model

model = train(networks.DeepCNN(10), gpu_id=1)  # CPUで実行する場合は、`gpu_id=-1`を指定

epoch       main/loss   main/accuracy  validation/main/loss  validation/main/accuracy  elapsed_time
[J1           2.00102     0.264066       1.64667               0.380474                  22.6498       
[J2           1.48478     0.444562       1.27334               0.555334                  44.8164       
[J3           1.19862     0.572043       1.02784               0.65416                   66.9242       
[J4           0.988633    0.660191       1.05836               0.63963                   89.0051       
[J5           0.826074    0.724724       0.73677               0.762639                  111.121       
[J6           0.698896    0.769326       0.643213              0.788117                  133.248       
[J7           0.606899    0.799296       0.592886              0.801652                  155.35        
[J8           0.536283    0.824444       0.5744                0.818471                  177.468       
[J9           0.478741    0.842931       0.56951           