Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Wrong embedding gradients with distributed optimizer and shared embedding #844

Closed
li-plus opened this issue May 28, 2024 · 3 comments

Comments

@li-plus
Copy link

li-plus commented May 28, 2024

Describe the bug
When use_distributed_optimizer is enabled for models with share_embeddings_and_output_weights such as GPT2, all model gradients are reduce-scattered across DP ranks before the embedding gradients are all-reduced across PP[0] & PP[-1]. See

def finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torch.Tensor] = None):
"""
All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism,
embedding grads across first and last pipeline stages (if not tied),
scale gradients by `num_tokens`.
"""
config = get_model_config(model[0])
# All-reduce / reduce-scatter across DP replicas.
if config.timers is not None:
config.timers('all-grads-sync', log_level=1).start(barrier=config.barrier_with_L1_time)
for model_chunk in model:
model_chunk.finish_grad_sync()
if config.timers is not None:
config.timers('all-grads-sync').stop()
# All-reduce layer-norm grads (for sequence parallelism).
if config.timers is not None:
config.timers('layernorm-grads-all-reduce', log_level=1).start(
barrier=config.barrier_with_L1_time
)
_allreduce_layernorm_grads(model, config)
if config.timers is not None:
config.timers('layernorm-grads-all-reduce').stop()
# All-reduce embedding grads (for pipeline parallelism).
if config.timers is not None:
config.timers('embedding-grads-all-reduce', log_level=1).start(
barrier=config.barrier_with_L1_time
)
_allreduce_embedding_grads(model, config)
if config.timers is not None:
config.timers('embedding-grads-all-reduce').stop()
# normalize gradients for per-token loss normalization.
# if we are using by the number of tokens, then we use that as a divisor. this number
# will be the total number of non-padded tokens in the global batch.
if num_tokens is not None:
# the number of tokens is only present on the last stage, so broadcast it
# to the other ranks in the pipeline parallel group.
torch.distributed.broadcast(
num_tokens,
src=parallel_state.get_pipeline_model_parallel_last_rank(),
group=parallel_state.get_pipeline_model_parallel_group(),
)
# all-reduce across DP ranks.
torch.distributed.all_reduce(num_tokens, group=parallel_state.get_data_parallel_group())
for model_chunk in model:
if num_tokens > 0:
scaling = 1.0 / num_tokens
model_chunk.scale_gradients(scaling)

Note that the wte gradients and lm_head gradients lie in different partitions of the contiguous gradient buffer (wte is the first weight on PP[0], lm_head is the last weight on PP[-1]), so they will be reduce-scattered to different DP ranks. The following embedding gradients all-reduce across PP[0] and PP[-1] within same DP group will add up partial results, producing wrong embedding gradients.

For example, consider only embedding gradients with dp=2 and pp=2 on 4 GPUs:

  1. before reduce-scatter across DP ranks:
pp \ dp 0 1
0 g0 g1
1 g2 g3
  1. after reduce-scatter across DP ranks:
pp \ dp 0 1
0 g0 (g0+g1)/2
1 (g2+g3)/2 g3
  1. after all-reduce embedding grad across PP[0] & PP[-1]:
pp \ dp 0 1
0 g0+(g2+g3)/2 g3+(g0+g1)/2
1 g0+(g2+g3)/2 g3+(g0+g1)/2

Embedding gradients on rank1 (pp0, dp1) and rank2 (pp1, dp0) are used in optimizer to update weights. They are expected to be the same, but they are not in fact.

To Reproduce
Run pretrain_gpt.py with pp=2 and dp=2 on 4 local GPUs. Before returning from finalize_model_grads, print wte gradient hash on PP[0] and lm_head gradient hash on PP[-1].

    if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
        print(f'[Rank {torch.distributed.get_rank()}] embedding grad hash {model[0].module.module.language_model.embedding.word_embeddings.weight.main_grad.sum()}')
    elif parallel_state.is_pipeline_last_stage(ignore_virtual=True):
        print(f'[Rank {torch.distributed.get_rank()}] embedding grad hash {model[-1].module.module.word_embeddings.weight.main_grad.sum()}')

It can be observed that gradients on rank1 and rank2 are apparently different.

Expected behavior
Correct embedding gradients with distributed optimizer for models with tied embeddings.

Environment (please complete the following information):

  • Megatron-LM commit ID de4028a
  • PyTorch version 2.1.0+cu121
  • CUDA version 12.1
  • NCCL version 2.18.1

Proposed fix
Move _allreduce_embedding_grads before finish_grad_sync. Will open a PR soon. Expected embedding gradient flow:

  1. embedding gradients:
pp \ dp 0 1
0 g0 g1
1 g2 g3
  1. all-reduce embedding grad across PP[0] & PP[-1]:
pp \ dp 0 1
0 g0+g2 g1+g3
1 g0+g2 g1+g3
  1. reduce-scatter across DP ranks:
pp \ dp 0 1
0 g0+g2 (g0+g2+g1+g3)/2
1 (g0+g2+g1+g3)/2 g1+g3
@deepakn94
Copy link
Collaborator

Does the following fix not work: daf0006?

In particular, these lines: daf0006#diff-703512d9cce575fe32a776ec738162312b6276de08ac4846a50f07e3903cfdacR239-R245.

@li-plus
Copy link
Author

li-plus commented May 29, 2024

It works. Problem is I've been using Megatron released on Jan without separate bucket for shared embedding. Just switched to latest master and solved it. Thanks!

@li-plus li-plus closed this as completed May 29, 2024
@deepakn94
Copy link
Collaborator

Awesome, great to hear!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants