Skip to content

Latest commit



284 lines (180 loc) · 15.3 KB


File metadata and controls

284 lines (180 loc) · 15.3 KB

Let's try using the Trainer feature

By using, you don't need to write the training loop explicitly any more. Furthermore, Chainer provides many useful extensions that can be used with to visualize your results, evaluate your model, store and manage log files more easily.

This example will show how to use the to train a fully-connected feed-forward neural network on the MNIST dataset.


If you would like to know how to write a training loop without using the, please check train_loop instead of this tutorial.

1. Prepare the dataset

Load the MNIST dataset, which contains a training set of images and class labels as well as a corresponding test set.

from chainer.datasets import mnist

train, test = mnist.get_mnist()


You can use a Python list as a dataset. That's because ~chainer.dataset.Iterator can take any object as a dataset whose elements can be accessed via [] accessor and whose lengh can be obtained with len() function. For example,

train = [(x1, t1), (x2, t2), ...]

a list of tuples like this can be used as a dataset.

There are many utility dataset classes defined in ~chainer.datasets. It's recommended to utilize them in the actual applications.

For example, if your dataset consists of a number of image files, it would take a large amount of memory to load those data into a list like above. In that case, you can use ~chainer.datasets.ImageDataset, which just keeps the paths to image files. The actual image data will be loaded from the disk when the corresponding element is requested via [] accessor. Until then, no images are loaded to the memory to reduce memory use.

2. Prepare the dataset iterations

~chainer.dataset.Iterator creates a mini-batch from the given dataset.

batchsize = 128

train_iter = iterators.SerialIterator(train, batchsize) test_iter = iterators.SerialIterator(test, batchsize, False, False)

3. Prepare the model

Here, we are going to use the same model as the one defined in train_loop.

class MLP(Chain):

def __init__(self, n_mid_units=100, n_out=10):

super(MLP, self).__init__() with self.init_scope(): self.l1 = L.Linear(None, n_mid_units) self.l2 = L.Linear(None, n_mid_units) self.l3 = L.Linear(None, n_out)

def __call__(self, x):

h1 = F.relu(self.l1(x)) h2 = F.relu(self.l2(h1)) return self.l3(h2)

gpu_id = 0 # Set to -1 if you use CPU

model = MLP() if gpu_id >= 0: model.to_gpu(gpu_id)

4. Prepare the Updater is a class that holds all of the necessary components needed for training. The main components are shown below.


Basically, all you need to pass to is an However, contains an ~chainer.dataset.Iterator and ~chainer.Optimizer. Since ~chainer.dataset.Iterator can access the dataset and ~chainer.Optimizer has references to the model, can access to the model to update its parameters.

So, can perform the training procedure as shown below:

  1. Retrieve the data from dataset and construct a mini-batch (~chainer.dataset.Iterator)
  2. Pass the mini-batch to the model and calculate the loss
  3. Update the parameters of the model (~chainer.Optimizer)

Now let's create the object !

max_epoch = 10

# Wrapp your model by Classifier and include the process of loss calculation within your model. # Since we do not specify a loss funciton here, the default 'softmax_cross_entropy' is used. model = L.Classifier(model)

# selection of your optimizing method optimizer = optimizers.MomentumSGD()

# Give the optimizer a reference to the model optimizer.setup(model)

# Get an updater that uses the Iterator and Optimizer updater = training.StandardUpdater(train_iter, optimizer, device=gpu_id)


Here, the model defined above is passed to ~chainer.links.Classifier and changed to a new ~chainer.Chain. ~chainer.links.Classifier, which in fact inherits from the ~chainer.Chain class, keeps the given ~chainer.Chain model in its ~chainer.links.Classifier.predictor attribute. Once you give the input data and the corresponding class labels to the model by the () operator,

  1. ~chainer.links.Classifier.__call__ of the model is invoked. The data is then given to ~chainer.links.Classifier.predictor to obtain the output y.
  2. Next, together with the given labels, the output y is passed to the loss function which is determined by ~chainer.links.Classifier.lossfun argument in the constructor of ~chainer.links.Classifier.
  3. The loss is returned as a ~chainer.Variable.

In ~chainer.links.Classifier, the ~chainer.links.Classifier.lossfun is set to ~chainer.functions.softmax_cross_entropy as default. is the simplest class among several updaters. There are also the and the to utilize multiple GPUs. The uses the NVIDIA NCCL library, so you need to install NCCL and re-install CuPy before using it.

5. Setup Trainer

Lastly, we will setup The only requirement for creating a is to pass the object that we previously created above. You can also pass a to the second trainer argument as a tuple like (length, unit) to tell the trainer when to stop the training. The length is given as an integer and the unit is given as a string which should be either epoch or iteration. Without setting, the training will never be stopped.

# Setup a Trainer trainer = training.Trainer(updater, (max_epoch, 'epoch'), out='mnist_result')

