Skip to content
Browse files

Fixed issue with broadcasting optimizer state for params that don't r…

…equire grads, added back utility function to broadcast any object (#1609)

Signed-off-by: Travis Addair <>
  • Loading branch information
tgaddair committed Jan 9, 2020
1 parent 438880e commit b505f149b4cb1c1e8f61e8d3f19aeba15383be71
Showing with 99 additions and 5 deletions.
  1. +52 −2 horovod/torch/
  2. +47 −3 test/
@@ -19,8 +19,12 @@
from __future__ import print_function

from contextlib import contextmanager

import io
import warnings

import cloudpickle

from horovod.common.util import check_extension

@@ -211,6 +215,7 @@ def zero_grad(self):
"This is prohibited as it can cause a race condition.")
return super(self.__class__, self).zero_grad()

class _DistributedAdasumOptimizer(torch.optim.Optimizer):
def __init__(self, params, named_parameters, compression,
@@ -381,6 +386,7 @@ def zero_grad(self):
"This is prohibited as it can cause a race condition.")
return super(self.__class__, self).zero_grad()

def DistributedOptimizer(optimizer, named_parameters=None,
@@ -491,7 +497,8 @@ def broadcast_optimizer_state(optimizer, root_rank):
if len(state_dict['state']) == 0:
for group in optimizer.param_groups:
for p in group['params']:
p.grad =
if p.requires_grad and id(p) not in state_dict['state']:
p.grad =
# This function accepts a torch.optim.Optimizer or a DistributedOptimizer
# wrapped around a torch optimizer. Calling step() with a DistributedOptimizer
# forces allreduce on all model parameters, which will result in deadlock
@@ -582,7 +589,50 @@ def _from_tensor():
# Synchronized broadcast of all parameters
broadcast_parameters(params, root_rank)

# Post-broadcast clenaup for non-tensor parameters
# Post-broadcast cleanup for non-tensor parameters
for key, p in params:
if key in callbacks:

def broadcast_object(obj, root_rank, name=None):
Serializes and broadcasts an object from root rank to all other processes.
Typical usage is to broadcast the `optimizer.state_dict()`, for example:
.. code-block:: python
state_dict = broadcast_object(optimizer.state_dict(), 0)
if hvd.rank() > 0:
obj: An object capable of being serialized without losing any context.
root_rank: The rank of the process from which parameters will be
broadcasted to all other processes.
name: Optional name to use during broadcast, will default to the class
The object that was broadcast from the `root_rank`.
if name is None:
name = str(type(obj))

if rank() == root_rank:
b = io.BytesIO()
cloudpickle.dump(obj, b)
t = torch.ByteTensor(bytearray(b.getvalue()))
sz = torch.IntTensor([t.shape[0]])
broadcast_(sz, root_rank, name + '.sz')
sz = torch.IntTensor([0])
broadcast_(sz, root_rank, name + '.sz')
t = torch.ByteTensor(sz.tolist()[0])

broadcast_(t, root_rank, name + '.t')

if rank() != root_rank:
buf = io.BytesIO(t.numpy().tobytes())
obj = cloudpickle.load(buf)

return obj
@@ -19,17 +19,21 @@
from __future__ import print_function

from distutils.version import LooseVersion

import collections
import inspect
import itertools
import numpy as np
import os
import pytest
import tempfile
import torch
import torch.nn.functional as F
import unittest
import warnings

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import horovod.torch as hvd

from common import mpi_env_rank_and_size
@@ -1101,6 +1105,46 @@ def create_model(opt_class):

@pytest.mark.skipif(LooseVersion(torch.__version__) < LooseVersion('0.4.1'),
reason='Cannot optimize parameters that do not require gradients before PyTorch 0.4.1')
def test_broadcast_state_no_grad(self):
class ModelNoGrad(nn.Module):
def __init__(self, a, b):
super(ModelNoGrad, self).__init__()
self.a = nn.Parameter(, requires_grad=False)
self.b = nn.Parameter(b)

def forward(self, x):
return torch.index_select(self.b, 0, self.a.long()) * x


a = torch.Tensor([1, 3])
b = torch.rand(4)

model = ModelNoGrad(a, b)

optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())

hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)

assert optimizer.param_groups[0]['params'][0].grad is None
assert torch.all(torch.eq(optimizer.param_groups[0]['params'][1].grad, torch.zeros([4]))).item()

def test_broadcast_object(self):

expected_obj = {
'hello': 123,
0: [1, 2]
obj = expected_obj if hvd.rank() == 0 else {}

obj = hvd.broadcast_object(obj, root_rank=0)
self.assertDictEqual(obj, expected_obj)

def test_compression_fp16(self):
valid_dtypes = [torch.float32, torch.float64]
invalid_dtypes = [torch.uint8, torch.int8, torch.int16,

0 comments on commit b505f14

Please sign in to comment.
You can’t perform that action at this time.