Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Gossip/SlowMo #378

Merged
merged 123 commits into from
Nov 8, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
123 commits
Select commit Hold shift + click to select a range
268f2f8
Add latest version of gossip code from branch latest_master of vtanti…
vtantia Jan 5, 2021
f152379
Add code for importing GossipDataParallel in fairscale
vtantia Jan 5, 2021
bbeab4a
Add tests (currently in wrong location so will need to be moved)
vtantia Jan 5, 2021
4616722
Remove extra ad_psgd file
vtantia Jan 5, 2021
ed7b866
Add change in gitignore to ignore vscode config
vtantia Jan 5, 2021
89d865f
Perform formatting (black, isort, flake8)
vtantia Jan 5, 2021
8157603
Add scripts to load environment and format code
vtantia Jan 5, 2021
5d458f9
Add stubs for fairscale script
vtantia Jan 5, 2021
9fdd823
[Temp] Comment out a line in stubs to fix error message
vtantia Jan 5, 2021
3a09576
Remove remaining adpsgd code
vtantia Jan 6, 2021
96fec9e
Remove unnecessary function
vtantia Jan 19, 2021
d32d384
Add mypy typing to GossipDataParallel
vtantia Jan 19, 2021
015537f
Fix formatting
vtantia Jan 19, 2021
9b5aff7
Make format.sh a script
vtantia Jan 19, 2021
cc83b84
Make flaky test log message clearer
vtantia Jan 21, 2021
9c78976
Fix minor bug in mypy implementation
vtantia Jan 21, 2021
3068a34
Add tests for SGP
vtantia Jan 21, 2021
dbb4eb3
Minor mypy changes
vtantia Jan 21, 2021
4b4c373
Fix errors with multiple process groups by synchronizing appropriately
vtantia Jan 21, 2021
ab01f16
Remove deprecated file
vtantia Jan 21, 2021
993c6ff
Fix mypy in utils/helpers.py
vtantia Jan 21, 2021
5e30d5d
Finish mypy typing for distributed.py
vtantia Feb 2, 2021
00c1ff2
Add typing to and format test files
vtantia Feb 2, 2021
7d75ab2
Fix mypy errors including those for switching to Python 3.6
vtantia Feb 2, 2021
7c1e998
Temporary commit - cleaning up parameters
vtantia Feb 2, 2021
92aef32
Remove single process support to make code cleaner
vtantia Feb 2, 2021
0e6f6ea
Change localsgd to be set as an option
vtantia Feb 2, 2021
98f9d36
Refactor perform_additional_optimizer_actions function
vtantia Feb 2, 2021
9427e12
Clean up
vtantia Feb 2, 2021
b1c66c7
Factor out sgp_int
vtantia Feb 2, 2021
43efe01
Add temporary comments to prevent auto-formatting of argument separation
vtantia Feb 9, 2021
70ab95f
Rename sgp functions. Move sgp and slowmo functions together
vtantia Feb 9, 2021
383dfbc
Factorize creation of process groups in SlowMo
vtantia Feb 9, 2021
8f8a275
Remove extra variable
vtantia Feb 9, 2021
d5d4108
Change default value of localsgd_frequency to 3
vtantia Feb 9, 2021
9b99cbc
Factorize initialization of process groups
vtantia Feb 9, 2021
24bd02c
Minor name change
vtantia Feb 9, 2021
aa63481
Minor formatting change
vtantia Feb 9, 2021
242272a
Add a todo
vtantia Feb 9, 2021
2380783
Make distributed_broadcast_coalesced more generalizable
vtantia Feb 9, 2021
b67ef2c
Fix pre-commit errors (mainly mypy)
vtantia Feb 9, 2021
7c57b58
Formatting changes in scripts
vtantia Feb 9, 2021
01b34c3
Missed renaming change
vtantia Feb 9, 2021
b43d859
Precommit formatting
vtantia Feb 9, 2021
1c6549b
Add changes for fairseq fp16 optimizer
vtantia Feb 9, 2021
49db45e
Change slowmo_world_size to slowmo_num_shards
vtantia Feb 9, 2021
33a39bb
Fix flaky test and change parameter names
vtantia Feb 9, 2021
84fe38d
Fix minor bugs
vtantia Feb 9, 2021
669c90b
Fairscale pyproject change. Not sure why this happens
vtantia Feb 9, 2021
aed0595
Add a no sharding version of SlowMo. Add tests for the no sharding ve…
vtantia Feb 10, 2021
65d3861
Clean up SGP conditions
vtantia Feb 10, 2021
ed8b219
minor tweaks, seems to run fine
blefaudeux Feb 11, 2021
2bbc373
lint
blefaudeux Feb 11, 2021
e6c1b7f
Merge branch 'master' into slowmo_ben
blefaudeux Feb 11, 2021
2d7eff3
removing some changes which slipped in
blefaudeux Feb 11, 2021
ebfc864
changing the cudnn deterministic setting, seems that running all test…
blefaudeux Feb 11, 2021
898cc55
moving all the tests to pytest, would probably need a second cleanup …
blefaudeux Feb 13, 2021
ee5f94c
fix an assert on a parameter list
blefaudeux Feb 14, 2021
1675f39
small test refactor, not perfect but a bit more redeable I presume
blefaudeux Feb 16, 2021
79ea7f8
does not look like setting files manually is a good idea
blefaudeux Feb 16, 2021
70e40bb
destroy process groups when done
blefaudeux Feb 17, 2021
1b74cc5
fixing unit tests firing consecutive process groups
blefaudeux Feb 19, 2021
152e004
Formatting changes
vtantia Feb 9, 2021
61d501f
Changes in documentation
vtantia Feb 10, 2021
94c6757
Add documentation for slowmo_memory_efficient
vtantia Feb 19, 2021
589a609
Make private methods start with underscore. Minor name changes
vtantia Feb 19, 2021
c10ada7
Move sgp related functions together
vtantia Feb 19, 2021
d0dece9
Minor flake8 fix
vtantia Feb 19, 2021
b9a7d8a
Remove enum SlowmoBaseAlgorithm. Use string instead
vtantia Feb 20, 2021
ecb558c
Remove extra parameter
vtantia Feb 20, 2021
4fd8be9
Change license header on all the files
vtantia Feb 20, 2021
ce50089
Rename function
vtantia Feb 20, 2021
8f7dc6e
Add tutorial for slowmo (very slightly modified from tutorial_oss.py)
vtantia Feb 20, 2021
c13e287
Fix broken tests on > 2 GPU machines
vtantia Feb 20, 2021
5d6dc69
Add SlowMo to init
vtantia Feb 20, 2021
3fa7593
Remove extra imports
vtantia Feb 20, 2021
cba3829
Minor addition missed 2 commits before
vtantia Feb 20, 2021
75b8cba
moving gossip to experimental
blefaudeux Mar 10, 2021
4ee9577
Merge branch 'master' into slowmo_ben
blefaudeux Mar 10, 2021
451cf6d
removing a change which slipped in
blefaudeux Mar 10, 2021
ef86bb1
Merge branch 'main' into slowmo_ben
blefaudeux Oct 18, 2021
b4a798f
code review + fixing an issue with model parallel tests
blefaudeux Oct 18, 2021
43ac702
removing private torch variable which seemed broken on nightly
blefaudeux Oct 18, 2021
76e87b4
addressing some more comments
blefaudeux Oct 18, 2021
fa214b7
tentatively debugging the unit tests, the interface is not too nice
blefaudeux Oct 19, 2021
1bd1b71
Fix a couple of bugs related to spawning processes
vtantia Oct 22, 2021
06f5af2
Fix a bug by ensuring that data is the same on all GPUs at setup time
vtantia Oct 22, 2021
5d025d0
Resolve comments on PR - misc
vtantia Oct 28, 2021
8fc366b
Resolve comments on PR - break rank and world_size into 2 variables
vtantia Oct 28, 2021
c89bdc9
Refactor to clean up _maybe_create_process_groups
vtantia Oct 28, 2021
1110eff
Fix non-deterministic behaviour in a clean way
vtantia Oct 28, 2021
7bf8017
Merge branch 'main' into slowmo_ben
vtantia Oct 28, 2021
da5357d
Fix bug by removing residual option
vtantia Oct 29, 2021
4d29165
Migrate list to deque to prevent future memory leak
vtantia Oct 29, 2021
e1aca67
Address PR comments
vtantia Oct 29, 2021
20f55ed
Minor formatting fixes
vtantia Oct 29, 2021
8d98b4d
Change slowmo_base_algorithm from string to Enum
vtantia Oct 29, 2021
cc3e829
Remove extra cast in the code
vtantia Oct 29, 2021
f3d91ec
Address PR comments
vtantia Oct 29, 2021
3eeef73
Update documentation to include SlowMo. Add tutorial. Remove tutorial…
vtantia Oct 27, 2021
e2a9d13
Modify docs to add custom sections
vtantia Nov 2, 2021
318dd91
Adress comments in PR in docs and tutorials
vtantia Nov 2, 2021
f83d5ad
Convert class and methods to abstract to address PR review
vtantia Nov 2, 2021
fb8383d
Adress further comments in PR in docs and tutorials
vtantia Nov 2, 2021
e19cc2a
Fix minor typo
vtantia Nov 2, 2021
da5bb69
Fix backticks linter error
vtantia Nov 2, 2021
ebe1196
Minor refactor - Rename an argument to remove Sphinx error
vtantia Nov 2, 2021
a19323e
Minor renaming in docs
vtantia Nov 3, 2021
306dbef
Merge branch 'main' into slowmo_ben
vtantia Nov 3, 2021
39383bd
Minor addition to CHANGELOG.md
vtantia Nov 3, 2021
5bd07f9
Merge branch 'main' into slowmo_ben
vtantia Nov 3, 2021
c6d0273
Add deep dive for SlowMo
vtantia Nov 4, 2021
b59a835
Modify deep dive and tutorial to address recommendations in code review
vtantia Nov 5, 2021
d9765a7
Minor refactor - name change
vtantia Nov 5, 2021
122e082
Modify deep dive to make condition for using SlowMo clearer
vtantia Nov 5, 2021
b325371
MModification to CHANGELOG.md to address review comments
vtantia Nov 5, 2021
45830c1
Add changes in documentation to address code review
vtantia Nov 5, 2021
c7242de
Fix minor linter error
vtantia Nov 5, 2021
22efbaa
Fix missing parameter in docs
vtantia Nov 5, 2021
68ff8f1
Fix link in docs
vtantia Nov 5, 2021
d0d94d0
Fix missing parameter in docs
vtantia Nov 5, 2021
67f6003
Modification to tutorials to address code review comments
vtantia Nov 8, 2021
9cf9153
Merge branch 'main' into slowmo_ben
vtantia Nov 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
41 changes: 22 additions & 19 deletions docs/source/deep_dive/slowmo_ddp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,61 +5,64 @@ Training neural networks in a distributed data-parallel manner results in non-li
between the different nodes (as well as, to a lesser extent though, synchronization between the different nodes). So, a distributed training run
with 8 nodes is not 8x faster than a run with 1 node as we would expect it to be.

SlowMo Distributed Data Parallel aims to solve this by replacing the exact allreduce between gradients, which is typically done, with an approximate
SlowMo Distributed Data Parallel aims to solve this by replacing the typical exact allreduce between gradients with an approximate
averaging of parameters. This approximate averaging reduces both the time spent on communication as well as the synchronization between different
nodes. It uses one of the following two algorithms (configurable) as a base algorithm for this purpose -
nodes. It uses one of the following two algorithms (configurable) as a base algorithm for this purpose:

* `Local <https://arxiv.org/abs/1602.05629>`_ `SGD <https://arxiv.org/abs/1705.09056>`_. This algorithm does an allreduce of the parameters every few iterations.
* Local SGD (papers `#1 <https://arxiv.org/abs/1602.05629>`_ and `#2 <https://arxiv.org/abs/1705.09056>`_). This algorithm does an allreduce of the parameters every few iterations.

* `Stochastic Gradient Push <https://arxiv.org/abs/1811.10792>`_ (SGP). This algorithm involves one-to-one communications between nodes.

These base algorithms (LocalSGD and SGP), when used only by themselves, result in reduced accuracy. The `SlowMo <https://arxiv.org/abs/1910.00643>`_
algorithm removes this accuracy loss by doing a slow momentum step, typically, every 48 iterations.
These base algorithms (LocalSGD and SGP), when used only by themselves, result in reduced model quality (measured as accuracy in a classification
setting). The `SlowMo <https://arxiv.org/abs/1910.00643>`_ algorithm alleviates this issue by doing a slow momentum step, typically, every 48 iterations.