The argument specifies an output directory used to save the log files, the image files of plots to show the time progress of loss, accuracy, etc. when you use extension. Next, we will explain how to display or save those information by using trainer

6. Add Extensions to the Trainer object

The extensions provide the following capabilites:

  • Save log files automatically (
  • Display the training information to the terminal periodically (
  • Visualize the loss progress by plottig a graph periodically and save it as an image file (
  • Automatically serialize the state periodically ( /
  • Display a progress bar to the terminal to show the progress of training (
  • Save the model architechture as a Graphviz's dot file (

To use these wide variety of tools for your training task, pass objects to the method of your object.

# Shortcut for doctests. max_epoch = 1 trainer = training.Trainer(updater, (max_epoch, 'epoch'), out='mnist_result') trainer.extend(extensions.snapshot_object(model.predictor, filename='model_epoch-10'))

# Allow doctest to run in headless environment. import matplotlib matplotlib.use('Agg')

trainer.extend(extensions.LogReport()) trainer.extend(extensions.snapshot(filename='snapshot_epoch-{.updater.epoch}')) trainer.extend(extensions.snapshot_object(model.predictor, filename='model_epoch-{.updater.epoch}')) trainer.extend(extensions.Evaluator(test_iter, model, device=gpu_id)) trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'validation/main/loss', 'validation/main/accuracy', 'elapsed_time'])) trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'], x_key='epoch', file_name='loss.png')) trainer.extend(extensions.PlotReport(['main/accuracy', 'validation/main/accuracy'], x_key='epoch', file_name='accuracy.png')) trainer.extend(extensions.dump_graph('main/loss'))

Collect loss and accuracy automatically every epoch or iteration and store the information under the log file in the directory specified by the argument when you create a object.

The method saves the object at the designated timing (defaut: every epoch) in the directory specified by The object, as mentioned before, has an which contains an ~chainer.Optimizer and a model inside. Therefore, as long as you have the snapshot file, you can use it to come back to the training or make inferences using the previously trained model later.

However, when you keep the whole object, in some cases, it is very tedious to retrieve only the inside of the model. By using, you can save the particular object (in this case, the model wrapped by ~chainer.links.Classifier) as a separeted snapshot. ~chainer.links.Classifier is a ~chainer.Chain object which keeps the model that is also a ~chainer.Chain object as its ~chainer.links.Classifier.predictor property, and all the parameters are under the ~chainer.links.Classifier.predictor, so taking the snapshot of ~chainer.links.Classifier.predictor is enough to keep all the trained parameters.

This method saves the structure of the computational graph of the model. The graph is saved in the Graphviz <>_`s dot format. The output location (directory) to save the graph is set by the` argument of

The ~chainer.dataset.Iterator that uses the evaluation dataset and the model object are required to use It evaluates the model using the given dataset (typically it's a validation dataset) at the specified timing interval.

It outputs the spcified values to the standard output. plots the values specified by its arguments saves it as a image file which has the same naem as the argument.

Each class has different options and some extensions are not mentioned here. And one of other important feature is, for instance, by using the option, you can set individual timings to fire the To know more details of all extensions, please take a look at the official document: Trainer extensions <reference/extensions.html>_.

7. Start Training

Just call method from object to start training.

epoch       main/loss   main/accuracy  validation/main/loss  validation/main/accuracy  elapsed_time
1           1.53241     0.638409       0.74935               0.835839                  4.93409
2           0.578334    0.858059       0.444722              0.882812                  7.72883
3           0.418569    0.886844       0.364943              0.899229                  10.4229
4           0.362342    0.899089       0.327569              0.905558                  13.148
5           0.331067    0.906517       0.304399              0.911788                  15.846
6           0.309019    0.911964       0.288295              0.917722                  18.5395
7           0.292312    0.916128       0.272073              0.921776                  21.2173
8           0.278291    0.92059        0.261351              0.923457                  23.9211
9           0.266266    0.923541       0.253195              0.927314                  26.6612
10          0.255489    0.926739       0.242415              0.929094                  29.466

Let's see the plot of loss progress saved in the mnist_result directory.


How about the accuracy?


Furthermore, let's visualize the computaional graph saved with using Graphviz.

% dot -Tpng mnist_result/ -o mnist_result/cg.png


From the top to the bottom, you can see the data flow in the computational graph. It basically shows how data and parameters are passed to the ~chainer.Functions.

8. Evaluate a pre-trained model

Evaluation using the snapshot of a model is as easy as what explained in the train_loop.

import matplotlib.pyplot as plt

model = MLP() serializers.load_npz('mnist_result/model_epoch-10', model)

# Show the output x, t = test[0] plt.imshow(x.reshape(28, 28), cmap='gray') print('label:', t)

y = model(x[None, ...])



label: 7 predicted_label: 7

The prediction looks correct. Success!