Skip to content

Latest commit

 

History

History
298 lines (192 loc) · 16.3 KB

mnist.rst

File metadata and controls

298 lines (192 loc) · 16.3 KB

MNIST using Trainer

By using ~chainer.training.Trainer, you don't need to write the training loop explicitly any more. Furthermore, Chainer provides many useful extensions that can be used with ~chainer.training.Trainer to visualize your results, evaluate your model, store and manage log files more easily.

This example will show how to use the ~chainer.training.Trainer to train a fully-connected feed-forward neural network on the MNIST dataset.

Note

If you would like to know how to write a training loop without using the ~chainer.training.Trainer, please check train_loop instead of this tutorial.

1. Prepare the dataset

Load the MNIST dataset, which contains a training set of images and class labels as well as a corresponding test set.

from chainer.datasets import mnist

train, test = mnist.get_mnist()

Note

You can use a Python list as a dataset. That's because ~chainer.dataset.Iterator can take any object as a dataset whose elements can be accessed via [] accessor and whose length can be obtained with len() function. For example,

train = [(x1, t1), (x2, t2), ...]

a list of tuples like this can be used as a dataset.

There are many utility dataset classes defined in ~chainer.datasets. It's recommended to utilize them in the actual applications.

For example, if your dataset consists of a number of image files, it would take a large amount of memory to load those data into a list like above. In that case, you can use ~chainer.datasets.ImageDataset, which just keeps the paths to image files. The actual image data will be loaded from the disk when the corresponding element is requested via [] accessor. Until then, no images are loaded to the memory to reduce memory use.

2. Prepare the dataset iterations

~chainer.dataset.Iterator creates a mini-batch from the given dataset.

batchsize = 128

train_iter = iterators.SerialIterator(train, batchsize) test_iter = iterators.SerialIterator(test, batchsize, False, False)

3. Prepare the model

Here, we are going to use the same model as the one defined in train_loop.

class MLP(Chain):

def __init__(self, n_mid_units=100, n_out=10):

super(MLP, self).__init__() with self.init_scope(): self.l1 = L.Linear(None, n_mid_units) self.l2 = L.Linear(None, n_mid_units) self.l3 = L.Linear(None, n_out)

def __call__(self, x):

h1 = F.relu(self.l1(x)) h2 = F.relu(self.l2(h1)) return self.l3(h2)

gpu_id = 0 # Set to -1 if you use CPU

model = MLP() if gpu_id >= 0: model.to_gpu(gpu_id)

4. Prepare the Updater

~chainer.training.Trainer is a class that holds all of the necessary components needed for training. The main components are shown below.

image

Basically, all you need to pass to ~chainer.training.Trainer is an ~chainer.training.Updater. However, ~chainer.training.Updater contains an ~chainer.dataset.Iterator and ~chainer.Optimizer. Since ~chainer.dataset.Iterator can access the dataset and ~chainer.Optimizer has references to the model, ~chainer.training.Updater can access to the model to update its parameters.

So, ~chainer.training.Updater can perform the training procedure as shown below:

  1. Retrieve the data from dataset and construct a mini-batch (~chainer.dataset.Iterator)
  2. Pass the mini-batch to the model and calculate the loss
  3. Update the parameters of the model (~chainer.Optimizer)

Now let's create the ~chainer.training.Updater object !

max_epoch = 10

# Wrap your model by Classifier and include the process of loss calculation within your model. # Since we do not specify a loss function here, the default 'softmax_cross_entropy' is used. model = L.Classifier(model)

# selection of your optimizing method optimizer = optimizers.MomentumSGD()

# Give the optimizer a reference to the model optimizer.setup(model)

# Get an updater that uses the Iterator and Optimizer updater = training.updaters.StandardUpdater(train_iter, optimizer, device=gpu_id)

Note

Here, the model defined above is passed to ~chainer.links.Classifier and changed to a new ~chainer.Chain. ~chainer.links.Classifier, which in fact inherits from the ~chainer.Chain class, keeps the given ~chainer.Chain model in its ~chainer.links.Classifier.predictor attribute. Once you give the input data and the corresponding class labels to the model by the () operator,

  1. ~chainer.links.Classifier.__call__ of the model is invoked. The data is then given to ~chainer.links.Classifier.predictor to obtain the output y.
  2. Next, together with the given labels, the output y is passed to the loss function which is determined by ~chainer.links.Classifier.lossfun argument in the constructor of ~chainer.links.Classifier.
  3. The loss is returned as a ~chainer.Variable.

In ~chainer.links.Classifier, the ~chainer.links.Classifier.lossfun is set to ~chainer.functions.softmax_cross_entropy as default.

~chainer.training.updaters.StandardUpdater is the simplest class among several updaters. There are also the ~chainer.training.updaters.ParallelUpdater and the ~chainer.training.updaters.MultiprocessParallelUpdater to utilize multiple GPUs. The ~chainer.training.updaters.MultiprocessParallelUpdater uses the NVIDIA NCCL library, so you need to install NCCL and re-install CuPy before using it.

5. Setup Trainer

Lastly, we will setup ~chainer.training.Trainer. The only requirement for creating a ~chainer.training.Trainer is to pass the ~chainer.training.Updater object that we previously created above. You can also pass a ~chainer.training.Trainer.stop_trigger to the second trainer argument as a tuple like (length, unit) to tell the trainer when to stop the training. The length is given as an integer and the unit is given as a string which should be either epoch or iteration. Without setting ~chainer.training.Trainer.stop_trigger, the training will never be stopped.

# Setup a Trainer trainer = training.Trainer(updater, (max_epoch, 'epoch'), out='mnist_result')

The ~chainer.training.Trainer.out argument specifies an output directory used to save the log files, the image files of plots to show the time progress of loss, accuracy, etc. when you use ~chainer.training.extensions.PlotReport extension. Next, we will explain how to display or save those information by using trainer ~chainer.training.Extension.

6. Add Extensions to the Trainer object

The ~chainer.training.Trainer extensions provide the following capabilities:

  • Save log files automatically (~chainer.training.extensions.LogReport)
  • Display the training information to the terminal periodically (~chainer.training.extensions.PrintReport)
  • Visualize the loss progress by plotting a graph periodically and save it as an image file (~chainer.training.extensions.PlotReport)
  • Automatically serialize the state periodically (~chainer.training.extensions.snapshot / ~chainer.training.extensions.snapshot_object)
  • Display a progress bar to the terminal to show the progress of training (~chainer.training.extensions.ProgressBar)
  • Save the model architecture as a Graphviz's dot file (~chainer.training.extensions.dump_graph)

To use these wide variety of tools for your training task, pass ~chainer.training.Extension objects to the ~chainer.training.Trainer.extend method of your ~chainer.training.Trainer object.

# Shortcut for doctests. max_epoch = 1 trainer = training.Trainer(updater, (max_epoch, 'epoch'), out='mnist_result') trainer.extend(extensions.snapshot_object(model.predictor, filename='model_epoch-10'))

# Allow doctest to run in headless environment. import matplotlib matplotlib.use('Agg')

from chainer.training import extensions

trainer.extend(extensions.LogReport()) trainer.extend(extensions.snapshot(filename='snapshot_epoch-{.updater.epoch}')) trainer.extend(extensions.snapshot_object(model.predictor, filename='model_epoch-{.updater.epoch}')) trainer.extend(extensions.Evaluator(test_iter, model, device=gpu_id)) trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'validation/main/loss', 'validation/main/accuracy', 'elapsed_time'])) trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'], x_key='epoch', file_name='loss.png')) trainer.extend(extensions.PlotReport(['main/accuracy', 'validation/main/accuracy'], x_key='epoch', file_name='accuracy.png')) trainer.extend(extensions.dump_graph('main/loss'))

~chainer.training.extensions.LogReport

Collect loss and accuracy automatically every epoch or iteration and store the information under the log file in the directory specified by the ~chainer.training.Trainer.out argument when you create a ~chainer.training.Trainer object.

~chainer.training.extensions.snapshot

The ~chainer.training.extensions.snapshot method saves the ~chainer.training.Trainer object at the designated timing (default: every epoch) in the directory specified by ~chainer.training.Trainer.out. The ~chainer.training.Trainer object, as mentioned before, has an ~chainer.training.Updater which contains an ~chainer.Optimizer and a model inside. Therefore, as long as you have the snapshot file, you can use it to come back to the training or make inferences using the previously trained model later.

~chainer.training.extensions.snapshot_object

However, when you keep the whole ~chainer.training.Trainer object, in some cases, it is very tedious to retrieve only the inside of the model. By using ~chainer.training.extensions.snapshot_object, you can save the particular object (in this case, the model wrapped by ~chainer.links.Classifier) as a separate snapshot. ~chainer.links.Classifier is a ~chainer.Chain object which keeps the model that is also a ~chainer.Chain object as its ~chainer.links.Classifier.predictor property, and all the parameters are under the ~chainer.links.Classifier.predictor, so taking the snapshot of ~chainer.links.Classifier.predictor is enough to keep all the trained parameters.

This is a list of commonly used trainer extensions:

~chainer.training.extensions.LogReport

This extension collects the loss and accuracy values every epoch or iteration and stores in a log file. The log file will be located under the output directory (specified by out argument of the ~chainer.training.Trainer object).

~chainer.training.extensions.snapshot

This extension saves the ~chainer.training.Trainer object at the designated timing (defaut: every epoch) in the output directory. The ~chainer.training.Trainer object, as mentioned before, has an ~chainer.training.Updater which contains an ~chainer.Optimizer and a model inside. Therefore, as long as you have the snapshot file, you can use it to come back to the training or make inferences using the previously trained model later.

~chainer.training.extensions.snapshot_object

~chainer.training.extensions.snapshot extension above saves the whole ~chainer.training.Trainer object. However, in some cases, it is tedious to retrieve only the inside of the model. By using ~chainer.training.extensions.snapshot_object, you can save the particular object (in the example above, the model wrapped by ~chainer.links.Classifier) as a separeted snapshot. Taking the snapshot of ~chainer.links.Classifier.predictor is enough to keep all the trained parameters, because ~chainer.links.Classifier (which is a subclass of ~chainer.Chain) keeps the model as its ~chainer.links.Classifier.predictor property, and all the parameters are under this property.

~chainer.training.extensions.dump_graph

This extension saves the structure of the computational graph of the model. The graph is saved in Graphviz dot format under the output directory of the ~chainer.training.Trainer.

~chainer.training.extensions.Evaluator

~chainer.dataset.Iterators that use the evaluation dataset and the model object are required to use ~chainer.training.extensions.Evaluator extension. It evaluates the model using the given dataset (typically it's a validation dataset) at the specified timing interval.

~chainer.training.extensions.PrintReport

This extension outputs the spcified values to the standard output.

~chainer.training.extensions.PlotReport

This extension plots the values specified by its arguments and saves it as a image file.

This is not an exhaustive list of built-in extensions. Please take a look at extensions for more of them.

7. Start Training

Just call ~chainer.training.Trainer.run method from ~chainer.training.Trainer object to start training.

trainer.run()

epoch       main/loss   main/accuracy  validation/main/loss  validation/main/accuracy  elapsed_time
1           1.53241     0.638409       0.74935               0.835839                  4.93409
2           0.578334    0.858059       0.444722              0.882812                  7.72883
3           0.418569    0.886844       0.364943              0.899229                  10.4229
4           0.362342    0.899089       0.327569              0.905558                  13.148
5           0.331067    0.906517       0.304399              0.911788                  15.846
6           0.309019    0.911964       0.288295              0.917722                  18.5395
7           0.292312    0.916128       0.272073              0.921776                  21.2173
8           0.278291    0.92059        0.261351              0.923457                  23.9211
9           0.266266    0.923541       0.253195              0.927314                  26.6612
10          0.255489    0.926739       0.242415              0.929094                  29.466

Let's see the plot of loss progress saved in the mnist_result directory.

image

How about the accuracy?

image

Furthermore, let's visualize the computational graph saved with ~chainer.training.extensions.dump_graph using Graphviz.

% dot -Tpng mnist_result/cg.dot -o mnist_result/cg.png

image

From the top to the bottom, you can see the data flow in the computational graph. It basically shows how data and parameters are passed to the ~chainer.Functions.

8. Evaluate a pre-trained model

Evaluation using the snapshot of a model is as easy as what explained in the train_loop.

import matplotlib.pyplot as plt

model = MLP() serializers.load_npz('mnist_result/model_epoch-10', model)

# Show the output x, t = test[0] plt.imshow(x.reshape(28, 28), cmap='gray') plt.show() print('label:', t)

y = model(x[None, ...])

print('predicted_label:', y.data.argmax(axis=1)[0])

image

label: 7 predicted_label: 7

The prediction looks correct. Success!