The training process with SlowMo looks as follows:

1. Compute the forward pass.

2. Compute the backward pass.

3. During the backward pass, using a backward hook, the gradients are synchronized using allreduce across the different GPUs on a node.
3. During the backward pass, using a backward hook, on each node, the gradients are synchronized using allreduce across the different GPUs on
that node.

4. Perform the ``optimizer.step()`` to update parameters on a node with the gradients of that node.
4. Perform the ``optimizer.step()`` to update parameters on each node with the gradients of that node.

5. Approximately average the parameters using a base algorithm - one of LocalSGD or SGP (both are described above).

6. Perform the slow momentum update step once every ``slowmo_frequency`` (typically 48) iterations. In this step, the parameters on different
nodes are reduced, followed by a ``slowmo_optimizer.step()``. Note that this ``slowmo_optimizer`` is different from the original optimizer,
and it is done in a `Zero-1 <https://fairscale.readthedocs.io/en/latest/deep_dive/oss_sdp_fsdp.html>`_ like manner to save memory.
nodes are (exactly) averaged, followed by a ``slowmo_optimizer.step()``. Note that this ``slowmo_optimizer`` is different from the original optimizer,
and it is done in a `Zero-1 <./oss_sdp_fsdp.html>`_ like manner to save memory.

Best practices for using ``SlowMoDistributedDataParallel``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

1. SlowMo will be useful in deep learning workloads which run on greater than 2 nodes in clusters with a slow interconnect, eg Ethernet.
1. SlowMo will be useful in deep learning workloads which run on more than 2 nodes in clusters with a slow interconnect, eg Ethernet.

2. SlowMo should be useful in your workload if the following condition holds:
2. SlowMo should be useful in your workload if the following condition holds (in case you are using SGP as the base algorithm, the value of ``localsgd_frequency`` can be plugged in as 2):

:math:`\textrm{time_taken_for_all_reduce_of_gradients} \times (1 - \frac{1}{\textrm{localsgd_frequency}} ) > \textrm{time_taken_for_backward_pass}`
vtantia marked this conversation as resolved.
Show resolved Hide resolved

In clusters with slower interconnect, ``time_taken_for_all_reduce_of_gradients`` will go up, leading to SlowMo being more useful. ``localsgd_frequency``
is also an important factor here. More details on varying that to affect performance are in tip 2 of
`Performance tips for SlowMoDistributedDataParallel`_.

3. ``slowmo_momentum`` will need to be tuned for obtaining good accuracy. A random search across 4 values from [0.1, 0.2, ..., 0.7] should be good enough
3. ``slowmo_momentum`` will need to be tuned for obtaining good model quality. A grid search across {0.0, 0.1, 0.2, 0.4, 0.6} should be good enough
for tuning. This ``slowmo_momentum`` value holds consistent across multiple runs with similar settings. When the number of nodes used is increased,
however, a higher value of ``slow_momentum`` should be needed. More details about this can be found in the
`documentation <https://fairscale.readthedocs.io/en/latest/api/experimental/nn/slowmo_ddp.html>`_.
`documentation <../api/experimental/nn/slowmo_ddp.html>`_.

