Skip to content

Commit

Permalink
Detailed documentation
Browse files Browse the repository at this point in the history
Co-authored-by: Chiheon Kim <frost.conv@kakaobrain.com>
Co-authored-by: Myungryong Jeong <paul.june@kakaobrain.com>
Co-authored-by: Sungbin Lim <leo.brain@kakaobrain.com>
  • Loading branch information
4 people committed Jun 26, 2019
1 parent c176235 commit 0aa8484
Show file tree
Hide file tree
Showing 16 changed files with 1,123 additions and 47 deletions.
2 changes: 1 addition & 1 deletion README.ko.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# torchgpipe
# torchgpipe <img src="docs/_static/not-pipe.svg" height="20" />

[![PyPI](https://img.shields.io/pypi/v/torchgpipe.svg)](https://pypi.org/project/torchgpipe)
[![Build Status](https://travis-ci.org/kakaobrain/torchgpipe.svg?branch=master)](https://travis-ci.org/kakaobrain/torchgpipe)
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# torchgpipe
# torchgpipe <img src="docs/_static/not-pipe.svg" height="20" />

[![PyPI](https://img.shields.io/pypi/v/torchgpipe.svg)](https://pypi.org/project/torchgpipe)
[![Build Status](https://travis-ci.org/kakaobrain/torchgpipe.svg?branch=master)](https://travis-ci.org/kakaobrain/torchgpipe)
Expand Down Expand Up @@ -65,7 +65,7 @@ $ pip install torchgpipe
```

To train a module with GPipe, simply wrap it with `torchgpipe.GPipe`. Your
module must be `nn.Sequential` as GPipe will automatically break up the module
module must be `nn.Sequential` as GPipe will automatically split the module
into partitions with consecutive layers. `balance` argument determines the
number of layers in each partition. `chunks` argument specifies the number of
micro-batches. Input, output, and intermediate tensors must be `Tensor` or
Expand Down
8 changes: 8 additions & 0 deletions docs/_static/not-pipe.svg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
26 changes: 26 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
API
===

GPipe Module
~~~~~~~~~~~~

.. autoclass:: torchgpipe.GPipe(module, balance, \**kwargs)

.. automethod:: forward(input)

.. autoattribute:: devices
:annotation:

Inspecting GPipe Timeline
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchgpipe.current_microbatch()

.. autofunction:: torchgpipe.is_recomputing()

Automatic Balancing
~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchgpipe_balancing.balance_by_time(module, canary, partitions, device, timeout)

.. autofunction:: torchgpipe_balancing.balance_by_size(module, canary, partitions, device)
45 changes: 45 additions & 0 deletions docs/benchmarks.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
Benchmarks
==========

ResNet-101
~~~~~~~~~~

ResNet-101 Performance Benchmark
--------------------------------

========== =================== =======
Experiment Throughput Speedup
========== =================== =======
naive-1 100.506 samples/sec 1.000x
pipeline-1 73.925 samples/sec 0.736x
pipeline-2 135.691 samples/sec 1.350x
pipeline-4 230.216 samples/sec 2.291x
pipeline-8 312.945 samples/sec 3.114x
========== =================== =======

The code which is reproducible on Tesla P40 GPUs, and the experiment details
can be found in `examples/resnet101_performance_benchmark`_.

.. _examples/resnet101_performance_benchmark:
https://github.com/kakaobrain/torchgpipe/
tree/master/examples/resnet101_performance_benchmark

AmoebaNet-D
~~~~~~~~~~~

.. AmoebaNet-D Performance Benchmark
.. ---------------------------------
AmoebaNet-D Memory Benchmark
----------------------------

========== =========== ========== ================ =================
Experiment AmoebaNet-D # of Model Total Model Total Peak
(L, F) Parameters Parameter Memory Activation Memory
========== =========== ========== ================ =================
naive-1 (6, 208) 90M 1.00GB --
pipeline-1 (6, 416) 358M 4.01GB 6.64GB
pipeline-2 (6, 544) 613M 6.45GB 11.31GB
pipeline-4 (12, 544) 1.16B 13.00GB 18.72GB
pipeline-8 (24, 512) 2.01B 22.42GB 35.78GB
========== =========== ========== ================ =================
23 changes: 17 additions & 6 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
Changelog
=========

v0.0.2 (WIP)
~~~~~~

Not released yet.

- Added support for PyTorch 1.1.
- Refined public APIs.
- Proper exceptions for invalid usage.
- Provided inspecting utilities: :func:`torchgpipe.current_microbatch` and
:func:`torchgpipe.is_recomputing`
- Reimplemented deferred batch normalization by subclassing.
- Provided :mod:`torchgpipe_balancing` for automatic balancing.

v0.0.1
------
~~~~~~

Released on May 14, 2019 to evaluate usability and efficiency internally.

Provided a basically functional GPipe implementation, including pipeline
parallelism, checkpointing with preceding recomputation, and deferred
BatchNorm.

Supported Python 3.6+ and PyTorch 1.0.
- Provided a functional GPipe implementation, including pipeline parallelism,
checkpointing, and deferred batch normalization.
- Supported Python 3.6+ and PyTorch 1.0.
10 changes: 10 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
extensions = [
# We follow Google style docstrings just like PyTorch.
'sphinx.ext.napoleon',

# Allow reference sections using its title.
'sphinx.ext.autosectionlabel',
]

# Add any paths that contain templates here, relative to this directory.
Expand All @@ -64,11 +67,18 @@
html_theme = 'alabaster'

html_theme_options = {
'logo': 'not-pipe.svg',
'logo_name': True,
'description': 'GPipe for PyTorch',

'github_user': 'kakaobrain',
'github_repo': 'torchgpipe',
'github_type': 'star',

'extra_nav_links': {
'Source Code': 'https://github.com/kakaobrain/torchgpipe',
'Original Paper': 'https://arxiv.org/abs/1811.06965',
},
}

# Add any paths that contain custom static files (such as style sheets) here,
Expand Down
102 changes: 102 additions & 0 deletions docs/gpipe.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
Understanding GPipe
===================

GPipe uses (a) :ref:`pipeline parallelism` and (b) automatic recomputation of
the forward propagation during the backpropagation, hence leverages training a
large model. We refer to (b) as :ref:`checkpointing`, following the well-known
terminology in PyTorch community.

Pipeline Parallelism
~~~~~~~~~~~~~~~~~~~~

GPipe splits a model into multiple partitions and places each partition on a
different device to occupy more memory capacity. For example, we may split a
model occupying 40GB of CUDA memory into 4 partitions each occupying 10GB,
respectively.

This approach is called `model parallelism`. However, typical deep learning
models are composed of sequential layers. In other words, usually the latter
layer wouldn't work until the prior layer has finished. If a model is composed
of fully sequential layers, even if we spread the model over two or more
devices, only one device can be utilized at once.

.. image:: img/model-parallel.svg
:align: center
:height: 110

GPipe splits a mini-batch into multiple micro-batches to make the devices work
as parallel as possible. It is called `pipeline parallelism`. Basically,
pipeline parallelism is a stack of small data parallelism. When each partition
has finished processing a micro-batch, it can toss the output to the next
partition and immediately can start to work on the next micro-batch. Now the
partitions can be overlapped.

.. image:: img/pipeline-parallel.svg
:align: center
:height: 110

.. seealso::
`Model Parallel Best Practices in PyTorch Tutorials
<https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html>`_

There is still idle time called `bubble` because every partition has to wait
for the first micro-batch from the prior partition. The bubble can be reduced
by choosing a smaller size of micro-batches. But usually, larger batch size can
utilize GPU more efficiently. Hence, GPU may be underutilized if too small size
of micro-batches is chosen.

A faster partition should wait for adjacent slower partition. Therefore,
imbalance over partitions also may cause GPU underutilization. Note that the
overall performance is determined by the slowest partition.

.. image:: img/imbalance.svg
:align: center
:height: 110

Checkpointing
~~~~~~~~~~~~~

Checkpointing is applied to each partition to minimize the overall memory
consumption by a model. During forward propagation, only the tensors at the
boundaries between partitions are remembered. All other intermediate tensors
are volatilized, and recomputed during backpropagation when necessary.
Specifically, hidden layers consume the memory which is required by only a
single micro-batch with checkpointing.

Checkpointing is a trade-off between performance and memory, because
recomputation spends time just as much as the forward propagation. When you use
:class:`torchgpipe.GPipe`, you can decide to turn off checkpointing by
``checkpoint='never'`` option.

Deferred Batch Normalization
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

One of the goals of GPipe is `transparency`. GPipe shouldn't affect existing
hyperparameters and output during training. However, if a module processes per
mini-batch, not per single sample, it might be affected by GPipe since each
module could see only a micro-batch at once.

Meanwhile, batch normalization is a module commonly used in CNN. The forward
propagation of this module performs two procedures in training. Both the
procedures are per mini-batch, not micro-batch:

1. Normalizing a mini-batch by the average and variance of the just given
mini-batch.
2. Tracking moving statistics (mean and variance) of mini-batches to normalize
batches in evaluation.

GPipe couldn't provide transparency for the first procedure (normalization).
Per mini-batch normalization introduces a dependency among the micro-batches,
hence it breaks the parallelism of GPipe. But the second procedure (tracking
moving statistics) could be transparent with GPipe by accumulating statistics
of all micro-batches.

:mod:`torchgpipe` provides this functionality as `deferred batch
normalization`. But in the current implementation, it is slower than the
vanilla batch normalization. That is why we turn off by default. If you need
transparent moving statistics, turn on by ``deferred_batch_norm=True`` option
in :class:`~torchgpipe.GPipe`::

model = GPipe(model, balance=[1, 1, 1, 1], chunks=8,
# Turn on deferred batch normalization.
deferred_batch_norm=True)

0 comments on commit 0aa8484

Please sign in to comment.