Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does DPOTrainer loss mask the prompts? #1041

Closed
tokestermw opened this issue Nov 29, 2023 · 4 comments
Closed

Does DPOTrainer loss mask the prompts? #1041

tokestermw opened this issue Nov 29, 2023 · 4 comments
Labels
DPO Question related to DPO and DPOTrainer

Comments

@tokestermw
Copy link

Hi quick question, so DataCollatorForCompletionOnlyLM will train only on the responses by loss masking the prompts.

Does it work this way with DPOTrainer (DPODataCollatorWithPadding) as well? Looking at the code, it does look like it trains on the prompts. But maybe it doesn't matter with the DPO loss.

The reason I ask is my dataset has long prompts, but short responses. And the resulting model trained by DPO is barely different from the reference model (even with accuracy 90+%). So loss masking the prompts may help focus the learning, just a guess.

Thanks!

@lvwerra lvwerra added the DPO Question related to DPO and DPOTrainer label Nov 30, 2023
@lvwerra
Copy link
Member

lvwerra commented Nov 30, 2023

As far as I know we don't mask prompts in DPO but @kashif might know more about this.

@rpowalski
Copy link

I am having the same problem with the DPO training not improving the baseline model

Nevertheless, DPODataCollatorWithPadding seems to put an ignore index for label indexes corresponding to prompt so probably the reason is elsewhere
Here is the relevant piece of code that does that:

chosen_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len(

@tokestermw
Copy link
Author

ah thanks! @rpowalski

@sadaisystems
Copy link

sadaisystems commented Mar 15, 2024

Gentlemen, If I understood your discussion correctly (and the code snippet mentioned), DPOTrainer does in fact mask out the prompt tokens by default? Were you able to figure out why your models did not improve?

@rpowalski @tokestermw

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
DPO Question related to DPO and DPOTrainer
Projects
None yet
Development

No branches or pull requests

4 participants