You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
When use_distributed_optimizer is enabled for models with share_embeddings_and_output_weights such as GPT2, all model gradients are reduce-scattered across DP ranks before the embedding gradients are all-reduced across PP[0] & PP[-1]. See
Note that the wte gradients and lm_head gradients lie in different partitions of the contiguous gradient buffer (wte is the first weight on PP[0], lm_head is the last weight on PP[-1]), so they will be reduce-scattered to different DP ranks. The following embedding gradients all-reduce across PP[0] and PP[-1] within same DP group will add up partial results, producing wrong embedding gradients.
For example, consider only embedding gradients with dp=2 and pp=2 on 4 GPUs:
before reduce-scatter across DP ranks:
pp \ dp
0
1
0
g0
g1
1
g2
g3
after reduce-scatter across DP ranks:
pp \ dp
0
1
0
g0
(g0+g1)/2
1
(g2+g3)/2
g3
after all-reduce embedding grad across PP[0] & PP[-1]:
pp \ dp
0
1
0
g0+(g2+g3)/2
g3+(g0+g1)/2
1
g0+(g2+g3)/2
g3+(g0+g1)/2
Embedding gradients on rank1 (pp0, dp1) and rank2 (pp1, dp0) are used in optimizer to update weights. They are expected to be the same, but they are not in fact.
To Reproduce
Run pretrain_gpt.py with pp=2 and dp=2 on 4 local GPUs. Before returning from finalize_model_grads, print wte gradient hash on PP[0] and lm_head gradient hash on PP[-1].
ifparallel_state.is_pipeline_first_stage(ignore_virtual=True):
print(f'[Rank {torch.distributed.get_rank()}] embedding grad hash {model[0].module.module.language_model.embedding.word_embeddings.weight.main_grad.sum()}')
elifparallel_state.is_pipeline_last_stage(ignore_virtual=True):
print(f'[Rank {torch.distributed.get_rank()}] embedding grad hash {model[-1].module.module.word_embeddings.weight.main_grad.sum()}')
It can be observed that gradients on rank1 and rank2 are apparently different.
Expected behavior
Correct embedding gradients with distributed optimizer for models with tied embeddings.
Environment (please complete the following information):
It works. Problem is I've been using Megatron released on Jan without separate bucket for shared embedding. Just switched to latest master and solved it. Thanks!
Describe the bug
When
use_distributed_optimizer
is enabled for models withshare_embeddings_and_output_weights
such as GPT2, all model gradients are reduce-scattered across DP ranks before the embedding gradients are all-reduced across PP[0] & PP[-1]. SeeMegatron-LM/megatron/core/distributed/finalize_model_grads.py
Lines 99 to 150 in 0650d83
Note that the wte gradients and lm_head gradients lie in different partitions of the contiguous gradient buffer (wte is the first weight on PP[0], lm_head is the last weight on PP[-1]), so they will be reduce-scattered to different DP ranks. The following embedding gradients all-reduce across PP[0] and PP[-1] within same DP group will add up partial results, producing wrong embedding gradients.
For example, consider only embedding gradients with dp=2 and pp=2 on 4 GPUs:
Embedding gradients on rank1 (pp0, dp1) and rank2 (pp1, dp0) are used in optimizer to update weights. They are expected to be the same, but they are not in fact.
To Reproduce
Run
pretrain_gpt.py
with pp=2 and dp=2 on 4 local GPUs. Before returning fromfinalize_model_grads
, printwte
gradient hash on PP[0] andlm_head
gradient hash on PP[-1].It can be observed that gradients on rank1 and rank2 are apparently different.
Expected behavior
Correct embedding gradients with distributed optimizer for models with tied embeddings.
Environment (please complete the following information):
Proposed fix
Move
_allreduce_embedding_grads
beforefinish_grad_sync
. Will open a PR soon. Expected embedding gradient flow:The text was updated successfully, but these errors were encountered: