diff --git a/docs/image/trainer/mnist_accuracy.png b/docs/image/trainer/mnist_accuracy.png new file mode 100644 index 000000000000..c6607e3421d0 Binary files /dev/null and b/docs/image/trainer/mnist_accuracy.png differ diff --git a/docs/image/trainer/mnist_graph.png b/docs/image/trainer/mnist_graph.png new file mode 100644 index 000000000000..089e09e6dae6 Binary files /dev/null and b/docs/image/trainer/mnist_graph.png differ diff --git a/docs/image/trainer/mnist_loss.png b/docs/image/trainer/mnist_loss.png new file mode 100644 index 000000000000..81607e276293 Binary files /dev/null and b/docs/image/trainer/mnist_loss.png differ diff --git a/docs/image/trainer/mnist_output.png b/docs/image/trainer/mnist_output.png new file mode 100644 index 000000000000..6019dd64d50d Binary files /dev/null and b/docs/image/trainer/mnist_output.png differ diff --git a/docs/image/trainer/trainer.png b/docs/image/trainer/trainer.png new file mode 100644 index 000000000000..c997b9d07708 Binary files /dev/null and b/docs/image/trainer/trainer.png differ diff --git a/docs/source/tutorial/train.rst b/docs/source/tutorial/train.rst index 03659500d74f..e7e6fbef66fa 100644 --- a/docs/source/tutorial/train.rst +++ b/docs/source/tutorial/train.rst @@ -5,3 +5,4 @@ How to Train a Network :maxdepth: 2 train_loop + trainer diff --git a/docs/source/tutorial/trainer.rst b/docs/source/tutorial/trainer.rst new file mode 100644 index 000000000000..d0c080399eac --- /dev/null +++ b/docs/source/tutorial/trainer.rst @@ -0,0 +1,309 @@ +Let's try using the Trainer feature +``````````````````````````````````` + +By using :class:`~chainer.training.Trainer`, you don't need to write the training loop explicitly any more. Furthermore, Chainer provides many useful extensions that can be used with :class:`~chainer.training.Trainer` to visualize your results, evaluate your model, store and manage log files more easily. + +This example will show how to use the :class:`~chainer.training.Trainer` to train a fully-connected feed-forward neural network on the MNIST dataset. + +.. note:: + + If you would like to know how to write a training loop without using the :class:`~chainer.training.Trainer`, please check :doc:`train_loop` instead of this tutorial. + +1. Prepare the dataset +'''''''''''''''''''''' + +Load the MNIST dataset, which contains a training set of images and class labels as well as a corresponding test set. + +.. testcode:: + + from chainer.datasets import mnist + + train, test = mnist.get_mnist() + +.. note:: + + **You can use a Python list as a dataset.** That's because :class:`~chainer.dataset.Iterator` can take any object as a dataset whose elements can be accessed via ``[]`` accessor and whose lengh can be obtained with ``len()`` function. For example, + + .. code-block:: python + + train = [(x1, t1), (x2, t2), ...] + + a list of tuples like this can be used as a dataset. + + There are many utility dataset classes defined in :mod:`~chainer.datasets`. It's recommended to utilize them in the actual applications. + + For example, if your dataset consists of a number of image files, it would take a large amount of memory to load those data into a list like above. In that case, you can use :class:`~chainer.datasets.ImageDataset`, which just keeps the paths to image files. The actual image data will be loaded from the disk when the corresponding element is requested via ``[]`` accessor. Until then, no images are loaded to the memory to reduce memory use. + +2. Prepare the dataset iterations +''''''''''''''''''''''''''''''''' + +:class:`~chainer.dataset.Iterator` creates a mini-batch from the given dataset. + +.. testsetup:: * + + train, test = mnist.get_mnist() + +.. testcode:: + + batchsize = 128 + + train_iter = iterators.SerialIterator(train, batchsize) + test_iter = iterators.SerialIterator(test, batchsize, False, False) + +3. Prepare the model +'''''''''''''''''''' + +Here, we are going to use the same model as the one defined in :doc:`train_loop`. + +.. testcode:: + + class MLP(Chain): + + def __init__(self, n_mid_units=100, n_out=10): + super(MLP, self).__init__() + with self.init_scope(): + self.l1 = L.Linear(None, n_mid_units) + self.l2 = L.Linear(None, n_mid_units) + self.l3 = L.Linear(None, n_out) + + def __call__(self, x): + h1 = F.relu(self.l1(x)) + h2 = F.relu(self.l2(h1)) + return self.l3(h2) + + gpu_id = 0 # Set to -1 if you use CPU + + model = MLP() + if gpu_id >= 0: + model.to_gpu(gpu_id) + +4. Prepare the Updater +'''''''''''''''''''''' + +:class:`~chainer.training.Trainer` is a class that holds all of the necessary components needed for training. The main components are shown below. + +.. image:: ../../image/trainer/trainer.png + +Basically, all you need to pass to :class:`~chainer.training.Trainer` is an :class:`~chainer.training.Updater`. However, :class:`~chainer.training.Updater` contains an :class:`~chainer.dataset.Iterator` and :class:`~chainer.Optimizer`. Since :class:`~chainer.dataset.Iterator` can access the dataset and :class:`~chainer.Optimizer` has references to the model, :class:`~chainer.training.Updater` can access to the model to update its parameters. + +So, :class:`~chainer.training.Updater` can perform the training procedure as shown below: + +1. Retrieve the data from dataset and construct a mini-batch (:class:`~chainer.dataset.Iterator`) +2. Pass the mini-batch to the model and calculate the loss +3. Update the parameters of the model (:class:`~chainer.Optimizer`) + +Now let's create the :class:`~chainer.training.Updater` object ! + +.. testsetup:: * + + from chainer.datasets import mnist + + class MLP(Chain): + + def __init__(self, n_mid_units=100, n_out=10): + super(MLP, self).__init__() + with self.init_scope(): + self.l1 = L.Linear(None, n_mid_units) + self.l2 = L.Linear(None, n_mid_units) + self.l3 = L.Linear(None, n_out) + + def __call__(self, x): + h1 = F.relu(self.l1(x)) + h2 = F.relu(self.l2(h1)) + return self.l3(h2) + + model = MLP() + + batchsize = 128 + + train, test = mnist.get_mnist() + train_iter = iterators.SerialIterator(train, batchsize) + test_iter = iterators.SerialIterator(test, batchsize, False, False) + +.. testcode:: + + max_epoch = 10 + + # Wrapp your model by Classifier and include the process of loss calculation within your model. + # Since we do not specify a loss funciton here, the default 'softmax_cross_entropy' is used. + model = L.Classifier(model) + + # selection of your optimizing method + optimizer = optimizers.MomentumSGD() + + # Give the optimizer a reference to the model + optimizer.setup(model) + + # Get an updater that uses the Iterator and Optimizer + updater = training.StandardUpdater(train_iter, optimizer) + +.. note:: + + Here, the model defined above is passed to :class:`~chainer.links.Classifier` and changed to a new :class:`~chainer.Chain`. :class:`~chainer.links.Classifier`, which in fact inherits from the :class:`~chainer.Chain` class, keeps the given :class:`~chainer.Chain` model in its :attr:`~chainer.links.Classifier.predictor` attribute. Once you give the input data and the corresponding class labels to the model by the ``()`` operator, + + 1. :meth:`~chainer.links.Classifier.__call__` of the model is invoked. The data is then given to :attr:`~chainer.links.Classifier.predictor` to obtain the output ``y``. + 2. Next, together with the given labels, the output ``y`` is passed to the loss function which is determined by :attr:`~chainer.links.Classifier.lossfun` argument in the constructor of :class:`~chainer.links.Classifier`. + 3. The loss is returned as a :class:`~chainer.Variable`. + + In :class:`~chainer.links.Classifier`, the :attr:`~chainer.links.Classifier.lossfun` is set to + :meth:`~chainer.functions.softmax_cross_entropy` as default. + + :class:`~chainer.training.StandardUpdater` is the simplest class among several updaters. There are also the :class:`~chainer.training.ParallelUpdater` and the :class:`~chainer.training.updaters.MultiprocessParallelUpdater` to utilize multiple GPUs. The :class:`~chainer.training.updaters.MultiprocessParallelUpdater` uses the NVIDIA NCCL library, so you need to install NCCL and re-install CuPy before using it. + +5. Setup Trainer +'''''''''''''''' + +Lastly, we will setup :class:`~chainer.training.Trainer`. The only requirement for creating a :class:`~chainer.training.Trainer` is to pass the :class:`~chainer.training.Updater` object that we previously created above. You can also pass a :attr:`~chainer.training.Trainer.stop_trigger` to the second trainer argument as a tuple like ``(length, unit)`` to tell the trainer when to stop the training. The ``length`` is given as an integer and the ``unit`` is given as a string which should be either ``epoch`` or ``iteration``. Without setting :attr:`~chainer.training.Trainer.stop_trigger`, the training will never be stopped. + +.. testsetup:: * + + model = L.Classifier(model) + optimizer = optimizers.MomentumSGD() + optimizer.setup(model) + updater = training.StandardUpdater(train_iter, optimizer) + +.. testcode:: + + # Setup a Trainer + trainer = training.Trainer(updater, (max_epoch, 'epoch'), out='mnist_result') + +The :attr:`~chainer.training.Trainer.out` argument specifies an output directory used to save the +log files, the image files of plots to show the time progress of loss, accuracy, etc. when you use :class:`~chainer.training.extensions.PlotReport` extension. Next, we will explain how to display or save those information by using trainer :class:`~chainer.training.Extension`. + +6. Add Extensions to the Trainer object +''''''''''''''''''''''''''''''''''''''' + +The :class:`~chainer.training.Trainer` extensions provide the following capabilites: + +* Save log files automatically (:class:`~chainer.training.extensions.LogReport`) +* Display the training information to the terminal periodically (:class:`~chainer.training.extensions.PrintReport`) +* Visualize the loss progress by plottig a graph periodically and save it as an image file (:class:`~chainer.training.extensions.PlotReport`) +* Automatically serialize the state periodically (:meth:`~chainer.training.extensions.snapshot` / :meth:`~chainer.training.extensions.snapshot_object`) +* Display a progress bar to the terminal to show the progress of training (:class:`~chainer.training.extensions.ProgressBar`) +* Save the model architechture as a Graphviz's dot file (:meth:`~chainer.training.extensions.dump_graph`) + +To use these wide variety of tools for your training task, pass :class:`~chainer.training.Extension` objects to the :meth:`~chainer.training.Trainer.extend` method of your :class:`~chainer.training.Trainer` object. + +.. testcode:: + + trainer.extend(extensions.LogReport()) + trainer.extend(extensions.snapshot(filename='snapshot_epoch-{.updater.epoch}')) + trainer.extend(extensions.snapshot_object(model.predictor, filename='model_epoch-{.updater.epoch}')) + trainer.extend(extensions.Evaluator(test_iter, model, device=gpu_id)) + trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'validation/main/loss', 'validation/main/accuracy', 'elapsed_time'])) + trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'], x_key='epoch', file_name='loss.png')) + trainer.extend(extensions.PlotReport(['main/accuracy', 'validation/main/accuracy'], x_key='epoch', file_name='accuracy.png')) + trainer.extend(extensions.dump_graph('main/loss')) + +:class:`~chainer.training.extensions.LogReport` +............................................... + +Collect ``loss`` and ``accuracy`` automatically every ``epoch`` or ``iteration`` and store the information under the ``log`` file in the directory specified by the :attr:`~chainer.training.Trainer.out` argument when you create a :class:`~chainer.training.Trainer` object. + +:meth:`~chainer.training.extensions.snapshot` +............................................. + +The :meth:`~chainer.training.extensions.snapshot` method saves the :class:`~chainer.training.Trainer` object at the designated timing (defaut: every epoch) in the directory specified by :attr:`~chainer.training.Trainer.out`. The :class:`~chainer.training.Trainer` object, as mentioned before, has an :class:`~chainer.training.Updater` which contains an :class:`~chainer.Optimizer` and a model inside. Therefore, as long as you have the snapshot file, you can use it to come back to the training or make inferences using the previously trained model later. + +:meth:`~chainer.training.extensions.snapshot_object` +.................................................... + +However, when you keep the whole :class:`~chainer.training.Trainer` object, in some cases, it is very tedious to retrieve only the inside of the model. By using :meth:`~chainer.training.extensions.snapshot_object`, you can save the particular object (in this case, the model wrapped by :class:`~chainer.links.Classifier`) as a separeted snapshot. :class:`~chainer.links.Classifier` is a :class:`~chainer.Chain` object which keeps the model that is also a :class:`~chainer.Chain` object as its :attr:`~chainer.links.Classifier.predictor` property, and all the parameters are under the :attr:`~chainer.links.Classifier.predictor`, so taking the snapshot of :attr:`~chainer.links.Classifier.predictor` is enough to keep all the trained parameters. + +:meth:`~chainer.training.extensions.dump_graph` +............................................... + +This method saves the structure of the computational graph of the model. The graph is saved in the +`Graphviz _`s dot format. The output location (directory) to save the graph is set by the :attr:`~chainer.training.Trainer.out` argument of :class:`~chainer.training.Trainer`. + +:class:`~chainer.training.extensions.Evaluator` +............................................... + +The :class:`~chainer.dataset.Iterator` that uses the evaluation dataset and the model object are required to use :class:`~chainer.training.extensions.Evaluator`. It evaluates the model using the given dataset (typically it's a validation dataset) at the specified timing interval. + +:class:`~chainer.training.extensions.PrintReport` +................................................. + +It outputs the spcified values to the standard output. + +:class:`~chainer.training.extensions.PlotReport` +................................................ + +:class:`~chainer.training.extensions.PlotReport` plots the values specified by its arguments saves it as a image file which has the same naem as the :attr:`~chainer.training.extensions.PlotReport.file_name` argument. + +---- + +Each :class:`~chainer.training.Extension` class has different options and some extensions are not mentioned here. And one of other important feature is, for instance, by using the :attr:`~chainer.training.Extension.trigger` option, you can set individual timings to fire the :class:`~chainer.training.Extension`. To know more details of all extensions, please take a look at the official document: `Trainer extensions _`. + +7. Start Training +''''''''''''''''' + +Just call :meth:`~chainer.training.Trainer.run` method from +:class:`~chainer.training.Trainer` object to start training. + +.. code-block:: python + + trainer.run() + +:: + + epoch main/loss main/accuracy validation/main/loss validation/main/accuracy elapsed_time + 1 1.53241 0.638409 0.74935 0.835839 4.93409 + 2 0.578334 0.858059 0.444722 0.882812 7.72883 + 3 0.418569 0.886844 0.364943 0.899229 10.4229 + 4 0.362342 0.899089 0.327569 0.905558 13.148 + 5 0.331067 0.906517 0.304399 0.911788 15.846 + 6 0.309019 0.911964 0.288295 0.917722 18.5395 + 7 0.292312 0.916128 0.272073 0.921776 21.2173 + 8 0.278291 0.92059 0.261351 0.923457 23.9211 + 9 0.266266 0.923541 0.253195 0.927314 26.6612 + 10 0.255489 0.926739 0.242415 0.929094 29.466 + +Let's see the plot of loss progress saved in the ``mnist_result`` directory. + +.. image:: ../../image/trainer/mnist_loss.png + +How about the accuracy? + +.. image:: ../../image/trainer/mnist_accuracy.png + +Furthermore, let's visualize the computaional graph saved with :meth:`~chainer.training.extensions.dump_graph` using Graphviz. + +:: + + % dot -Tpng mnist_result/cg.dot -o mnist_result/cg.png + +.. image:: ../../image/trainer/mnist_graph.png + +From the top to the bottom, you can see the data flow in the computational graph. It basically shows how data and parameters are passed to the :class:`~chainer.Function`\ s. + +8. Evaluate a pre-trained model +''''''''''''''''''''''''''''''' + +Evaluation using the snapshot of a model is as easy as what explained in the :doc:`train_loop`. + +.. code-block:: python + + import matplotlib.pyplot as plt + + model = MLP() + serializers.load_npz('mnist_result/model_epoch-10', model) + + # Show the output + x, t = test[0] + plt.imshow(x.reshape(28, 28), cmap='gray') + plt.show() + print('label:', t) + + y = model(x) + + print('predicted_label:', y.argmax(axis=1)[0]) + +.. image:: ../../image/trainer/mnist_output.png + +:: + + label: 7 + predicted_label: 7 + +The prediction looks correct. Success!