-
Notifications
You must be signed in to change notification settings - Fork 25.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama 2 model divergence with FSDP #28826
Comments
cc @younesbelkada I think we have seen something similar recently? |
@Teng-xu are you correctly enabling mixed precision through |
Yeah bf16 was passed into the training args, and I can verify it is being applied correctly. |
Just to provide more context on this issue I am attaching a simple script to reproduce the issue and its associated output. Note, I am just using a random tensor as the dataset and for consistency I just saved the labels associated from another training script and loaded it from a pickle object. Script:
Output of script:
|
Tagging @pacman100 to take a look. |
Hi @rnadimp |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
System Info
transformers
version: 4.37.1Who can help?
When fine-tuning Llama 2 model with HF 4.37 and PT FSDP, found model divergence in comparison to HF 4.31. Fine-tuning with 4.31 works fine, but with HF 4.37, the loss consistently rises instead of stabilizing when setting attn_implementation="flash_attention_2", while attn_implementation="sdpa" works fine.
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
The model is inited as
model = AutoModelForCausalLM.from_pretrained(pretrained_model_weights, attn_implementation="flash_attention_2")
Expected behavior
The loss should not go up as the training goes.
The text was updated successfully, but these errors were encountered: