-
Notifications
You must be signed in to change notification settings - Fork 268
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add SlowMo Distributed Data Parallel for clusters with slow interconnects Co-authored-by: Vinayak Tantia <tantia.vinayak1@gmail.com>
- Loading branch information
1 parent
8347c1a
commit 21464e0
Showing
26 changed files
with
3,038 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ test-results/ | |
# Environments | ||
.env | ||
.venv | ||
.vscode | ||
env/ | ||
venv/ | ||
ENV/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
SlowMo Distributed Data Parallel | ||
================================ | ||
|
||
.. autoclass:: fairscale.experimental.nn.data_parallel.SlowMoDistributedDataParallel | ||
:members: | ||
:undoc-members: | ||
:exclude-members: eval, forward, load_state_dict, state_dict, train, training |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
SlowMo Distributed Data Parallel | ||
================================ | ||
|
||
Training neural networks in a distributed data-parallel manner results in non-linear scaling (slowdown) due to the time spent on communication | ||
between the different nodes (as well as, to a lesser extent though, synchronization between the different nodes). So, a distributed training run | ||
with 8 nodes is not 8x faster than a run with 1 node as we would expect it to be. | ||
|
||
SlowMo Distributed Data Parallel aims to solve this by replacing the typical exact allreduce between gradients with an approximate | ||
averaging of parameters. This approximate averaging reduces both the time spent on communication as well as the synchronization between different | ||
nodes. It uses one of the following two algorithms (configurable) as a base algorithm for this purpose: | ||
|
||
* Local SGD (papers `#1 <https://arxiv.org/abs/1602.05629>`_ and `#2 <https://arxiv.org/abs/1705.09056>`_). This algorithm does an allreduce of the parameters every few iterations. | ||
|
||
* `Stochastic Gradient Push <https://arxiv.org/abs/1811.10792>`_ (SGP). This algorithm involves one-to-one communications between nodes. | ||
|
||
These base algorithms (LocalSGD and SGP), when used only by themselves, result in reduced model quality (measured as accuracy in a classification | ||
setting). The `SlowMo <https://arxiv.org/abs/1910.00643>`_ algorithm alleviates this issue by doing a slow momentum step, typically, every 48 iterations. | ||
|
||
The training process with SlowMo looks as follows: | ||
|
||
1. Compute the forward pass. | ||
|
||
2. Compute the backward pass. | ||
|
||
3. During the backward pass, using a backward hook, on each node, the gradients are synchronized using allreduce across the different GPUs on | ||
that node. | ||
|
||
4. Perform the ``optimizer.step()`` to update parameters on each node with the gradients of that node. | ||
|
||
5. Approximately average the parameters using a base algorithm - one of LocalSGD or SGP (both are described above). | ||
|
||
6. Perform the slow momentum update step once every ``slowmo_frequency`` (typically 48) iterations. In this step, the parameters on different | ||
nodes are (exactly) averaged, followed by a ``slowmo_optimizer.step()``. Note that this ``slowmo_optimizer`` is different from the original optimizer, | ||
and it is done in a `Zero-1 <./oss_sdp_fsdp.html>`_ like manner to save memory. | ||
|
||
Best practices for using ``SlowMoDistributedDataParallel`` | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
1. SlowMo will be useful in deep learning workloads which run on more than 2 nodes in clusters with a slow interconnect, eg Ethernet. | ||
|
||
2. SlowMo should be useful in your workload if the following condition holds: | ||
|
||
:math:`\textrm{time_taken_for_all_reduce_of_gradients} \times (1 - \frac{1}{\textrm{localsgd_frequency}} ) > \textrm{time_taken_for_backward_pass}` | ||
|
||
Notes: | ||
|
||
* In case you are using SGP as the base algorithm, the value of ``localsgd_frequency`` can be plugged in as 2. | ||
|
||
* The formula above is a simplified version of: | ||
:math:`\textrm{time_taken_for_all_reduce_of_gradients} > \textrm{time_taken_for_backward_pass} + \frac{\textrm{time_taken_for_all_reduce_of_gradients}}{\textrm{localsgd_frequency}}` | ||
The left and right hand sides denote the total backward duration (combining the computation of gradients in the backward pass and the | ||
communication cost) for DDP and SlowMo DDP, respectively. Since DDP overlaps the computation of gradients with their communication, it is | ||
bottlenecked by the latter. In contrast, there is an extra ``time_taken_for_backward_pass`` on the right hand side because we do not | ||
overlap the backward pass with communication in the current implementation of SlowMo. | ||
|
||
* In clusters with slower interconnect, ``time_taken_for_all_reduce_of_gradients`` will go up, leading to SlowMo being more useful. ``localsgd_frequency`` | ||
is also an important factor here. More details on varying that to affect performance are in tip 2 of | ||
`Performance tips for SlowMoDistributedDataParallel`_. | ||
|
||
3. ``slowmo_momentum`` will need to be tuned for obtaining good model quality. A grid search across {0.0, 0.1, 0.2, 0.4, 0.6} should be good enough | ||
for tuning. This ``slowmo_momentum`` value holds consistent across multiple runs with similar settings. When the number of nodes used is increased, | ||
however, a higher value of ``slow_momentum`` should be needed. More details about this can be found in the | ||
`documentation <../api/experimental/nn/slowmo_ddp.html>`_. | ||
|
||
4. Adding SlowMo to existing Distributed Data Parallel code involves two steps, which can be found in the `tutorial <../tutorials/slowmo_ddp.html>`_. | ||
|
||
Performance tips for ``SlowMoDistributedDataParallel`` | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
1. ``nprocs_per_node`` should be set to the number of GPUs on a node (this number should be the same on each node). This allows the API | ||
to exploit the fast interconnect between different GPUs on a node. | ||
|
||
2. Increasing the ``localsgd_frequency`` results in an increase in speed. However, it comes with a tradeoff of reducing the model quality. | ||
We recommend keeping the ``localsgd_frequency`` at 3. | ||
|
||
3. ``slowmo_memory_efficient`` should typically be used (this is the default behavior). It reduces memory usage by sharding the additional | ||
slow momentum optimizer's parameters in a `Zero-1`_ like manner. | ||
|
||
4. A call to ``model.zero_grad(set_to_none=True)`` should be made after ``optimizer.step()`` in order to save memory for the | ||
``model.perform_slowmo()`` step. More details about this can be found in the | ||
`documentation for perform_slowmo() <../api/experimental/nn/slowmo_ddp.html#:~:text=net.perform_slowmo(optimizer)-,perform_slowmo,-(optimizer%3A%20torch.optim>`_. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
Efficient Data Parallel Training with SlowMo Distributed Data Parallel | ||
====================================================================== | ||
|
||
SlowMo Distributed Data Parallel reduces the communication between different | ||
nodes while performing data parallel training. It is mainly useful for use on | ||
clusters with low interconnect speeds between different nodes. When using | ||
SlowMo, the models on the different nodes are no longer kept in sync after each | ||
iteration, which leads to the optimization dynamics being affected. The end | ||
result is close to the results of Distributed Data Parallel, but is not exactly | ||
the same. | ||
|
||
If you have code that is setup to use Distributed Data Parallel, using SlowMo Distributed Data Parallel | ||
is simply replacing the DDP call with a call to | ||
``fairscale.experimental.nn.data_parallel.SlowMoDistributedDataParallel``, and adding a | ||
``model.perform_slowmo(optimizer)`` call after ``optimizer.step()`` -- preceded by | ||
``model.zero_grad(set_to_none=True)`` in order to reduce peak memory usage. | ||
The different points at which ``use_slowmo`` is used below help demonstrate these changes: | ||
|
||
.. code-block:: python | ||
import torch | ||
from fairscale.experimental.nn.data_parallel import SlowMoDistributedDataParallel as SlowMoDDP | ||
def train( | ||
rank: int, | ||
world_size: int, | ||
epochs: int, | ||
use_slowmo: bool): | ||
# process group init | ||
dist_init(rank, world_size) | ||
# Problem statement | ||
model = MyAwesomeModel().to(rank) | ||
if use_slowmo: | ||
# Wrap the model into SlowMoDDP | ||
model = SlowMoDDP(model, slowmo_momentum=0.5, nprocs_per_node=8) | ||
else: | ||
model = DDP(model, device_ids=[rank]) | ||
dataloader = MySuperFastDataloader() | ||
loss_ln = MyVeryRelevantLoss() | ||
optimizer = MyAmazingOptimizer() | ||
# Any relevant training loop, with a line at the very end specific to SlowMoDDP, e.g.: | ||
model.train() | ||
for e in range(epochs): | ||
for (data, target) in dataloader: | ||
data, target = data.to(rank), target.to(rank) | ||
# Train | ||
outputs = model(data) | ||
loss = loss_fn(outputs, target) | ||
loss.backward() | ||
optimizer.step() | ||
model.zero_grad(set_to_none=use_slowmo) # free memory for the perform_slowmo() call below | ||
if use_slowmo: | ||
model.perform_slowmo(optimizer) | ||
In the example above, when using SlowMoDDP, we are reducing the total communication between | ||
nodes by 3 times as the default ``localsgd_frequency`` is set to 3. | ||
SlowMoDDP takes in ``slowmo_momentum`` as a parameter. This parameter may need to be tuned | ||
depending on your use case. It also takes in ``nproces_per_node`` which should be typically set | ||
to the number of GPUs on a node. Please look at the | ||
`documentation <../api/experimental/nn/slowmo_ddp.html>`_ | ||
for more details on these parameters as well as other advanced settings of the SlowMo algorithm. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from .gossip import SlowMoBaseAlgorithm, SlowMoDistributedDataParallel # noqa |
19 changes: 19 additions & 0 deletions
19
fairscale/experimental/nn/data_parallel/gossip/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from .distributed import SlowMoBaseAlgorithm, SlowMoDistributedDataParallel | ||
from .gossiper import PushPull, PushSum | ||
from .graph_manager import ( | ||
DynamicBipartiteExponentialGraph, | ||
DynamicBipartiteLinearGraph, | ||
DynamicDirectedExponentialGraph, | ||
DynamicDirectedLinearGraph, | ||
GraphManager, | ||
NPeerDynamicDirectedExponentialGraph, | ||
RingGraph, | ||
) | ||
from .mixing_manager import MixingManager, UniformMixing | ||
from .utils import communicate | ||
from .utils.cuda_metering import CudaEventRecorder |
Oops, something went wrong.