Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MXNet: support broadcasting deferred initialization parameters in Gluon #915

Merged
merged 3 commits into from Mar 28, 2019

Conversation

yuxihu
Copy link
Collaborator

@yuxihu yuxihu commented Mar 14, 2019

Fixes #895

In Gluon, we defer initialization for some parameters until their shapes are known during the first forward pass. This makes it difficult for us to sync parameters among workers when we train with different random seeds for workers.

To solve the issue, we inject broadcast into init_impl of deferred initialized Gluon Parameter. Thanks @romerojosh for suggesting this idea.

@yuxihu yuxihu force-pushed the gluon_param branch 3 times, most recently from 50154ba to 0ac89b1 Compare March 15, 2019 05:34
@yuxihu yuxihu changed the title [WIP] MXNet: support broadcasting deferred initialization parameters in Gluon MXNet: support broadcasting deferred initialization parameters in Gluon Mar 15, 2019
@yuxihu
Copy link
Collaborator Author

yuxihu commented Mar 15, 2019

@apeforest @eric-haibin-lin @ctcyang Please help review. I will update the README and imagenet example after I get some feedback from you.

@ptrendx
Copy link

ptrendx commented Mar 15, 2019

Hi @yuxihu - maybe instead of DistributedOptimized and DistributedInitializer Gluon should just have distributedtrainer? That would make it much more sane I think.

@alsrgv
Copy link
Member

alsrgv commented Mar 15, 2019

Does gluon trainer have callback mechanism, where after all the variables are initialized we could add the broadcasting step? It would be nice to keep API the same across frameworks.

@yuxihu
Copy link
Collaborator Author

yuxihu commented Mar 15, 2019

@alsrgv The broadcast needs to happen here, after the try-except and before the hybrid_forward call. We do not have callback here.

We have callback after the forward call. It seems a little bit late since the first forward pass has already done. And we also need to have logic to make sure that we only broadcast once. The callback will be called after each forward pass.

@eric-haibin-lin any opinion on this?

@yuxihu
Copy link
Collaborator Author

yuxihu commented Mar 15, 2019

@ptrendx DistributedOptimizer is well adopted mechanism in Horovod across frameworks. DistributedInitializer is one option we came up to solve the deferred-init params broadcasting issue. I think the current Gluon Trainer is designed to work well with distributed training using the kvstore (PS) approach. How does the DistributedTrainer differ from the Trainer? We are happy to hear more about what you think. @eric-haibin-lin is working on improving distributed training using MXNet. Maybe we can have some discussions around it.

@alsrgv
Copy link
Member

alsrgv commented Mar 16, 2019

What about i.finish_deferred_init, could broadcast happen there?

Copy link
Collaborator

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whenever a parameter is initialized/re-initialized, the mx.initializer.Initializer will be called. This happens in various places, including .finish_deferred_init.
so a wrapper around the initializer to broadcast any newly initialized parameter sound reasonable.
There are hooks in Gluon before and after block executes forward, but none of them executes a callback right after parameter initialization. On Gluon side it probably does not make sense to add another register hood function for parameter initialization.

I'm also interested to hear @ptrendx 's DistributedTrainer proposal

@alsrgv
Copy link
Member

alsrgv commented Mar 17, 2019

What happens when you restore model from the checkpoint? Normally, the model is to reload model on rank 0 and broadcast to everyone else.

@ptrendx
Copy link

ptrendx commented Mar 18, 2019

Basically the problem is that what constitutes an "Optimizer" in TF and MXNet differs. In TF optimizer computes gradients, and applies them (so it is at a pretty high level with knowledge about all the gradients etc.), whereas Optimizer in MXNet is a low level class that just takes 1 or a few already computed gradients and just applies them.
MXNet's Trainer, on the other hand, is much closer to what Optimizer in TF is - it knows about all model parameters (and so all the gradients), it has allreduce_grads method which would perfectly fit horovod, and then it calls optimizers to actually do the update.

Additional problem with having Horovod as part of the optimizer in MXNet is that the previous distributed training strategy (kvstore) hooks in different places, which makes it confusing and error-prone for users who need to know when does the actual reduction happen (since in case of kvstore gradients are allreduced after calling trainer.allreduce_grads and in case of horovod they are not).

@alsrgv
Copy link
Member

alsrgv commented Mar 18, 2019

Sounds like a great fit. Could you provide an MNIST code example that would showcase DistributedTrainer?

@yuxihu
Copy link
Collaborator Author

yuxihu commented Mar 18, 2019

@alsrgv When restoring from a checkpoint in Gluon, I think each worker needs to load parameters from file and there should be no need to broadcast parameters. If we only load on rank 0, rank 0 worker and other workers do not broadcast the same set of parameters due to deferred-inited parameters, where broadcast does not work. @eric-haibin-lin correct me if I missed anything here.

@yuxihu
Copy link
Collaborator Author

yuxihu commented Mar 18, 2019

@ptrendx If we can integrate with Horovod at Trainer level, the internal workflow between kvstore and Horovod will be more or less consistent. But one complication with this proposal is that Module API does not have Trainer. The current integration on optimizer makes it easy to support both Gluon and Module API. Despite that we need to consider Module API support along the way, what you proposed is very interesting to explore such that we can take advantage of the improvement we are making at Trainer level in Horovod (e.g. the AMP feature you are working on). Shall we set up a meeting to discuss furthur?

@romerojosh
Copy link
Collaborator

romerojosh commented Mar 18, 2019

This is an idea of what the DistributedTrainer implementation could look like:

# Wrapper to inject Horovod broadcast after parameter initialization
def _append_broadcast_init(param, root_rank):
    init_impl = getattr(param, '_init_impl')
    def wrapped_init_impl(self, *args, **kwargs):
        init_impl(*args, **kwargs)
        broadcast_(self.data(), root_rank=root_rank)
        self.data().wait_to_read()
    return wrapped_init_impl

# Horovod DistributedTrainer wrapper
class DistributedTrainer(mx.gluon.Trainer):
    def __init__(self, trainer):
        self._trainer = trainer

        # Broadcast non-deferred parameters
        for param in self._params:
            try:
                broadcast_(param.data(), root_rank=0)
                param.data().wait_to_read()
            except mx.gluon.parameter.DeferredInitializationError:
                pass

        # Inject wrapper method with post-initialization broadcast to
        # handle parameters with deferred initialization
        for param in self._params:
            new_init = _append_broadcast_init(param, 0)
            param._init_impl = types.MethodType(new_init, param)

    def __getattr__(self, item):
        return getattr(self._trainer, item)

    def _allreduce_grads(self):
        for i, param in enumerate(self._params):
            if param.grad_req != 'null':
                allreduce_(param.list_grad()[0], average=True, name=str(i))

The initial parameter broadcast can be addressed through injecting a horovod broadcast into the _init_impl for the model parameters. I think this also would handle the deferred initialization issue as well. Then, for the trainer class, we simply override the _allreduce_grads method to perform a horovod allreduce.

The mnist script modifications with this are straightforward. One just needs to add the line trainer = hvd.DistributedTrainer(trainer) to the script, and remove both the the hvd.DistributedOptimizer wrapper application and the explicit parameter broadcast since those are now taken care of by DistributedTrainer.

Edit: The wrapping of _init_impl will only apply to parameters with deferred initialization. Added code to broadcast non-deferred parameters to DistributedTrainer.__init__ method.

@yuxihu
Copy link
Collaborator Author

yuxihu commented Mar 19, 2019

@romerojosh Thanks for demonstrating the idea of DistributedTrainer. I do see it solves two major issues: (1) broadcasting deferred initialization parameters without adding special callback in MXNet or DistributedInitializer in Horovod; (2) make training with kvstore and Horovod using MXNet Gluon API have the same internal work logic.

One concern I have is that the user experience of training with Gluon and Module API will differ. Although we have been encouraging users to adopt Gluon API, we still see users are using Module API for their training jobs. Module API will continue to use DistributedOptimizer and Gluon API should use DistributedTrainer. We need to document the difference well to avoid confusion as much as possible.

@alsrgv Does the DistributedTrainer look good to you? Besides the difference between Gluon and Module within MXNet, it will also introduce slightly different user experience compared to Tensorflow and PyTorch.

@eric-haibin-lin any thoughts about the DistributedTrainer proposal?

@alsrgv
Copy link
Member

alsrgv commented Mar 20, 2019

Can we have code similar to this in hvd.broadcast_parameters:

        # Inject wrapper method with post-initialization broadcast to
        # handle parameters with deferred initialization
        for param in self._params:
            new_init = _append_broadcast_init(param, 0)
            param._init_impl = types.MethodType(new_init, param

?

This would allow us to keep APIs the same. I think it'd be really useful to do that to keep the portability.

One exception to portability rule I feel would be ok is to have situation where all aspects (adding optimizer, broadcasts, data partitioning, GPU pinning, etc) are handled by a magical one-line high level API.

What do you guys think?

@yuxihu yuxihu force-pushed the gluon_param branch 2 times, most recently from 3c79e42 to 35ff369 Compare March 20, 2019 06:12
@yuxihu
Copy link
Collaborator Author

yuxihu commented Mar 20, 2019

@alsrgv I changed the broadcast_parameters to inject broadcast into init_impl which solves the deferred initialization issue and keep user experience the same.

I think @ptrendx and @romerojosh had some valid points regrading the idea of DistributedTrainer. The optimizer in MXNet is not equivalent to the one in TF. It will let us take advantage of the improvements we are making in MXNet.

@yuxihu
Copy link
Collaborator Author

yuxihu commented Mar 20, 2019

@apeforest @eric-haibin-lin please review

@romerojosh
Copy link
Collaborator

One exception to portability rule I feel would be ok is to have situation where all aspects (adding optimizer, broadcasts, data partitioning, GPU pinning, etc) are handled by a magical one-line high level API.

What do you guys think?

@alsrgv I think having an all-encapsulating API wrapper seems a bit overkill, just to resolve the design inconsistency pointed out by myself and @ptrendx. The main reason for having the DistributedTrainer wrapper is not to simplify the API for Gluon, but to add allreduces to the already existing allreduce_grads method in trainer, which in the current design is a confusing (and undocumented) no-op when using Horovod.

If we remove the broadcast handling from the DistributedTrainer example above, the only difference between the two operating modes (module and Gluon), is that for module, the user would use DistributedOptimizer to hook in the allreduces, while with Gluon, the user would use DistributedTrainer to hook in allreduces. This is a small inconsistency, but I think it is ultimately worth it to not have inconsistencies on the MXNet side on where allreduces are hooked in for Gluon.

@alsrgv
Copy link
Member

alsrgv commented Mar 21, 2019

I think DistributedOptimizer -> DistributedTrainer (w/o broadcast logic) replacement is acceptable since DistributedTrainer takes optimizer and can conceptually be though as a powerful wrapper for the optimizer.

@romerojosh
Copy link
Collaborator

@yuxihu Are you okay with this idea, in particular the very small difference between Gluon and Module this introduces?

@yuxihu
Copy link
Collaborator Author

yuxihu commented Mar 21, 2019

@romerojosh @eric-haibin-lin and I discussed offline about DistributedTrainer yesterday. We both think it makes sense for Gluon users. I will send out another PR for review soon.

@yuxihu
Copy link
Collaborator Author

yuxihu commented Mar 21, 2019

Here is the DistributedTrainer PR.

Copy link
Member

@apeforest apeforest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test for deferred initialization? otherwise, LGTM

@yuxihu
Copy link
Collaborator Author

yuxihu commented Mar 25, 2019

@apeforest @eric-haibin-lin @alsrgv Added unit test. Please help review and merge.

@alsrgv
Copy link
Member

alsrgv commented Mar 27, 2019

@yuxihu, could you rebase on latest master with CI fixes? Otherwise, LGTM.

Signed-off-by: Yuxi Hu <darrenyxhu@gmail.com>
Signed-off-by: Yuxi Hu <darrenyxhu@gmail.com>
Signed-off-by: Yuxi Hu <darrenyxhu@gmail.com>
@yuxihu
Copy link
Collaborator Author

yuxihu commented Mar 27, 2019

@yuxihu, could you rebase on latest master with CI fixes? Otherwise, LGTM.

Thanks for letting me know. Just rebased and triggered the CI.

@yuxihu
Copy link
Collaborator Author

yuxihu commented Mar 27, 2019

@alsrgv Looks like all the tests are queued but not running for two hours. Any issue with Travis?

@alsrgv
Copy link
Member

alsrgv commented Mar 27, 2019

@yuxihu, yeah, Travis free edition hasn't been great recently. I'll try to migrate to Buildkite ASAP.

@alsrgv alsrgv self-requested a review March 27, 2019 22:28
@alsrgv
Copy link
Member

alsrgv commented Mar 27, 2019

@apeforest @eric-haibin-lin are you OK to merge this assuming the tests pass?

Copy link
Member

@apeforest apeforest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@alsrgv alsrgv merged commit 3a1083e into horovod:master Mar 28, 2019
@yuxihu yuxihu deleted the gluon_param branch March 28, 2019 21:26
shirosankaku pushed a commit to SmorkalovME/horovod that referenced this pull request May 30, 2019
…on (horovod#915)

* Create DistributedInitializer to broadcast deferred-init param

Signed-off-by: Yuxi Hu <darrenyxhu@gmail.com>

* inject broadcast to init_impl

Signed-off-by: Yuxi Hu <darrenyxhu@gmail.com>

* add unit test

Signed-off-by: Yuxi Hu <darrenyxhu@gmail.com>
Signed-off-by: Yana Shchyokotova <yana.shchyokotova@intel.com>
zsh-thu pushed a commit to zsh-thu/horovod that referenced this pull request Jun 3, 2019
…on (horovod#915)

* Create DistributedInitializer to broadcast deferred-init param

Signed-off-by: Yuxi Hu <darrenyxhu@gmail.com>

* inject broadcast to init_impl

Signed-off-by: Yuxi Hu <darrenyxhu@gmail.com>

* add unit test

Signed-off-by: Yuxi Hu <darrenyxhu@gmail.com>
Signed-off-by: Sihan Zeng <zsh@uber.com>
jeffdaily pushed a commit to ROCm/horovod that referenced this pull request Nov 27, 2019
…on (horovod#915)

* Create DistributedInitializer to broadcast deferred-init param

Signed-off-by: Yuxi Hu <darrenyxhu@gmail.com>

* inject broadcast to init_impl

Signed-off-by: Yuxi Hu <darrenyxhu@gmail.com>

* add unit test

Signed-off-by: Yuxi Hu <darrenyxhu@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants