Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DPOTrainer] Fix DPO trainer + mistral + FA2 #1290

Merged
merged 1 commit into from
Jan 30, 2024
Merged

[DPOTrainer] Fix DPO trainer + mistral + FA2 #1290

merged 1 commit into from
Jan 30, 2024

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Jan 30, 2024

Fixes: #1217
Fixes: huggingface/transformers#26877
Fixes: #1266

Simply setting use_cache=False circumvents all issues with FA-2 + DPO + Mistral. in fact we should bypass that check since we are not in text generation mode when computing the loss function. use_cache is retrieved from the model config by default which falls back always to True. The cache is not used anyway when purely computing the logits so this change is fully BC

cc @kashif @vwxyzjn

To test that, I managed to repro the issue by adding --attn_implementation "flash_attention_2" in the dpo shell script on a A100 machine, and I confirm this PR fixes it. Unfortunately our CI runners are not compatible with FA2 so we cannot add a slow test to test that

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@kashif
Copy link
Collaborator

kashif commented Jan 30, 2024

thank you great help!

@kashif kashif merged commit b415224 into main Jan 30, 2024
9 checks passed
@kashif kashif deleted the fix-dpo-fa2 branch January 30, 2024 07:25
lapp0 pushed a commit to lapp0/trl that referenced this pull request May 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants