-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add partition optimizer status trainer #2951
Conversation
Unit Test Results 438 files + 12 438 suites +12 6h 20m 36s ⏱️ + 1h 1m 26s For more details on these failures, see this check. Results for commit 65ae9af. ± Comparison against base commit f0c44d2. This pull request removes 5 and adds 19 tests. Note that renamed tests count towards both.
♻️ This comment has been updated with latest results. |
This is awesome @xinyual! Can you also add a test? |
This sounds intriguing! I guess a ZeRO-like optimizer would benefit from a ReduceScatter operation in Horovod (#1496)? |
@maxhgerlach yes, though this PR only implements optimizer state sharding and not yet gradient sharding. Do you have any plans to revive #1496? |
@xinyual, can you rebase your PR off master? The most recent changes should fix the failing buildkite and docs tests. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR. This looks very cool and would be a great addition to Horovod!
Left a few minor comments.
|
||
self._allreduce_grads() | ||
self._update(ignore_stale_grad) | ||
self.broadcast_params() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason you've rewritten all of step()
here instead of just calling step()
of the parent class and adding just the additional call to broadcast_params()
?
|
||
def step(self, batch_size, ignore_stale_grad=False): | ||
""" | ||
inherit from trainer, only call boardcast to make sure all parameter are consistent |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo: boardcast -> broadcast
from mxnet.gluon.parameter import Parameter | ||
|
||
|
||
class POS_Trainer(mx.gluon.Trainer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can the name of this be more descriptive? Perhaps something like PartitionedParameterTrainer
or similar?
Alternatively, do you think this parameter splitting code can be made a feature of the existing hvd.DistributedTrainer
instead of creating a completely separate implementation?
self._gradient_predivide_factor = gradient_predivide_factor | ||
|
||
|
||
def partition_parameters(self, params): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add leading underscore to internal methods like this one. Also, add a brief comment describing what this does.
|
||
|
||
|
||
def broadcast_params(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should also be renamed to something like _broadcast_partitioned_params
to better differentiate it from the existing _broadcast_parameters
method. Also, add a brief comment describing what it does.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
Unit Test Results (with flaky tests) 546 files - 47 546 suites - 47 8h 19m 25s ⏱️ + 2h 31m 32s For more details on these failures, see this check. Results for commit 65ae9af. ± Comparison against base commit f0c44d2. This pull request removes 5 and adds 19 tests. Note that renamed tests count towards both.
|
I see that PR but I think it is not suitable. In the paper it mentions reduce-scatter but I think it is in model-level. For parameters, we may need to use reduce/allreduce. I open a new PR [#3309].(#3309) with testcase. Please see that. |
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
Checklist before submitting
Description
Fixes # (issue).
Review process to land