Skip to content

Commit

Permalink
[feat] Gossip/SlowMo (#378)
Browse files Browse the repository at this point in the history
Add SlowMo Distributed Data Parallel for clusters with slow interconnects

Co-authored-by: Vinayak Tantia <tantia.vinayak1@gmail.com>
  • Loading branch information
blefaudeux and vtantia committed Nov 8, 2021
1 parent 8347c1a commit 21464e0
Show file tree
Hide file tree
Showing 26 changed files with 3,038 additions and 21 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ test-results/
# Environments
.env
.venv
.vscode
env/
venv/
ENV/
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
and gradient memory to be sharded despite being needed from different layers due to
weight sharing. [#836]
- [MEVO]: a custom layer to help big vocab trainings. Experimental. Docs is still TBD. [#840]
- SlowMoDistributedDataParallel[feature][experimental] - This is a distributed training wrapper which should be useful on clusters with slow network interconnects (eg Ethernet). This improves on performance as compared to Distributed Data Parallel in such clusters. [#378]

## [0.4.1] - 2021-09-17
### Fixed
Expand Down
7 changes: 7 additions & 0 deletions docs/source/api/experimental/nn/slowmo_ddp.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SlowMo Distributed Data Parallel
================================

.. autoclass:: fairscale.experimental.nn.data_parallel.SlowMoDistributedDataParallel
:members:
:undoc-members:
:exclude-members: eval, forward, load_state_dict, state_dict, train, training
1 change: 1 addition & 0 deletions docs/source/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ API Reference
nn/fsdp
nn/checkpoint/checkpoint_activations
experimental/nn/offload_model
experimental/nn/slowmo_ddp
13 changes: 13 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,19 @@
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True

# List of custom sections allowed. It is especially useful when the argument
# list is very long for a constructor or function. This helps split the
# arguments into different sections, helping us to understand the arguments
# better.
napoleon_custom_sections = [
("SlowMo Parameters", "params_style"),
("LocalSGD Parameters", "params_style"),
("SGP Parameters", "params_style"),
("Debugging Parameters", "params_style"),
("Parameters for Advanced Users", "params_style"),
]


# -- Options for HTML output -------------------------------------------------


Expand Down
81 changes: 81 additions & 0 deletions docs/source/deep_dive/slowmo_ddp.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
SlowMo Distributed Data Parallel
================================

Training neural networks in a distributed data-parallel manner results in non-linear scaling (slowdown) due to the time spent on communication
between the different nodes (as well as, to a lesser extent though, synchronization between the different nodes). So, a distributed training run
with 8 nodes is not 8x faster than a run with 1 node as we would expect it to be.

SlowMo Distributed Data Parallel aims to solve this by replacing the typical exact allreduce between gradients with an approximate
averaging of parameters. This approximate averaging reduces both the time spent on communication as well as the synchronization between different
nodes. It uses one of the following two algorithms (configurable) as a base algorithm for this purpose:

* Local SGD (papers `#1 <https://arxiv.org/abs/1602.05629>`_ and `#2 <https://arxiv.org/abs/1705.09056>`_). This algorithm does an allreduce of the parameters every few iterations.

* `Stochastic Gradient Push <https://arxiv.org/abs/1811.10792>`_ (SGP). This algorithm involves one-to-one communications between nodes.

These base algorithms (LocalSGD and SGP), when used only by themselves, result in reduced model quality (measured as accuracy in a classification
setting). The `SlowMo <https://arxiv.org/abs/1910.00643>`_ algorithm alleviates this issue by doing a slow momentum step, typically, every 48 iterations.

The training process with SlowMo looks as follows:

1. Compute the forward pass.

2. Compute the backward pass.

3. During the backward pass, using a backward hook, on each node, the gradients are synchronized using allreduce across the different GPUs on
that node.

4. Perform the ``optimizer.step()`` to update parameters on each node with the gradients of that node.

5. Approximately average the parameters using a base algorithm - one of LocalSGD or SGP (both are described above).

6. Perform the slow momentum update step once every ``slowmo_frequency`` (typically 48) iterations. In this step, the parameters on different
nodes are (exactly) averaged, followed by a ``slowmo_optimizer.step()``. Note that this ``slowmo_optimizer`` is different from the original optimizer,
and it is done in a `Zero-1 <./oss_sdp_fsdp.html>`_ like manner to save memory.

Best practices for using ``SlowMoDistributedDataParallel``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

1. SlowMo will be useful in deep learning workloads which run on more than 2 nodes in clusters with a slow interconnect, eg Ethernet.

2. SlowMo should be useful in your workload if the following condition holds:

:math:`\textrm{time_taken_for_all_reduce_of_gradients} \times (1 - \frac{1}{\textrm{localsgd_frequency}} ) > \textrm{time_taken_for_backward_pass}`

Notes:

* In case you are using SGP as the base algorithm, the value of ``localsgd_frequency`` can be plugged in as 2.

* The formula above is a simplified version of:
:math:`\textrm{time_taken_for_all_reduce_of_gradients} > \textrm{time_taken_for_backward_pass} + \frac{\textrm{time_taken_for_all_reduce_of_gradients}}{\textrm{localsgd_frequency}}`
The left and right hand sides denote the total backward duration (combining the computation of gradients in the backward pass and the
communication cost) for DDP and SlowMo DDP, respectively. Since DDP overlaps the computation of gradients with their communication, it is
bottlenecked by the latter. In contrast, there is an extra ``time_taken_for_backward_pass`` on the right hand side because we do not
overlap the backward pass with communication in the current implementation of SlowMo.

* In clusters with slower interconnect, ``time_taken_for_all_reduce_of_gradients`` will go up, leading to SlowMo being more useful. ``localsgd_frequency``
is also an important factor here. More details on varying that to affect performance are in tip 2 of
`Performance tips for SlowMoDistributedDataParallel`_.

3. ``slowmo_momentum`` will need to be tuned for obtaining good model quality. A grid search across {0.0, 0.1, 0.2, 0.4, 0.6} should be good enough
for tuning. This ``slowmo_momentum`` value holds consistent across multiple runs with similar settings. When the number of nodes used is increased,
however, a higher value of ``slow_momentum`` should be needed. More details about this can be found in the
`documentation <../api/experimental/nn/slowmo_ddp.html>`_.

4. Adding SlowMo to existing Distributed Data Parallel code involves two steps, which can be found in the `tutorial <../tutorials/slowmo_ddp.html>`_.

Performance tips for ``SlowMoDistributedDataParallel``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

1. ``nprocs_per_node`` should be set to the number of GPUs on a node (this number should be the same on each node). This allows the API
to exploit the fast interconnect between different GPUs on a node.

2. Increasing the ``localsgd_frequency`` results in an increase in speed. However, it comes with a tradeoff of reducing the model quality.
We recommend keeping the ``localsgd_frequency`` at 3.

3. ``slowmo_memory_efficient`` should typically be used (this is the default behavior). It reduces memory usage by sharding the additional
slow momentum optimizer's parameters in a `Zero-1`_ like manner.

4. A call to ``model.zero_grad(set_to_none=True)`` should be made after ``optimizer.step()`` in order to save memory for the
``model.perform_slowmo()`` step. More details about this can be found in the
`documentation for perform_slowmo() <../api/experimental/nn/slowmo_ddp.html#:~:text=net.perform_slowmo(optimizer)-,perform_slowmo,-(optimizer%3A%20torch.optim>`_.
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ modules and easy to use APIs.
deep_dive/adascale
deep_dive/pipeline_parallelism
deep_dive/activation_checkpointing
deep_dive/slowmo_ddp

|
|
Expand All @@ -56,6 +57,7 @@ modules and easy to use APIs.
tutorials/adascale
tutorials/pipe
tutorials/layer_memory_tracking
tutorials/slowmo_ddp

|
|
Expand Down
67 changes: 67 additions & 0 deletions docs/source/tutorials/slowmo_ddp.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
Efficient Data Parallel Training with SlowMo Distributed Data Parallel
======================================================================

SlowMo Distributed Data Parallel reduces the communication between different
nodes while performing data parallel training. It is mainly useful for use on
clusters with low interconnect speeds between different nodes. When using
SlowMo, the models on the different nodes are no longer kept in sync after each
iteration, which leads to the optimization dynamics being affected. The end
result is close to the results of Distributed Data Parallel, but is not exactly
the same.

If you have code that is setup to use Distributed Data Parallel, using SlowMo Distributed Data Parallel
is simply replacing the DDP call with a call to
``fairscale.experimental.nn.data_parallel.SlowMoDistributedDataParallel``, and adding a
``model.perform_slowmo(optimizer)`` call after ``optimizer.step()`` -- preceded by
``model.zero_grad(set_to_none=True)`` in order to reduce peak memory usage.
The different points at which ``use_slowmo`` is used below help demonstrate these changes:

.. code-block:: python
import torch
from fairscale.experimental.nn.data_parallel import SlowMoDistributedDataParallel as SlowMoDDP
def train(
rank: int,
world_size: int,
epochs: int,
use_slowmo: bool):
# process group init
dist_init(rank, world_size)
# Problem statement
model = MyAwesomeModel().to(rank)
if use_slowmo:
# Wrap the model into SlowMoDDP
model = SlowMoDDP(model, slowmo_momentum=0.5, nprocs_per_node=8)
else:
model = DDP(model, device_ids=[rank])
dataloader = MySuperFastDataloader()
loss_ln = MyVeryRelevantLoss()
optimizer = MyAmazingOptimizer()
# Any relevant training loop, with a line at the very end specific to SlowMoDDP, e.g.:
model.train()
for e in range(epochs):
for (data, target) in dataloader:
data, target = data.to(rank), target.to(rank)
# Train
outputs = model(data)
loss = loss_fn(outputs, target)
loss.backward()
optimizer.step()
model.zero_grad(set_to_none=use_slowmo) # free memory for the perform_slowmo() call below
if use_slowmo:
model.perform_slowmo(optimizer)
In the example above, when using SlowMoDDP, we are reducing the total communication between
nodes by 3 times as the default ``localsgd_frequency`` is set to 3.
SlowMoDDP takes in ``slowmo_momentum`` as a parameter. This parameter may need to be tuned
depending on your use case. It also takes in ``nproces_per_node`` which should be typically set
to the number of GPUs on a node. Please look at the
`documentation <../api/experimental/nn/slowmo_ddp.html>`_
for more details on these parameters as well as other advanced settings of the SlowMo algorithm.
6 changes: 6 additions & 0 deletions fairscale/experimental/nn/data_parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from .gossip import SlowMoBaseAlgorithm, SlowMoDistributedDataParallel # noqa
19 changes: 19 additions & 0 deletions fairscale/experimental/nn/data_parallel/gossip/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from .distributed import SlowMoBaseAlgorithm, SlowMoDistributedDataParallel
from .gossiper import PushPull, PushSum
from .graph_manager import (
DynamicBipartiteExponentialGraph,
DynamicBipartiteLinearGraph,
DynamicDirectedExponentialGraph,
DynamicDirectedLinearGraph,
GraphManager,
NPeerDynamicDirectedExponentialGraph,
RingGraph,
)
from .mixing_manager import MixingManager, UniformMixing
from .utils import communicate
from .utils.cuda_metering import CudaEventRecorder

0 comments on commit 21464e0

Please sign in to comment.