diff --git a/horovod/torch/__init__.py b/horovod/torch/__init__.py index 6a138b863d..f76819f7a0 100644 --- a/horovod/torch/__init__.py +++ b/horovod/torch/__init__.py @@ -17,6 +17,7 @@ from __future__ import division from __future__ import print_function +from contextlib import contextmanager import warnings from horovod.common.util import check_extension @@ -83,6 +84,7 @@ def __init__(self, params, named_parameters, compression, self._grad_accs = [] self._requires_update = set() self._synchronized = False + self._should_synchronize = True if size() > 1: self._register_hooks() @@ -157,13 +159,31 @@ def synchronize(self): self._synchronized = True - def step(self, closure=None, synchronize=True): - if synchronize: + @contextmanager + def already_synchronized(self): + """ + A context manager used to specify that optimizer.step() should + not perform synchronization. + + It's typically used in a following pattern: + ``` + optimizer.synchronize() + with optimizer.already_synchronized(): + optimizer.step() + ``` + """ + self._should_synchronize = False + yield + self._should_synchronize = True + + def step(self, closure=None): + if self._should_synchronize: if self._synchronized: - warnings.warn("optimizer.step(synchronize=True) called after " + warnings.warn("optimizer.step() called without " + "optimizer.already_synchronized() context after " "optimizer.synchronize(). This can cause training " "slowdown. You may want to consider using " - "optimizer.step(synchronize=False) if you use " + "optimizer.already_synchronized() context if you use " "optimizer.synchronize() in your code.") self.synchronize() self._synchronized = False @@ -191,7 +211,7 @@ def DistributedOptimizer(optimizer, named_parameters=None, DistributedOptimizer exposes the `synchronize()` method, which forces allreduce operations to finish before continuing the execution. It's useful in conjunction with gradient clipping, or other operations that modify gradients in place before `step()` is executed. - Make sure to pass `synchronize=False` to `step()` method if you're calling `synchronize()` + Make sure to use `optimizer.already_synchronized()` if you're calling `synchronize()` in your code. Example of gradient clipping: @@ -201,7 +221,8 @@ def DistributedOptimizer(optimizer, named_parameters=None, loss.backward() optimizer.synchronize() torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) - optimizer.step(synchronize=False) + with optimizer.already_synchronized(): + optimizer.step() ``` Arguments: diff --git a/test/test_torch.py b/test/test_torch.py index f23a6b6855..0db0ff7ea9 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -1261,10 +1261,14 @@ def test_gradient_clipping(self): torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) clipped_grad = model.weight.grad.item() assert abs(prior_grad) > abs(clipped_grad) - optimizer.step(synchronize=False) + with optimizer.already_synchronized(): + optimizer.step() def test_synchronize_step_warning(self): - """Test that .synchronize() followed by .step(synchronize=True) will produce a warning.""" + """ + Test that .synchronize() followed by .step() without + optimizer.already_synchronized() context will produce a warning. + """ hvd.init() size = hvd.size() @@ -1288,9 +1292,10 @@ def test_synchronize_step_warning(self): optimizer.synchronize() torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) with warnings.catch_warnings(record=True) as ws: - optimizer.step(synchronize=True) + with optimizer.already_synchronized(): + optimizer.step() assert len(ws) == 1 - assert 'optimizer.step(synchronize=True) called after optimizer.synchronize()' \ + assert 'optimizer.step() called without optimizer.already_synchronized()' \ in str(ws[0].message) def test_no_named_parameters(self):