Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Z3: optimizations for grad norm calculation and gradient clipping #5504

Merged
merged 25 commits into from
Aug 15, 2024

Conversation

nelyahu
Copy link
Contributor

@nelyahu nelyahu commented May 7, 2024

This PR add the below functionality:

  1. complete_grad_norm_calculation_for_cpu_offload: move total_norm to CPU, as expected device in such case is CPU..
  2. repalce get_global_norm() with torch.linalg.norm for better performance.
  3. unscale_and_clip_grads: replace clipping based on if statement to use torch.clamp for better performance.

change (3) is taken from #5547 (which was closed)

@jomayeri
Copy link
Contributor

Changing this line has been associated with several bugs #5422, #5538

@nelyahu nelyahu changed the title z3 scaled_global_grad_norm: repalce get_global_norm with torch.norm Z3: optimizations for grad norm calculation and gradient clipping May 27, 2024
@loadams
Copy link
Contributor

loadams commented Jun 26, 2024

Changing this line has been associated with several bugs #5422, #5538

@nelyahu - thoughts on this comment, seems last time this line was modified users ran into issues?

@nelyahu
Copy link
Contributor Author

nelyahu commented Jun 26, 2024

Changing this line has been associated with several bugs #5422, #5538

@nelyahu - thoughts on this comment, seems last time this line was modified users ran into issues?

@loadams, Yes - this optimization was already pushed and reverted due to ds-chat (failures in cpu-offload configurations).
i did offline debugging of those failure and improved the code change so it will pass. Since then ds-chat tests where added to DeepSpeed repo CI and it is now passing.
Are there any other tests (full model training for example), that does not exists in the CI, which can be manually ran?

@tjruwase
Copy link
Contributor

i did offline debugging of those failure and improved the code change so it will pass

@nelyahu, it great that you narrowed this down. Do you think a unit test can be added for this case?

@loadams
Copy link
Contributor

loadams commented Jul 9, 2024

i did offline debugging of those failure and improved the code change so it will pass

@nelyahu, it great that you narrowed this down. Do you think a unit test can be added for this case?

@nelyahu - we've stabilized the CI, thoughts on adding this test?

@nelyahu
Copy link
Contributor Author

nelyahu commented Jul 10, 2024

@loadams Oh, Sorry- I missed the last comment.
Sure, yes we can add such UT that will cover it.
But i cannot address it immediately. I will update this PR once we have a unit test.

@nelyahu nelyahu requested a review from loadams as a code owner July 18, 2024 15:02
this will allow covering the gradient clipping flow with zero3
and catch issues, such as was observed during this PR
@nelyahu
Copy link
Contributor Author

nelyahu commented Jul 18, 2024

@loadams / @tjruwase as request i make sure the regression was discussed here will be covered by a unit tests. I used TestZeroPartialOffloadConfigSweep, and added it gradient_clipping so it will go though the problematic flow.
i reproduced the issue using the UT, and made sure it is fixed.

@nelyahu
Copy link
Contributor Author

nelyahu commented Jul 23, 2024

@loadams can you re-run workflows? i suspect that the failure are not related to this PR

@loadams loadams enabled auto-merge July 23, 2024 20:11
@loadams loadams added this pull request to the merge queue Aug 14, 2024
Merged via the queue into microsoft:master with commit 6eed634 Aug 15, 2024
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants