Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix overlap communication of ZeRO stage 1 and 2 #5606

Merged
merged 3 commits into from
Jun 10, 2024

Conversation

penn513
Copy link
Contributor

@penn513 penn513 commented Jun 3, 2024

deepspeed.runtime.zero.stage_1_and_2.DeepSpeedZeroOptimizer.average_tensor only sets reduction stream waiting for default stream. This is ok in cases where the computation time is longer than the communication time, but when the communication time is longer, it may result in a rewrite of the ipg_buffer when the communication is not completed.

image

To fix this bug, the easiest way is just add default stream to wait for reduction stream at the same point. For example, in point 1, the reduction stream needs to wait for '2', so we add a wait_stream to reduction stream waiting for default stream. Also, the default stream needs to wait for 'A', so we need to add a wait_stream to default stream waiting for reduction stream before the 'B'.

image

Compared with the modification of #5523, wait_stream does not cause host synchronization.

Compared with the modification of #5545, the modification is more simple and the logic is the same, just waiting for what needs to wait.


With this modification, losses of Qwen-1.5 with and without overlap_comm are totally identical.

image


On the contrary, there is an obvious gap with a small sequence length, which means a short computation time.

image

@penn513
Copy link
Contributor Author

penn513 commented Jun 3, 2024

@microsoft-github-policy-service agree company="Huawei"

@CurryRice233
Copy link
Contributor

Hi @tjruwase @GuanhuaWang , would you please help to review this PR?

@tjruwase tjruwase requested review from GuanhuaWang and removed request for mrwyattii June 4, 2024 07:54
@GuanhuaWang
Copy link
Member

GuanhuaWang commented Jun 5, 2024

Hi @penn513 thx for the nice figure and pr.

I think your fix on compute stream wait back reduce stream make sense to me, especially when compute is shorter.

To make PR more concise, could you remove your modification on npu fuseadam in current pr, and make a new pr on fused adam? (mainly because it is irrelevant with this PR title as "fix overlap communication...")

Co-authored-by: CurryRice233 <nmeia@qq.com>
@penn513
Copy link
Contributor Author

penn513 commented Jun 6, 2024

Hi @penn513 thx for the nice figure and pr.

I think your fix on compute stream wait back reduce stream make sense to me, especially when compute is shorter.

To make PR more concise, could you remove your modification on npu fuseadam in current pr, and make a new pr on fused adam? (mainly because it is irrelevant with this PR title as "fix overlap communication...")

Thanks for your reply. It's been updated.

@jomayeri jomayeri added this pull request to the merge queue Jun 7, 2024
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Jun 7, 2024
@jomayeri jomayeri added this pull request to the merge queue Jun 7, 2024
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Jun 7, 2024
@loadams loadams enabled auto-merge June 7, 2024 22:21
@loadams loadams added this pull request to the merge queue Jun 9, 2024
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Jun 9, 2024
@loadams loadams added this pull request to the merge queue Jun 10, 2024
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Jun 10, 2024
@tjruwase tjruwase added this pull request to the merge queue Jun 10, 2024
Merged via the queue into microsoft:master with commit a41729f Jun 10, 2024
15 checks passed
sfc-gh-reyazda pushed a commit to Snowflake-Labs/DeepSpeed that referenced this pull request Jun 10, 2024
`deepspeed.runtime.zero.stage_1_and_2.DeepSpeedZeroOptimizer.average_tensor`
only sets reduction stream waiting for default stream. This is ok in
cases where the computation time is longer than the communication time,
but when the communication time is longer, it may result in a rewrite of
the ipg_buffer when the communication is not completed.



![image](https://github.com/microsoft/DeepSpeed/assets/35059704/950cbf8a-f439-4cf9-a364-dcdfd47f46a0)



To fix this bug, the easiest way is just add default stream to wait for
reduction stream at the **same point**. For example, in point 1, the
`reduction stream` needs to wait for '2', so we add a wait_stream to
`reduction stream` waiting for `default stream`. Also, the `default
stream` needs to wait for 'A', so we need to add a wait_stream to
`default stream` waiting for `reduction stream` before the 'B'.


![image](https://github.com/microsoft/DeepSpeed/assets/35059704/588a9469-d3f9-4c39-976d-3ae0502cf1d1)



Compared with the modification of
microsoft#5523, wait_stream does not
cause host synchronization.

Compared with the modification of
microsoft#5545, the modification is
more simple and the logic is the same, just waiting for what needs to
wait.

---

With this modification, losses of Qwen-1.5 with and without overlap_comm
are totally identical.


![image](https://github.com/microsoft/DeepSpeed/assets/35059704/4d48d54e-e55b-4230-8b99-93549910a43f)

---

On the contrary, there is an obvious gap with a small sequence length,
which means a short computation time.


![image](https://github.com/microsoft/DeepSpeed/assets/35059704/c80af498-3358-4e36-9b13-8f266551d51d)

Co-authored-by: gp513 <guopeng34@huawei.com>
Co-authored-by: CurryRice233 <nmeia@qq.com>
Co-authored-by: Joe Mayer <114769929+jomayeri@users.noreply.github.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants