Skip to content

Commit

Permalink
Create DistributedInitializer to broadcast deferred-init param
Browse files Browse the repository at this point in the history
Signed-off-by: Yuxi Hu <darrenyxhu@gmail.com>
  • Loading branch information
yuxihu committed Mar 20, 2019
1 parent 4f7319e commit fbfdcb7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
14 changes: 5 additions & 9 deletions examples/mxnet_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,27 +112,23 @@ def evaluate(model, data_iter, context):
model.cast(args.dtype)
model.hybridize()

# Define hyper parameters
# Create optimizer
optimizer_params = {'momentum': args.momentum,
'learning_rate': args.lr * hvd.size(),
'rescale_grad': 1.0 / args.batch_size}

# Add Horovod Distributed Optimizer
opt = mx.optimizer.create('sgd', **optimizer_params)
# Horovod: wrap optimizer with DistributedOptimizer
opt = hvd.DistributedOptimizer(opt)

# Initialize parameters
initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in",
magnitude=2)
# Horovod: wrap initializer with DistributedInitializer
initializer = hvd.DistributedInitializer(initializer)
model.initialize(initializer, ctx=context)

# Fetch and broadcast parameters
params = model.collect_params()
if params is not None:
hvd.broadcast_parameters(params, root_rank=0)

# Create trainer, loss function and train metric
trainer = gluon.Trainer(params, opt, kvstore=None)
trainer = gluon.Trainer(model.collect_params(), opt, kvstore=None)
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
metric = mx.metric.Accuracy()

Expand Down
15 changes: 15 additions & 0 deletions horovod/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,21 @@ def set_wd_mult(self, args_wd_mult):
self._optimizer.set_wd_mult(args_wd_mult)


# DistributedInitializer wraps MXNet Initializer which initializes and broadcasts parameter.
class DistributedInitializer(mx.initializer.Initializer):
def __init__(self, init, root_rank=0):
self._init = init
self._root_rank = root_rank

def __call__(self, desc, arr):
self._init(desc, arr)
broadcast_(arr, self._root_rank, desc)
arr.wait_to_read()

def _init_weight(self, name, arr):
self._init._init_weight(name, arr)


def broadcast_parameters(params, root_rank=0):
"""
Broadcasts the parameters from root rank to all other processes.
Expand Down

0 comments on commit fbfdcb7

Please sign in to comment.