Skip to content

Commit

Permalink
Documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
sublee authored and GitHub Enterprise committed Sep 30, 2019
2 parents e22fc34 + 83f656a commit 1e1e276
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 37 deletions.
32 changes: 22 additions & 10 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,30 @@ v0.0.3 (WIP)

Not released yet.

- Added support for PyTorch 1.2.
- Redesigned the internal pipeline parallelism to represent dependencies
transparently.
- Optimized by using separate CUDA streams for device-to-device copies.
- Fixed hang at a once failed partition.
- Removed ``current_microbatch`` which actually didn't work.
- Fixed the size cumulation (`issue #3`_ by `Shiyan Deng`_) of
:func:`~torchgpipe_balancing.balance_by_size`.
Featured:
torchgpipe now overlaps copy and computation using the separate CUDA
streams. Previously, GPU could not compute a partition while copying
micro-batches across different GPUs because they all happened on the same
default CUDA stream.

Other Improvements:
- Added support for PyTorch 1.2.
- Redesigned the internal pipeline parallelism to represent dependencies
transparently.
- Fixed the hanging issue when an exception is raised in a partition.
- Fixed the unintended size accumulation (`issue #3`_ by `Shiyan Deng`_) of
:func:`~torchgpipe_balancing.balance_by_size`.

.. _issue #3: https://github.com/kakaobrain/torchgpipe/issues/3
.. _Shiyan Deng: https://github.com/842974287

Breaking Changes:
- No more support for PyTorch 1.0.
- Changed type of :attr:`GPipe.devices <torchgpipe.GPipe.devices>` from
``tuple`` to ``list``.
- Removed ``current_microbatch``. This approach turned out to be
incompatible with checkpointing.

v0.0.2
~~~~~~

Expand All @@ -28,8 +40,8 @@ Released on June 26, 2019.
- Detailed documentation.
- Proper exceptions for invalid usage.
- Provided :ref:`automatic balancing <Automatic Balancing>`.
- Provided inspecting utilities: ``current_microbatch`` (deprecated since
v0.0.3) and :func:`~torchgpipe.is_recomputing`
- Provided inspecting utilities: ``current_microbatch`` (DO NOT USE, deprecated
since v0.0.3) and :func:`~torchgpipe.is_recomputing`
- Reimplemented deferred batch normalization by subclassing.

v0.0.1
Expand Down
89 changes: 88 additions & 1 deletion docs/guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ parameter::
devices=[4, 5, 6, 7], # Specify GPUs.
chunks=8)

The typical model parallelism is a special case of GPipe. GPipe without
micro-batches and checkpointing is equivalent to model parallelism. You can
disable them with ``chunks=1`` and ``checkpoint='never'`` options::

model = GPipe(model, balance=[2, 2], chunks=1, checkpoint='never')

Input and Output Device
~~~~~~~~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -105,6 +111,72 @@ CUDA memory usage of each layer. Choose the balancing tool for your needs::

.. _PyTorch JIT: https://pytorch.org/docs/stable/jit.html

Trade-offs
~~~~~~~~~~

Number of Micro-batches
-----------------------

Number of micro-batches has a trade-off between GPU utilization per micro-batch
and total area of bubble. You need to find the best number of micro-batches for
your model.

GPU may slow down when processing many small micro-batches compared to larger
micro-batches. GPU will not be fully utilized if each CUDA kernel is too cheap
to compute, hence too small micro-batches cause underutilization. On the other
hand, the area of bubble is minimized when the size of each micro-batch is
minimal. Ideally, you should choose the largest number of micro-batches that
doesn't underutilize GPUs.

As a side note, BatchNorm tends to perform worse with smaller batch size. Large
number of micro-batches may affect the final performance of model using
BatchNorm negatively just like in :class:`nn.DataParallel
<torch.nn.DataParallel>`.

Checkpointing
-------------

Checkpointing drastically helps to reduce memory usage, but the overall
training would slow down by about 25%. You can handle how to apply
checkpointing on your model. There are three options:

- ``always`` -- Apply checkpointing over all micro-batches.
- ``except_last`` (default) -- Apply checkpointing except the last micro-batch.
- ``never`` -- Checkpointing is never applied.

Usually, checkpointing at the last micro-batch may not be useful because the
saved memory will be reconstructed immediately. That's why we choose
``except_last`` as the default option.

If you decide not to use checkpointing at all, :class:`nn.DataParallel
<torch.nn.DataParallel>` might be more efficient than GPipe.

Referential Transparency
~~~~~~~~~~~~~~~~~~~~~~~~

Checkpointing executes forward propagation again at backpropagation, which is
called `recomputation`. We assume that both the executions are identical.
Hence, all layers should be `referentially transparent
<https://en.wikipedia.org/wiki/Referential_transparency>`_ in forward
propagation. Here are the typical cases that break referential transparency:

In-place Operations:
We do not recommend using in-place operations with checkpointing.
Especially, if an in-place operation such as ``add_(1)`` is applied to the
input of a checkpointed partition, then the recomputation can't recover the
original input.

Nondeterminism:
For example, :class:`nn.Dropout <torch.nn.Dropout>` will produce different
mask in recomputation from the forward propagation due to the randomness.
This type of nondeterministic behaviors are not taken care of in torchgpipe
yet.

Side Effects:
Some modules such as BatchNorm update their state in forward propagation.
Hence, updated state in recomputation might not be identical to the original
state.

Restrictions
~~~~~~~~~~~~

Expand All @@ -122,7 +194,7 @@ Sequential:
>>> GPipe(model, balance=..., chunks=...)
Traceback (most recent call last)
...
TypeError: non-sequential module cannot be partitioned
TypeError: module must be nn.Sequential to be partitioned

See `the sequential ResNet example`_ to figure out how to make a model into
a :class:`nn.Sequential <torch.nn.Sequential>` model.
Expand Down Expand Up @@ -163,6 +235,21 @@ Tensor or Tensors:
The reason is that GPipe can't assume how the non-tensor inputs for a
mini-batch can be split for micro-batches.

Unique Parameters:
:class:`~torchgpipe.GPipe` places each partition on the corresponding
device. When placing a partition, the parameters of the partition are also
moved to the destination. GPipe cannot support a module with a parameter on
two or more devices::

>>> conv1 = nn.Conv2d(3, 3, 1)
>>> conv2 = nn.Conv2d(3, 3, 1)
>>> conv1.weight = conv2.weight
>>> model = nn.Sequential(conv1, conv2)
>>> model = GPipe(model, balance=[1, 1], ...)
Traceback (most recent call last)
...
ValueError: module with duplicate parameters in distinct children is not supported

Complex Modules
~~~~~~~~~~~~~~~

Expand Down
2 changes: 1 addition & 1 deletion examples/resnet101_accuracy_benchmark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Then, run each benchmark:
```sh
$ python main.py naive-128
$ python main.py dataparallel-256 # 2 GPUs required
$ python main.py dataparallel-1k # 4 GPUs required
$ python main.py dataparallel-1k # 8 GPUs required
$ python main.py gpipe-256 # 2 GPUs required
$ python main.py gpipe-1k # 8 GPUs required
$ python main.py gpipe-4k # 8 GPUs required
Expand Down
6 changes: 3 additions & 3 deletions torchgpipe/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
PyTorch already provides the official checkpointing utilities in
:mod:`torch.utils.checkpoint`. The official checkpointing combines
recomputation and recursive backpropagation into one autograd function named
``Checkpoint``. Hence, the recomputation can be started only when the gradients
arrive to the function. In GPipe, the recomputation should be preceding with
the gradient to minimize GPU idle time.
``CheckpointFunction``. Hence, the recomputation can be started only when the
gradients arrive to the function. In GPipe, the recomputation needs to precede
the gradient arrival to minimize the GPU idle time.
We solve this problem by introducing separate autograd functions named
:class:`Recompute` and :class:`Checkpoint`. Each function represents
Expand Down
13 changes: 9 additions & 4 deletions torchgpipe/copy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Autograd functions for stream-aware copy. It can be used to overlap copy and
computation on the same GPU.
"""Autograd functions for stream-aware CUDA copy. It is used to overlap copy
and computation on the same GPU.
"""
from collections import deque
from typing import Deque, List, Optional, Tuple
Expand All @@ -23,7 +23,7 @@ class Context:


class Copy(torch.autograd.Function):
"""Copies a tensor on specific streams."""
"""Copies tensors on specific streams."""
@staticmethod
def forward(ctx: Context, # type: ignore
prev_stream: AbstractStream,
Expand Down Expand Up @@ -75,7 +75,12 @@ def backward(ctx: Context,


class Wait(torch.autograd.Function):
"""Synchronizes a stream to another stream."""
"""Synchronizes a stream to another stream.
Place it just before you want to start an operation on the next stream,
provided that all operations on the previous stream are done.
"""
@staticmethod
def forward(ctx: Context, # type: ignore
prev_stream: AbstractStream,
Expand Down
8 changes: 4 additions & 4 deletions torchgpipe/gpipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ def split_module(module: nn.Sequential,


class GPipe(Module):
"""Wraps an arbitrary :class:`~torch.nn.Sequential` module to train on
GPipe_. If the module requires lots of memory, GPipe will be very
efficient::
"""Wraps an arbitrary :class:`nn.Sequential <torch.nn.Sequential>` module
to train on GPipe_. If the module requires lots of memory, GPipe will be
very efficient::
model = nn.Sequential(a, b, c, d)
model = GPipe(model, balance=[1, 1, 1, 1], chunks=8)
Expand Down Expand Up @@ -164,7 +164,7 @@ class GPipe(Module):
Raises:
TypeError:
the module is not a :class:`~torch.nn.Sequential`.
the module is not a :class:`nn.Sequential <torch.nn.Sequential>`.
ValueError:
invalid arguments, or wrong balance
IndexError:
Expand Down
50 changes: 36 additions & 14 deletions torchgpipe/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def fence(self, schedule: List[Tuple[int, int]]) -> None:
copy_streams = self.copy_streams

for i, j in schedule:
# Ensure that batches[j-1] is executed after batches[j] in
# backpropagation by an explicit dependency.
if j != 0:
depend(batches[j-1], batches[j])

Expand All @@ -118,11 +120,7 @@ def compute(self,
in_queues: List[InQueue],
out_queues: List[OutQueue],
) -> None:
"""Run tasks with synchronization to copy streams. A task consists of
"compute" and "finalize". "compute" is executed on worker threads
parallelly. "finalize" is executed when all worker threads complete to
execute "compute".
"""
"""Runs tasks with synchronization to copy streams."""
batches = self.batches
partitions = self.partitions
devices = self.devices
Expand All @@ -133,16 +131,41 @@ def compute(self,
streams = [current_stream(d) for d in devices]
exc_info: Optional[ExcInfo] = None

# With checkpointing, the autograd graph looks like this diagram:
# ┌─────┸──────┐
# │ Copy │
# └─────┰──────┘ (fence)
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┃ (compute)
# ┌─────┸──────┐
# │ Wait │ [1] Synchronize the current stream with the copy stream.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Checkpoint │ [2] Compute a partition within checkpointing.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Wait │ [3] Synchronize the copy stream with the current stream.
# └─────┰──────┘
# ┠ ─ ─ ─ ┐
# ┃ ┌─────┴─────┐
# ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
# ┃ └─────┬─────┘
# ┠ ─ ─ ─ ┘
# ┃
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┌─────┸──────┐ (fence)
# │ Copy │
# └─────┰──────┘
for i, j in schedule:
batch = batches[j]
partition = partitions[i]
device = devices[i]

# 1. Synchronize the current stream with the copy stream.
# Synchronize with the copied input. ([1] in the diagram)
if i != 0:
wait(batch, copy_streams[i][j], streams[i])

# 2. Determine whether checkpointing or not.
# Determine whether checkpointing or not.
checkpoint = (j < checkpoint_stop)
if checkpoint:
chk = Checkpointing(partition, batch)
Expand All @@ -155,7 +178,7 @@ def compute(batch: Batch = batch, partition: nn.Sequential = partition) -> Batch
task = Task(device, streams[i], compute=compute, finalize=None)
del compute

# 3. Compute tasks in parallel.
# Compute tasks in parallel. ([2] in the diagram)
in_queues[i].put(task)

for i, j in schedule:
Expand All @@ -170,15 +193,14 @@ def compute(batch: Batch = batch, partition: nn.Sequential = partition) -> Batch

task, batch = cast(Tuple[Task, Batch], payload)

# 4. Synchronize the copy stream with the current stream.
# The copy stream synchronizes to copy the output. ([3] in the
# diagram)
if i != n-1:
wait(batch, streams[i], copy_streams[i][j])

# 5. Finalize tasks.
#
# If checkpointing is enabled, here the recomputation is
# scheduled at backpropagation.
#
# Finalize tasks. If checkpointing is enabled, here the
# recomputation is scheduled at backpropagation. ([4] in the
# diagram)
task.finalize(batch)

batches[j] = batch
Expand Down

0 comments on commit 1e1e276

Please sign in to comment.