Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Padding side incorrect for Mistral DPO #1217

Closed
hengjiUSTC opened this issue Jan 11, 2024 · 9 comments · Fixed by #1290
Closed

Padding side incorrect for Mistral DPO #1217

hengjiUSTC opened this issue Jan 11, 2024 · 9 comments · Fixed by #1290

Comments

@hengjiUSTC
Copy link

hengjiUSTC commented Jan 11, 2024

For code here: https://github.com/huggingface/trl/blob/main/trl/trainer/utils.py#L538C4-L550 Does it always padding to right?

I get error when run dpo for mixtral 7b.

I am still seeing this error after set tokenizer.padding_side = 'left'
This is my demo notebook: https://colab.research.google.com/drive/1sVqbYEOqjJYl7CzNzXzviEBB6A984cMq?usp=sharing

Tokenizer already set with left padding
截屏2024-01-10 下午10 53 07

Train
截屏2024-01-11 上午10 28 53

Still have: ValueError: You are attempting to perform batched generation with padding_side='right' this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to call tokenizer.padding_side = 'left' before tokenizing the input.

截屏2024-01-10 下午10 54 31

Seems https://github.com/huggingface/trl/blob/main/trl/trainer/utils.py#L538C4-L550 or other part of code might lead to bug?

transformers 4.36.2
trl 0.7.7
peft 0.6.0

@hengjiUSTC hengjiUSTC changed the title Padding side incorrect for Mixtral DPO Padding side incorrect for Mistral DPO Jan 11, 2024
@hengjiUSTC
Copy link
Author

Did some debugging, quiet confirmed that https://github.com/huggingface/trl/blob/main/trl/trainer/utils.py#L538C4-L550 gives incorrect padding side and lead to crash in transformers.

File ~/learn-llm/venv/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:899, in MistralModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
    897     is_padding_right = attention_mask[:, -1].sum().item() != batch_size
    898     if is_padding_right:
--> 899         raise ValueError(
    900             "You are attempting to perform batched generation with padding_side='right'"
    901             " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
    902             " call `tokenizer.padding_side  = 'left'` before tokenizing the input. "
    903         )
    905 if self._use_flash_attention_2:
    906     # 2d mask is passed through the layers
    907     attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None

ValueError: You are attempting to perform batched generation with padding_side='right' this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to  call `tokenizer.padding_side  = 'left'` before tokenizing the input. 

@gchhablani
Copy link

Can I take this up @vwxyzjn?

@arkapal3
Copy link

I had the same issue and went down a debugging rabbit hole, but realised in the end I had forgotten to set use_cache to False for my reference model.

Remember you need to set use_cache = False for both the main and reference model. If you do so the check on is_padding_right above won't trigger (which it shouldn't, because we are using the forward to compute log probs, not do generation).

@hengjiUSTC
Copy link
Author

I had the same issue and went down a debugging rabbit hole, but realised in the end I had forgotten to set use_cache to False for my reference model.

Remember you need to set use_cache = False for both the main and reference model. If you do so the check on is_padding_right above won't trigger (which it shouldn't, because we are using the forward to compute log probs, not do generation).

Checked this solution work.

@vwxyzjn
Copy link
Contributor

vwxyzjn commented Jan 12, 2024

Hi thanks for the issue. Maybe during training the tokenizer.padding_side should be set to right?

@hengjiUSTC
Copy link
Author

Any reason that training should be right? For most of dpo scripts I have seen, they all set mistral tokenizer padding to left.

Hi thanks for the issue. Maybe during training the tokenizer.padding_side should be set to right?

@vwxyzjn
Copy link
Contributor

vwxyzjn commented Jan 12, 2024

For both our and DPO's original repo, we all do right padding. Out of curiosity, which DPO scripts have you seen use the pad left? Pad left is primarily used for generation such as in PPO but in DPO it's unnecessary because DPO does not do generation.

image image

@hengjiUSTC
Copy link
Author

hengjiUSTC commented Jan 12, 2024

@younesbelkada
Copy link
Contributor

This should be now fixed on TRL main !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants