Skip to content

Commit

Permalink
inject broadcast to init_impl
Browse files Browse the repository at this point in the history
Signed-off-by: Yuxi Hu <darrenyxhu@gmail.com>
  • Loading branch information
yuxihu committed Mar 20, 2019
1 parent fbfdcb7 commit 492138a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
9 changes: 6 additions & 3 deletions examples/mxnet_mnist.py
Expand Up @@ -123,12 +123,15 @@ def evaluate(model, data_iter, context):
# Initialize parameters
initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in",
magnitude=2)
# Horovod: wrap initializer with DistributedInitializer
initializer = hvd.DistributedInitializer(initializer)
model.initialize(initializer, ctx=context)

# Fetch and broadcast parameters
params = model.collect_params()
if params is not None:
hvd.broadcast_parameters(params, root_rank=0)

# Create trainer, loss function and train metric
trainer = gluon.Trainer(model.collect_params(), opt, kvstore=None)
trainer = gluon.Trainer(params, opt, kvstore=None)
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
metric = mx.metric.Accuracy()

Expand Down
28 changes: 13 additions & 15 deletions horovod/mxnet/__init__.py
Expand Up @@ -30,6 +30,7 @@
from horovod.mxnet.mpi_ops import mpi_threads_supported

import mxnet as mx
import types


# This is where Horovod's DistributedOptimizer wrapper for MXNet goes
Expand Down Expand Up @@ -68,19 +69,14 @@ def set_wd_mult(self, args_wd_mult):
self._optimizer.set_wd_mult(args_wd_mult)


# DistributedInitializer wraps MXNet Initializer which initializes and broadcasts parameter.
class DistributedInitializer(mx.initializer.Initializer):
def __init__(self, init, root_rank=0):
self._init = init
self._root_rank = root_rank

def __call__(self, desc, arr):
self._init(desc, arr)
broadcast_(arr, self._root_rank, desc)
arr.wait_to_read()

def _init_weight(self, name, arr):
self._init._init_weight(name, arr)
# Wrapper to inject Horovod broadcast after parameter initialization
def _append_broadcast_init(param, root_rank):
init_impl = getattr(param, '_init_impl')
def wrapped_init_impl(self, *args, **kwargs):
init_impl(*args, **kwargs)
broadcast_(self.data(), root_rank=root_rank)
self.data().wait_to_read()
return wrapped_init_impl


def broadcast_parameters(params, root_rank=0):
Expand All @@ -104,8 +100,10 @@ def broadcast_parameters(params, root_rank=0):
try:
tensors.append(p.data())
except mx.gluon.parameter.DeferredInitializationError:
# skip broadcasting deferred init param
pass
# Inject wrapper method with post-initialization broadcast to
# handle parameters with deferred initialization
new_init = _append_broadcast_init(p, root_rank)
p._init_impl = types.MethodType(new_init, p)
else:
raise ValueError('invalid params of type: %s' % type(params))

Expand Down

0 comments on commit 492138a

Please sign in to comment.