diff --git a/examples/mxnet_mnist.py b/examples/mxnet_mnist.py index 3ed7f8f426..63f3dddc76 100644 --- a/examples/mxnet_mnist.py +++ b/examples/mxnet_mnist.py @@ -112,27 +112,23 @@ def evaluate(model, data_iter, context): model.cast(args.dtype) model.hybridize() -# Define hyper parameters +# Create optimizer optimizer_params = {'momentum': args.momentum, 'learning_rate': args.lr * hvd.size(), 'rescale_grad': 1.0 / args.batch_size} - -# Add Horovod Distributed Optimizer opt = mx.optimizer.create('sgd', **optimizer_params) +# Horovod: wrap optimizer with DistributedOptimizer opt = hvd.DistributedOptimizer(opt) # Initialize parameters initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) +# Horovod: wrap initializer with DistributedInitializer +initializer = hvd.DistributedInitializer(initializer) model.initialize(initializer, ctx=context) -# Fetch and broadcast parameters -params = model.collect_params() -if params is not None: - hvd.broadcast_parameters(params, root_rank=0) - # Create trainer, loss function and train metric -trainer = gluon.Trainer(params, opt, kvstore=None) +trainer = gluon.Trainer(model.collect_params(), opt, kvstore=None) loss_fn = gluon.loss.SoftmaxCrossEntropyLoss() metric = mx.metric.Accuracy() diff --git a/horovod/mxnet/__init__.py b/horovod/mxnet/__init__.py index 5250d7f95f..9b48b747a8 100644 --- a/horovod/mxnet/__init__.py +++ b/horovod/mxnet/__init__.py @@ -68,6 +68,21 @@ def set_wd_mult(self, args_wd_mult): self._optimizer.set_wd_mult(args_wd_mult) +# DistributedInitializer wraps MXNet Initializer which initializes and broadcasts parameter. +class DistributedInitializer(mx.initializer.Initializer): + def __init__(self, init, root_rank=0): + self._init = init + self._root_rank = root_rank + + def __call__(self, desc, arr): + self._init(desc, arr) + broadcast_(arr, self._root_rank, desc) + arr.wait_to_read() + + def _init_weight(self, name, arr): + self._init._init_weight(name, arr) + + def broadcast_parameters(params, root_rank=0): """ Broadcasts the parameters from root rank to all other processes.