4. Adding SlowMo involves two steps, which can be found in the `tutorial <https://fairscale.readthedocs.io/en/latest/tutorials/slowmo_ddp.html>`_.
4. Adding SlowMo to existing Distributed Data Parallel code involves two steps, which can be found in the `tutorial <../tutorials/slowmo_ddp.html>`_.

Performance tips for ``SlowMoDistributedDataParallel``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

1. ``nprocs_per_node`` should be set to the number of GPUs in a node. This allows the API to exploit the fast interconnect between different GPUs
on a node.
1. ``nprocs_per_node`` should be set to the number of GPUs on a node (this number should be the same on each node). This allows the API
to exploit the fast interconnect between different GPUs on a node.

2. Increasing the ``localsgd_frequency`` results in an increase in speed. However, it comes with a tradeoff of reducing the accuracy.
2. Increasing the ``localsgd_frequency`` results in an increase in speed. However, it comes with a tradeoff of reducing the model quality.
We recommend keeping the ``localsgd_frequency`` at 3.

3. ``slowmo_memory_efficient`` should typically be used. It reduces memory usage by sharding the extra slow momentum optimizer's parameters in
a `Zero-1`_ like manner.
3. ``slowmo_memory_efficient`` should typically be used (this is the default behavior). It reduces memory usage by sharding the additional
slow momentum optimizer's parameters in a `Zero-1`_ like manner.

4. A call to ``model.zero_grad()`` should be made after ``optimizer.step()`` in order to save memory for the ``model.perform_slowmo()`` step.
56 changes: 12 additions & 44 deletions docs/source/tutorials/slowmo_ddp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,48 +7,14 @@ clusters with low interconnect speeds between different nodes. When using
SlowMo, the models on the different nodes are no longer kept in sync after each
iteration, which leads to the optimization dynamics being affected. The end
result is close to the results of Distributed Data Parallel, but is not exactly
the same. Let's suppose that your trainer looks like:
the same.

.. code-block:: python


import torch
from torch.nn.parallel import DistributedDataParallel as DDP


def train(
rank: int,
world_size: int,
epochs: int):

# process group init
dist_init(rank, world_size)

# Problem statement
model = MyAwesomeModel().to(rank)
model = DDP(model, device_ids=[rank])
dataloader = MySuperFastDataloader()
loss_ln = MyVeryRelevantLoss()
optimizer = MyAmazingOptimizer()

# Any relevant training loop, nothing specific to SlowMoDDP
# For example:
model.train()
for e in range(epochs):
for (data, targets) in dataloader:
data, targets = data.to(rank), targets.to(rank)
# Train
model.zero_grad()
outputs = model(data)
loss = loss_fn(outputs, targets)
loss.backward()
optimizer.step()


Then using SlowMo Distributed Data Parallel is simply replacing the DDP call with a call to
``fairscale.experimental.nn.data_parallel.SlowMoDistributedDataParallel`` and adding a
``model.perform_slowmo(optimizer)`` call after ``optimizer.step()``, as follows. The different
points at which ``use_slowmo`` is used below help demonstrate these changes.
If you have code that is setup to use Distributed Data Parallel, using SlowMo Distributed Data Parallel
is simply replacing the DDP call with a call to
``fairscale.experimental.nn.data_parallel.SlowMoDistributedDataParallel``, adding a
``model.perform_slowmo(optimizer)`` call after ``optimizer.step()``, and moving the ``model.zero_grad()``
to be after ``optimizer.step()``, as follows. The different points at which ``use_slowmo`` is used
below help demonstrate these changes:

