Skip to content

Commit

Permalink
Proofread docs
Browse files Browse the repository at this point in the history
  • Loading branch information
sublee committed Nov 29, 2019
1 parent 636baa6 commit ad9c82b
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 13 deletions.
8 changes: 8 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ API
GPipe Module
~~~~~~~~~~~~

.. py:module:: torchgpipe
.. autoclass:: torchgpipe.GPipe(module, balance, \**kwargs)

.. automethod:: forward(input)
Expand All @@ -23,6 +25,8 @@ GPipe Module
Skip Connections
~~~~~~~~~~~~~~~~

.. py:module:: torchgpipe.skip
.. autodecorator:: torchgpipe.skip.skippable([stash], [pop])

.. automethod:: torchgpipe.skip.skippable.Skippable.isolate(ns, [only=names])
Expand All @@ -42,9 +46,13 @@ Inspecting GPipe Timeline

.. autofunction:: torchgpipe.is_recomputing()

.. _torchgpipe.balance:

Automatic Balancing
~~~~~~~~~~~~~~~~~~~

.. py:module:: torchgpipe.balance
.. autofunction:: torchgpipe.balance.balance_by_time(partitions, module, sample, timeout=1.0, device=torch.device('cuda'))

.. autofunction:: torchgpipe.balance.balance_by_size(partitions, module, input, chunks=1, param_scale=2.0, device=torch.device('cuda'))
17 changes: 9 additions & 8 deletions docs/guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,14 @@ Checkpointing drastically helps to reduce memory usage, but the overall
training would slow down by about 25%. You can handle how to apply
checkpointing on your model. There are three options:

- ``always`` -- Apply checkpointing over all micro-batches.
- ``except_last`` (default) -- Apply checkpointing except the last micro-batch.
- ``never`` -- Checkpointing is never applied.
- ``'always'`` -- Apply checkpointing over all micro-batches.
- ``'except_last'`` (default) -- Apply checkpointing except the last
micro-batch.
- ``'never'`` -- Checkpointing is never applied.

Usually, checkpointing at the last micro-batch may not be useful because the
saved memory will be reconstructed immediately. That's why we choose
``except_last`` as the default option.
``'except_last'`` as the default option.

If you decide not to use checkpointing at all, :class:`nn.DataParallel
<torch.nn.DataParallel>` might be more efficient than GPipe.
Expand Down Expand Up @@ -240,7 +241,7 @@ Sequential:
a :class:`nn.Sequential <torch.nn.Sequential>` model.

.. _the sequential ResNet example:
https://github.com/kakaobrain/torchgpipe/tree/master/examples/resnet
https://github.com/kakaobrain/torchgpipe/tree/master/benchmarks/models/resnet

:class:`nn.Sequential <torch.nn.Sequential>` assumes that every underlying
layer takes only one argument. Calling ``forward(x)`` on
Expand All @@ -258,7 +259,7 @@ Sequential:
Tensor or Tensors:
As we discussed above, each layer must take only one argument due to
:class:`nn.Sequential <torch.nn.Sequential>`. There is one more restriction.
Every underlying layers' input and output must be ``Tensor`` or
Every underlying layers' input and output must be :class:`~torch.Tensor` or
``Tuple[Tensor, ...]``::

# OK
Expand Down Expand Up @@ -435,8 +436,8 @@ multiple skip tensors. However, there are restrictions:

Then, how can we instantiate multiple skippable modules from the same class in
a sequential module? You can isolate some skip names into a
:class:`~torch.skip.Namespace`. For example, a conceptual U-Net can be designed
like this. There are 3 pairs of ``Encoder`` and ``Decoder``::
:class:`~torchgpipe.skip.Namespace`. For example, a conceptual U-Net can be
designed like this. There are 3 pairs of ``Encoder`` and ``Decoder``::

# 1F. Encoder -------- Decoder -- Segment
# \ /
Expand Down
8 changes: 4 additions & 4 deletions torchgpipe/balance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def balance_by_time(partitions: int,
current CUDA device)
Returns:
A list of number of layers in each partition. Use it for the
``balance`` parameter of :class:`~torchgpipe.GPipe`.
A list of number of layers in each partition. Use it for the `balance`
parameter of :class:`~torchgpipe.GPipe`.
.. note::
`module` and `sample` must be placed on the same device.
Expand Down Expand Up @@ -145,8 +145,8 @@ def balance_by_size(partitions: int,
device)
Returns:
A list of number of layers in each partition. Use it for the
``balance`` parameter of :class:`~torchgpipe.GPipe`.
A list of number of layers in each partition. Use it for the `balance`
parameter of :class:`~torchgpipe.GPipe`.
.. note::
`module` and `input` must be placed on the same CUDA device.
Expand Down
2 changes: 1 addition & 1 deletion torchgpipe/gpipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ class GPipe(Module):
chunks: int = 1

#: The checkpoint mode to determine when to enable checkpointing. It is one
#: of ``always``, ``except_last``, or ``never``.
#: of ``'always'``, ``'except_last'``, or ``'never'``.
checkpoint: str = 'except_last'

def __init__(self,
Expand Down

0 comments on commit ad9c82b

Please sign in to comment.