Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spread layers more uniformly when using partition_uniform #4053

Merged
merged 7 commits into from
Aug 3, 2023

Conversation

marcobellagente93
Copy link
Contributor

The floor operation in the old implementation accumulates residuals which are all dumped in the last stage, potentially causing highly unbalanced partitions when both the residual and the number of pipeline stage is large.
The proposed change spreads out the layers more uniformly resulting in a more balanced partition.

As an example:

  • old: partition_unform(num_items=38, num_parts=8) returns [0, 4, 8, 12, 16, 20, 24, 28, 38]
  • new: partition_unform(num_items=38, num_parts=8) returns [0, 5, 10, 15, 20, 25, 30, 34, 38]

@marcobellagente93
Copy link
Contributor Author

@microsoft-github-policy-service agree

@mrwyattii
Copy link
Contributor

@marcobellagente93 could you please run the precommit and commit any changes made?

pre-commit run --all-files
git add deepspeed/runtime/utils.py
git commit -m "formatting"
git push origin more-uniform-pipeline

@mrwyattii mrwyattii added this pull request to the merge queue Jul 27, 2023
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Jul 27, 2023
@mrwyattii mrwyattii added this pull request to the merge queue Aug 3, 2023
Merged via the queue into microsoft:master with commit e831863 Aug 3, 2023
16 checks passed
polisettyvarma pushed a commit to polisettyvarma/DeepSpeed that referenced this pull request Aug 7, 2023
…4053)

* update partition_uniform util function

* formatting

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants