Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] Add gradient predivide factor to avoid overflow/underflow with large world size #565

Merged
merged 4 commits into from Apr 3, 2021

Conversation

shruti-bh
Copy link
Contributor

This PR tries to optimally scale gradients when the world_size is big.
For example, reduce_scatter(gradient) / world_size might overflow and reduce_scatter(gradient / world_size) might underflow with world_size=1024. We might prefer instead reduce_scatter(gradient / 32) / 32.
See Nvidia’s note about this

Testing:
After 10 updates, I see ~ consistent loss/gnorm with the main branch and this branch.
Main:
2021-03-31 16:55:49 | INFO | train | epoch 001 | loss 16.361 | nll_loss 16.324 | moe_gate_loss 2.56751 | overflow_expert1 24.327 | overflow_expert2 53.304 | entropy_gating 1.994 | expert1_balance_top 68.159 | expert1_balance_bottom 2.701 | unused_expert1_count 0 | expert2_balance_top 54.38 | expert2_balance_bottom 4.655 | unused_expert2_count 0 | ppl 82053.2 | wps 212239 | ups 6.47 | wpb 32768 | bsz 32 | num_updates 10 | lr 0.002 | gnorm 4.115 | loss_scale 32 | train_wall 4 | gb_free 29.2 | wall 24
This Branch:
2021-03-31 16:57:48 | INFO | train | epoch 001 | loss 16.361 | nll_loss 16.324 | moe_gate_loss 2.56767 | overflow_expert1 24.324 | overflow_expert2 53.299 | entropy_gating 1.994 | expert1_balance_top 68.156 | expert1_balance_bottom 2.7 | unused_expert1_count 0 | expert2_balance_top 54.396 | expert2_balance_bottom 4.65 | unused_expert2_count 0 | ppl 82051.3 | wps 216940 | ups 6.61 | wpb 32768 | bsz 32 | num_updates 10 | lr 0.002 | gnorm 4.116 | loss_scale 32 | train_wall 6 | gb_free 29.2 | wall 24

After 100 updates, the loss/gnorm numbers are slightly different after the 2nd decimal place between the main branch and this branch
Main:
2021-03-31 17:01:50 | INFO | train | epoch 001 | loss 11.751 | nll_loss 11.713 | moe_gate_loss 2.63956 | overflow_expert1 43.949 | overflow_expert2 64.359 | entropy_gating 2.022 | expert1_balance_top 84.331 | expert1_balance_bottom 1.088 | unused_expert1_count 0 | expert2_balance_top 75.253 | expert2_balance_bottom 1.84 | unused_expert2_count 0 | ppl 3356.47 | wps 229459 | ups 7 | wpb 32768 | bsz 32 | num_updates 100 | lr 0 | gnorm 0.931 | loss_scale 32 | train_wall 16 | gb_free 29.2 | wall 38

This Branch:
2021-03-31 16:59:44 | INFO | train | epoch 001 | loss 11.747 | nll_loss 11.706 | moe_gate_loss 2.79818 | overflow_expert1 44.95 | overflow_expert2 65.851 | entropy_gating 2.003 | expert1_balance_top 84.568 | expert1_balance_bottom 0.411 | unused_expert1_count 0 | expert2_balance_top 76.632 | expert2_balance_bottom 0.981 | unused_expert2_count 0 | ppl 3341.38 | wps 227973 | ups 6.96 | wpb 32768 | bsz 32 | num_updates 100 | lr 0 | gnorm 0.937 | loss_scale 32 | train_wall 18 | gb_free 29.2 | wall 37

Would love advice on how to proceed given these differences which would tend to emerge since we are scaling the gradients a bit differently now.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 1, 2021
Copy link
Contributor

@myleott myleott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple comments, but otherwise LG, thanks!

fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
@@ -1071,7 +1078,7 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:

if self.world_size > 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to if self.gradient_predivide_factor > 1

@@ -1098,7 +1105,7 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) ->
assert torch.cuda.current_stream() == self._streams["post_backward"]
assert param.grad is not None
self.assert_state(TrainingState.BACKWARD_POST)
param.grad.data = reduced_grad
param.grad.data = reduced_grad.div_(self.world_size/self.gradient_predivide_factor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to:

param.grad.data = reduced_grad
if self.gradient_postdivide_factor > 1:
    param.grad.data.div_(self.gradient_postdivide_factor)

@myleott myleott merged commit 04001e7 into master Apr 3, 2021
@myleott myleott deleted the fsdp_grad_predivide_factor branch April 3, 2021 21:29
@myleott
Copy link
Contributor

myleott commented Apr 3, 2021

@shruti-bh I pushed the lint fixes and merged this 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants