By using ~chainer.training.Trainer
, you don't need to write the training loop explicitly any more. Furthermore, Chainer provides many useful extensions that can be used with ~chainer.training.Trainer
to visualize your results, evaluate your model, store and manage log files more easily.
This example will show how to use the ~chainer.training.Trainer
to train a fully-connected feed-forward neural network on the MNIST dataset.
Note
If you would like to know how to write a training loop without using the ~chainer.training.Trainer
, please check train_loop
instead of this tutorial.
Load the MNIST dataset, which contains a training set of images and class labels as well as a corresponding test set.
from chainer.datasets import mnist
train, test = mnist.get_mnist()
Note
You can use a Python list as a dataset. That's because ~chainer.dataset.Iterator
can take any object as a dataset whose elements can be accessed via []
accessor and whose length can be obtained with len()
function. For example,
train = [(x1, t1), (x2, t2), ...]
a list of tuples like this can be used as a dataset.
There are many utility dataset classes defined in ~chainer.datasets
. It's recommended to utilize them in the actual applications.
For example, if your dataset consists of a number of image files, it would take a large amount of memory to load those data into a list like above. In that case, you can use ~chainer.datasets.ImageDataset
, which just keeps the paths to image files. The actual image data will be loaded from the disk when the corresponding element is requested via []
accessor. Until then, no images are loaded to the memory to reduce memory use.
~chainer.dataset.Iterator
creates a mini-batch from the given dataset.
batchsize = 128
train_iter = iterators.SerialIterator(train, batchsize) test_iter = iterators.SerialIterator(test, batchsize, False, False)
Here, we are going to use the same model as the one defined in train_loop
.
class MLP(Chain):
- def __init__(self, n_mid_units=100, n_out=10):
super(MLP, self).__init__() with self.init_scope(): self.l1 = L.Linear(None, n_mid_units) self.l2 = L.Linear(None, n_mid_units) self.l3 = L.Linear(None, n_out)
- def __call__(self, x):
h1 = F.relu(self.l1(x)) h2 = F.relu(self.l2(h1)) return self.l3(h2)
gpu_id = 0 # Set to -1 if you use CPU
model = MLP() if gpu_id >= 0: model.to_gpu(gpu_id)
~chainer.training.Trainer
is a class that holds all of the necessary components needed for training. The main components are shown below.
Basically, all you need to pass to ~chainer.training.Trainer
is an ~chainer.training.Updater
. However, ~chainer.training.Updater
contains an ~chainer.dataset.Iterator
and ~chainer.Optimizer
. Since ~chainer.dataset.Iterator
can access the dataset and ~chainer.Optimizer
has references to the model, ~chainer.training.Updater
can access to the model to update its parameters.
So, ~chainer.training.Updater
can perform the training procedure as shown below:
- Retrieve the data from dataset and construct a mini-batch (
~chainer.dataset.Iterator
) - Pass the mini-batch to the model and calculate the loss
- Update the parameters of the model (
~chainer.Optimizer
)
Now let's create the ~chainer.training.Updater
object !
max_epoch = 10
# Wrap your model by Classifier and include the process of loss calculation within your model. # Since we do not specify a loss function here, the default 'softmax_cross_entropy' is used. model = L.Classifier(model)
# selection of your optimizing method optimizer = optimizers.MomentumSGD()
# Give the optimizer a reference to the model optimizer.setup(model)
# Get an updater that uses the Iterator and Optimizer updater = training.updaters.StandardUpdater(train_iter, optimizer, device=gpu_id)
Note
Here, the model defined above is passed to ~chainer.links.Classifier
and changed to a new ~chainer.Chain
. ~chainer.links.Classifier
, which in fact inherits from the ~chainer.Chain
class, keeps the given ~chainer.Chain
model in its ~chainer.links.Classifier.predictor
attribute. Once you give the input data and the corresponding class labels to the model by the ()
operator,
~chainer.links.Classifier.__call__
of the model is invoked. The data is then given to~chainer.links.Classifier.predictor
to obtain the outputy
.- Next, together with the given labels, the output
y
is passed to the loss function which is determined by~chainer.links.Classifier.lossfun
argument in the constructor of~chainer.links.Classifier
. - The loss is returned as a
~chainer.Variable
.
In ~chainer.links.Classifier
, the ~chainer.links.Classifier.lossfun
is set to ~chainer.functions.softmax_cross_entropy
as default.
~chainer.training.updaters.StandardUpdater
is the simplest class among several updaters. There are also the ~chainer.training.updaters.ParallelUpdater
and the ~chainer.training.updaters.MultiprocessParallelUpdater
to utilize multiple GPUs. The ~chainer.training.updaters.MultiprocessParallelUpdater
uses the NVIDIA NCCL library, so you need to install NCCL and re-install CuPy before using it.
Lastly, we will setup ~chainer.training.Trainer
. The only requirement for creating a ~chainer.training.Trainer
is to pass the ~chainer.training.Updater
object that we previously created above. You can also pass a ~chainer.training.Trainer.stop_trigger
to the second trainer argument as a tuple like (length, unit)
to tell the trainer when to stop the training. The length
is given as an integer and the unit
is given as a string which should be either epoch
or iteration
. Without setting ~chainer.training.Trainer.stop_trigger
, the training will never be stopped.
# Setup a Trainer trainer = training.Trainer(updater, (max_epoch, 'epoch'), out='mnist_result')
The ~chainer.training.Trainer.out
argument specifies an output directory used to save the log files, the image files of plots to show the time progress of loss, accuracy, etc. when you use ~chainer.training.extensions.PlotReport
extension. Next, we will explain how to display or save those information by using trainer ~chainer.training.Extension
.
The ~chainer.training.Trainer
extensions provide the following capabilities:
- Save log files automatically (
~chainer.training.extensions.LogReport
) - Display the training information to the terminal periodically (
~chainer.training.extensions.PrintReport
) - Visualize the loss progress by plotting a graph periodically and save it as an image file (
~chainer.training.extensions.PlotReport
) - Automatically serialize the state periodically (
~chainer.training.extensions.snapshot
/~chainer.training.extensions.snapshot_object
) - Display a progress bar to the terminal to show the progress of training (
~chainer.training.extensions.ProgressBar
) - Save the model architecture as a Graphviz's dot file (
~chainer.training.extensions.dump_graph
)
To use these wide variety of tools for your training task, pass ~chainer.training.Extension
objects to the ~chainer.training.Trainer.extend
method of your ~chainer.training.Trainer
object.
# Shortcut for doctests. max_epoch = 1 trainer = training.Trainer(updater, (max_epoch, 'epoch'), out='mnist_result') trainer.extend(extensions.snapshot_object(model.predictor, filename='model_epoch-10'))
# Allow doctest to run in headless environment. import matplotlib matplotlib.use('Agg')
from chainer.training import extensions
trainer.extend(extensions.LogReport()) trainer.extend(extensions.snapshot(filename='snapshot_epoch-{.updater.epoch}')) trainer.extend(extensions.snapshot_object(model.predictor, filename='model_epoch-{.updater.epoch}')) trainer.extend(extensions.Evaluator(test_iter, model, device=gpu_id)) trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'validation/main/loss', 'validation/main/accuracy', 'elapsed_time'])) trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'], x_key='epoch', file_name='loss.png')) trainer.extend(extensions.PlotReport(['main/accuracy', 'validation/main/accuracy'], x_key='epoch', file_name='accuracy.png')) trainer.extend(extensions.dump_graph('main/loss'))
Collect loss
and accuracy
automatically every epoch
or iteration
and store the information under the log
file in the directory specified by the ~chainer.training.Trainer.out
argument when you create a ~chainer.training.Trainer
object.
The ~chainer.training.extensions.snapshot
method saves the ~chainer.training.Trainer
object at the designated timing (default: every epoch) in the directory specified by ~chainer.training.Trainer.out
. The ~chainer.training.Trainer
object, as mentioned before, has an ~chainer.training.Updater
which contains an ~chainer.Optimizer
and a model inside. Therefore, as long as you have the snapshot file, you can use it to come back to the training or make inferences using the previously trained model later.
However, when you keep the whole ~chainer.training.Trainer
object, in some cases, it is very tedious to retrieve only the inside of the model. By using ~chainer.training.extensions.snapshot_object
, you can save the particular object (in this case, the model wrapped by ~chainer.links.Classifier
) as a separate snapshot. ~chainer.links.Classifier
is a ~chainer.Chain
object which keeps the model that is also a ~chainer.Chain
object as its ~chainer.links.Classifier.predictor
property, and all the parameters are under the ~chainer.links.Classifier.predictor
, so taking the snapshot of ~chainer.links.Classifier.predictor
is enough to keep all the trained parameters.
This is a list of commonly used trainer extensions:
~chainer.training.extensions.LogReport
This extension collects the loss and accuracy values every epoch or iteration and stores in a log file. The log file will be located under the output directory (specified by
out
argument of the~chainer.training.Trainer
object).~chainer.training.extensions.snapshot
This extension saves the
~chainer.training.Trainer
object at the designated timing (defaut: every epoch) in the output directory. The~chainer.training.Trainer
object, as mentioned before, has an~chainer.training.Updater
which contains an~chainer.Optimizer
and a model inside. Therefore, as long as you have the snapshot file, you can use it to come back to the training or make inferences using the previously trained model later.~chainer.training.extensions.snapshot_object
~chainer.training.extensions.snapshot
extension above saves the whole~chainer.training.Trainer
object. However, in some cases, it is tedious to retrieve only the inside of the model. By using~chainer.training.extensions.snapshot_object
, you can save the particular object (in the example above, the model wrapped by~chainer.links.Classifier
) as a separeted snapshot. Taking the snapshot of~chainer.links.Classifier.predictor
is enough to keep all the trained parameters, because~chainer.links.Classifier
(which is a subclass of~chainer.Chain
) keeps the model as its~chainer.links.Classifier.predictor
property, and all the parameters are under this property.~chainer.training.extensions.dump_graph
This extension saves the structure of the computational graph of the model. The graph is saved in Graphviz dot format under the output directory of the
~chainer.training.Trainer
.~chainer.training.extensions.Evaluator
~chainer.dataset.Iterator
s that use the evaluation dataset and the model object are required to use~chainer.training.extensions.Evaluator
extension. It evaluates the model using the given dataset (typically it's a validation dataset) at the specified timing interval.~chainer.training.extensions.PrintReport
This extension outputs the spcified values to the standard output.
~chainer.training.extensions.PlotReport
This extension plots the values specified by its arguments and saves it as a image file.
This is not an exhaustive list of built-in extensions. Please take a look at extensions
for more of them.
Just call ~chainer.training.Trainer.run
method from ~chainer.training.Trainer
object to start training.
trainer.run()
epoch main/loss main/accuracy validation/main/loss validation/main/accuracy elapsed_time
1 1.53241 0.638409 0.74935 0.835839 4.93409
2 0.578334 0.858059 0.444722 0.882812 7.72883
3 0.418569 0.886844 0.364943 0.899229 10.4229
4 0.362342 0.899089 0.327569 0.905558 13.148
5 0.331067 0.906517 0.304399 0.911788 15.846
6 0.309019 0.911964 0.288295 0.917722 18.5395
7 0.292312 0.916128 0.272073 0.921776 21.2173
8 0.278291 0.92059 0.261351 0.923457 23.9211
9 0.266266 0.923541 0.253195 0.927314 26.6612
10 0.255489 0.926739 0.242415 0.929094 29.466
Let's see the plot of loss progress saved in the mnist_result
directory.
How about the accuracy?
Furthermore, let's visualize the computational graph saved with ~chainer.training.extensions.dump_graph
using Graphviz.
% dot -Tpng mnist_result/cg.dot -o mnist_result/cg.png
From the top to the bottom, you can see the data flow in the computational graph. It basically shows how data and parameters are passed to the ~chainer.Function
s.
Evaluation using the snapshot of a model is as easy as what explained in the train_loop
.
import matplotlib.pyplot as plt
model = MLP() serializers.load_npz('mnist_result/model_epoch-10', model)
# Show the output x, t = test[0] plt.imshow(x.reshape(28, 28), cmap='gray') plt.show() print('label:', t)
y = model(x[None, ...])
print('predicted_label:', y.data.argmax(axis=1)[0])
label: 7 predicted_label: 7
The prediction looks correct. Success!