Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] 'type:transformer' partitioning doesn't ensure non-zero parameters on each pipeline rank. #5078

Open
siddharth9820 opened this issue Feb 5, 2024 · 5 comments
Assignees

Comments

@siddharth9820
Copy link
Contributor

siddharth9820 commented Feb 5, 2024

Running Megatron-Deepspeed with pipelining seems to call PipeModule with the type:transformer partioning method which leads to this line of code - (

self.parts = ds_utils.partition_balanced(weights=binary_weights, num_parts=num_stages)
)

I tried running this with a model with 42 layers, tensor parallel=4, and pipeline=16. pipe ranks 15 and 16 were assigned 0 layers. Something needs to be changed to ensure that non-zero layers are assigned to each rank.

@tjruwase
Copy link
Contributor

tjruwase commented Feb 5, 2024

@siddharth9820, thanks for reporting this error. I am curious if this is a recent regression due to the below PR that changed the balancing algorithm:
#4312

Can you please try earlier DS versions (v. 0.13.0 or 0.12.6) or revert the PR?

@siddharth9820
Copy link
Contributor Author

@tjruwase I am able to reproduce the error outside of Megatron-DeepSpeed as well -
image

I'll try the other versions too. Thanks for the pointer.

About potential fixes. - Could you first assign 1 layer to each rank first and then run this function on n-m layers and m ranks? But that wouldn't be an ideal fix if the weights aren't uniform.

@siddharth9820 siddharth9820 changed the title 'type:transformer' partitioning doesn't ensure non-zero parameters on each pipeline rank. [BUG] 'type:transformer' partitioning doesn't ensure non-zero parameters on each pipeline rank. Feb 5, 2024
@tjruwase
Copy link
Contributor

tjruwase commented Feb 5, 2024

@siddharth9820, thanks for the update. This seems like an implementation bug as I find it hard to believe both the new and old algorithms fail these seemingly practical cases.

  1. Old algorithm - Fast Optimal Load Balancing Algorithms for 1D Partitioning
  2. New algorithm - https://www8.cs.umu.se/kurser/TDBAfl/VT06/algorithms/BOOK/BOOK2/NODE45.HTM

@tjruwase
Copy link
Contributor

tjruwase commented Feb 5, 2024

About potential fixes. - Could you first assign 1 layer to each rank first and then run this function on n-m layers and m ranks? But that wouldn't be an ideal fix if the weights aren't uniform.

Yes, it does not seem like this approach would be balanced. I think it will only increase the minimum from zero to one. Right?

@siddharth9820
Copy link
Contributor Author

Yes it won't be balanced. But atleast it will "run" with Megatron Deepspeed. With the current approach, I was getting "empty parameter" errors during optimizer initialization. I believe this was happening on the second last pp rank, since it became parameterless.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants