Skip to content

Commit

Permalink
Replace .step(synchronize=False) with optimizer.skip_synchronize() (#…
Browse files Browse the repository at this point in the history
…1132)

* Replace .step(synchronize=False) with optimizer.already_synchronized()

Signed-off-by: Alex Sergeev <alsrgv@users.noreply.github.com>

* Fix docs

Signed-off-by: Alex Sergeev <alsrgv@users.noreply.github.com>

* Rename to skip_synchronize() and fix test

Signed-off-by: Alex Sergeev <alsrgv@users.noreply.github.com>
  • Loading branch information
alsrgv committed Jun 11, 2019
1 parent 2c81bfc commit 049f108
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 24 deletions.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
'special-members': '__init__',
'imported-members': None,
'undoc-members': None,
'exclude-members': 'contextmanager',
}


Expand Down
63 changes: 43 additions & 20 deletions horovod/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import division
from __future__ import print_function

from contextlib import contextmanager
import warnings

from horovod.common.util import check_extension
Expand Down Expand Up @@ -83,6 +84,7 @@ def __init__(self, params, named_parameters, compression,
self._grad_accs = []
self._requires_update = set()
self._synchronized = False
self._should_synchronize = True
if size() > 1:
self._register_hooks()

Expand Down Expand Up @@ -157,13 +159,32 @@ def synchronize(self):

self._synchronized = True

def step(self, closure=None, synchronize=True):
if synchronize:
@contextmanager
def skip_synchronize(self):
"""
A context manager used to specify that optimizer.step() should
not perform synchronization.
It's typically used in a following pattern:
.. code-block:: python
optimizer.synchronize()
with optimizer.skip_synchronize():
optimizer.step()
"""
self._should_synchronize = False
yield
self._should_synchronize = True

def step(self, closure=None):
if self._should_synchronize:
if self._synchronized:
warnings.warn("optimizer.step(synchronize=True) called after "
warnings.warn("optimizer.step() called without "
"optimizer.skip_synchronize() context after "
"optimizer.synchronize(). This can cause training "
"slowdown. You may want to consider using "
"optimizer.step(synchronize=False) if you use "
"optimizer.skip_synchronize() context if you use "
"optimizer.synchronize() in your code.")
self.synchronize()
self._synchronized = False
Expand All @@ -184,30 +205,32 @@ def DistributedOptimizer(optimizer, named_parameters=None,
An optimizer that wraps another torch.optim.Optimizer, using an allreduce to
average gradient values before applying gradients to model weights.
Allreduce operations are executed after each gradient is computed by `loss.backward()`
in parallel with each other. The `step()` method ensures that all allreduce operations are
Allreduce operations are executed after each gradient is computed by ``loss.backward()``
in parallel with each other. The ``step()`` method ensures that all allreduce operations are
finished before applying gradients to the model.
DistributedOptimizer exposes the `synchronize()` method, which forces allreduce operations
DistributedOptimizer exposes the ``synchronize()`` method, which forces allreduce operations
to finish before continuing the execution. It's useful in conjunction with gradient
clipping, or other operations that modify gradients in place before `step()` is executed.
Make sure to pass `synchronize=False` to `step()` method if you're calling `synchronize()`
clipping, or other operations that modify gradients in place before ``step()`` is executed.
Make sure to use ``optimizer.skip_synchronize()`` if you're calling ``synchronize()``
in your code.
Example of gradient clipping:
```
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.synchronize()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step(synchronize=False)
```
.. code-block:: python
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.synchronize()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
with optimizer.skip_synchronize():
optimizer.step()
Arguments:
optimizer: Optimizer to use for computing gradients and applying updates.
named_parameters: A mapping between parameter names and values. Used for naming of
allreduce operations. Typically just `model.named_parameters()`.
allreduce operations. Typically just ``model.named_parameters()``.
compression: Compression algorithm used during allreduce to reduce the amount
of data sent during the each parameter update step. Defaults to
not using compression.
Expand All @@ -228,8 +251,8 @@ def DistributedOptimizer(optimizer, named_parameters=None,
def broadcast_parameters(params, root_rank):
"""
Broadcasts the parameters from root rank to all other processes.
Typical usage is to broadcast the `model.state_dict()`,
`model.named_parameters()`, or `model.parameters()`.
Typical usage is to broadcast the ``model.state_dict()``,
``model.named_parameters()``, or ``model.parameters()``.
Arguments:
params: One of the following:
Expand Down
12 changes: 8 additions & 4 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,10 +1261,14 @@ def test_gradient_clipping(self):
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
clipped_grad = model.weight.grad.item()
assert abs(prior_grad) > abs(clipped_grad)
optimizer.step(synchronize=False)
with optimizer.skip_synchronize():
optimizer.step()

def test_synchronize_step_warning(self):
"""Test that .synchronize() followed by .step(synchronize=True) will produce a warning."""
"""
Test that .synchronize() followed by .step() without
optimizer.skip_synchronize() context will produce a warning.
"""
hvd.init()
size = hvd.size()

Expand All @@ -1288,9 +1292,9 @@ def test_synchronize_step_warning(self):
optimizer.synchronize()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
with warnings.catch_warnings(record=True) as ws:
optimizer.step(synchronize=True)
optimizer.step()
assert len(ws) == 1
assert 'optimizer.step(synchronize=True) called after optimizer.synchronize()' \
assert 'optimizer.step() called without optimizer.skip_synchronize()' \
in str(ws[0].message)

def test_no_named_parameters(self):
Expand Down

0 comments on commit 049f108

Please sign in to comment.