Skip to content

Commit

Permalink
Replace .step(synchronize=False) with optimizer.already_synchronized()
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Sergeev <alsrgv@users.noreply.github.com>
  • Loading branch information
alsrgv committed Jun 10, 2019
1 parent ae96421 commit 1d3f0e5
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 10 deletions.
33 changes: 27 additions & 6 deletions horovod/torch/__init__.py
Expand Up @@ -17,6 +17,7 @@
from __future__ import division
from __future__ import print_function

from contextlib import contextmanager
import warnings

from horovod.common.util import check_extension
Expand Down Expand Up @@ -83,6 +84,7 @@ def __init__(self, params, named_parameters, compression,
self._grad_accs = []
self._requires_update = set()
self._synchronized = False
self._should_synchronize = True
if size() > 1:
self._register_hooks()

Expand Down Expand Up @@ -157,13 +159,31 @@ def synchronize(self):

self._synchronized = True

def step(self, closure=None, synchronize=True):
if synchronize:
@contextmanager
def already_synchronized(self):
"""
A context manager used to specify that optimizer.step() should
not perform synchronization.
It's typically used in a following pattern:
```
optimizer.synchronize()
with optimizer.already_synchronized():
optimizer.step()
```
"""
self._should_synchronize = False
yield
self._should_synchronize = True

def step(self, closure=None):
if self._should_synchronize:
if self._synchronized:
warnings.warn("optimizer.step(synchronize=True) called after "
warnings.warn("optimizer.step() called without "
"optimizer.already_synchronized() context after "
"optimizer.synchronize(). This can cause training "
"slowdown. You may want to consider using "
"optimizer.step(synchronize=False) if you use "
"optimizer.already_synchronized() context if you use "
"optimizer.synchronize() in your code.")
self.synchronize()
self._synchronized = False
Expand Down Expand Up @@ -191,7 +211,7 @@ def DistributedOptimizer(optimizer, named_parameters=None,
DistributedOptimizer exposes the `synchronize()` method, which forces allreduce operations
to finish before continuing the execution. It's useful in conjunction with gradient
clipping, or other operations that modify gradients in place before `step()` is executed.
Make sure to pass `synchronize=False` to `step()` method if you're calling `synchronize()`
Make sure to use `optimizer.already_synchronized()` if you're calling `synchronize()`
in your code.
Example of gradient clipping:
Expand All @@ -201,7 +221,8 @@ def DistributedOptimizer(optimizer, named_parameters=None,
loss.backward()
optimizer.synchronize()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step(synchronize=False)
with optimizer.already_synchronized():
optimizer.step()
```
Arguments:
Expand Down
13 changes: 9 additions & 4 deletions test/test_torch.py
Expand Up @@ -1261,10 +1261,14 @@ def test_gradient_clipping(self):
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
clipped_grad = model.weight.grad.item()
assert abs(prior_grad) > abs(clipped_grad)
optimizer.step(synchronize=False)
with optimizer.already_synchronized():
optimizer.step()

def test_synchronize_step_warning(self):
"""Test that .synchronize() followed by .step(synchronize=True) will produce a warning."""
"""
Test that .synchronize() followed by .step() without
optimizer.already_synchronized() context will produce a warning.
"""
hvd.init()
size = hvd.size()

Expand All @@ -1288,9 +1292,10 @@ def test_synchronize_step_warning(self):
optimizer.synchronize()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
with warnings.catch_warnings(record=True) as ws:
optimizer.step(synchronize=True)
with optimizer.already_synchronized():
optimizer.step()
assert len(ws) == 1
assert 'optimizer.step(synchronize=True) called after optimizer.synchronize()' \
assert 'optimizer.step() called without optimizer.already_synchronized()' \
in str(ws[0].message)

def test_no_named_parameters(self):
Expand Down

0 comments on commit 1d3f0e5

Please sign in to comment.