.. code-block:: python

Expand Down Expand Up @@ -84,18 +50,20 @@ points at which ``use_slowmo`` is used below help demonstrate these changes.
for (data, target) in dataloader:
data, target = data.to(rank), target.to(rank)
# Train
model.zero_grad()
if not use_slowmo:
model.zero_grad()
vtantia marked this conversation as resolved.
Show resolved Hide resolved
outputs = model(data)
loss = loss_fn(outputs, target)
loss.backward()
optimizer.step()
if use_slowmo:
model.zero_grad()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume it is important to have it here for SlowMo, but I don't remember why: it would be good to explain it in the docstring of perform_slowmo() and refer to this doc here.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update, a couple of follow-up comments on this point:

  1. Minor: would it be possible for the doc link to directly point to the perform_slowmo part of the page? (no big deal if not possible)

  2. How does this save memory? According to the doc (https://pytorch.org/docs/stable/generated/torch.nn.Module.html) it won't flush the tensors unless set_to_none is set to True

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Have fixed this. The link is a little ugly but it has very little chance of breaking in the future, so it might be good to go ahead with

  2. Ahh nice catch, I've fixed that. In the fairseq repo, setting to None was the default behavior of zero_grad so I got confused about that

model.perform_slowmo(optimizer) # SlowMoDDP specific
vtantia marked this conversation as resolved.
Show resolved Hide resolved

In the example above, when using SlowMoDDP, we are reducing the total communication between
nodes by 3 times as the default ``localsgd_frequency`` is set to 3.
nodes by 3 times as the default ``localsgd_frequency`` is set to 3 by default.
vtantia marked this conversation as resolved.
Show resolved Hide resolved
SlowMoDDP takes in ``slowmo_momentum`` as a parameter. This parameter may need to be tuned
depending on your use case. It also takes in ``nproces_per_node`` which should be typically set
to the number of GPUs on a node. Please look at the
`documentation <https://fairscale.readthedocs.io/en/latest/api/experimental/nn/slowmo_ddp.html>`_
`documentation <../api/experimental/nn/slowmo_ddp.html>`_
for more details on these parameters as well as other advanced settings of the SlowMo algorithm.
8 changes: 5 additions & 3 deletions fairscale/experimental/nn/data_parallel/gossip/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,9 @@ def __init__(

self.slowmo_memory_efficient = slowmo_memory_efficient
self.slowmo_num_shards = min(self.process_world_size, slowmo_num_shards) if self.slowmo_memory_efficient else 1
self.is_computing_slowmo = self.process_rank < self.slowmo_num_shards if self.slowmo_memory_efficient else True
self.is_current_node_a_slowmo_shard = (
self.process_rank < self.slowmo_num_shards if self.slowmo_memory_efficient else True
)

self.nprocs_per_node_device = torch.tensor([self.nprocs_per_node], device=comm_device, dtype=first_param_dtype)

Expand Down Expand Up @@ -660,7 +662,7 @@ def _init_global_momentum_buffers(self, optimizer: torch.optim.Optimizer) -> Non

self.world_portion_length = (total_elements + self.slowmo_num_shards - 1) // self.slowmo_num_shards

if not self.is_computing_slowmo:
if not self.is_current_node_a_slowmo_shard:
return

self.portion_start = self.process_rank * self.world_portion_length if self.slowmo_memory_efficient else 0
Expand Down Expand Up @@ -747,7 +749,7 @@ def _global_momentum_step(self, optimizer: torch.optim.Optimizer) -> None:
if self.slowmo_memory_efficient:
self._distributed_comm(optimizer, mode="gather")

if self.is_computing_slowmo:
if self.is_current_node_a_slowmo_shard:
self._perform_local_optimization(optimizer)

if self.slowmo_memory_efficient:
Expand Down