New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] DataCollatorForSeq2Seq with PaddingStrategy.MAX_LENGTH may not pad labels #30521
Comments
Thanks for raising this issue! Yea, that seems like a valid bug imo. The padding strategy isn't respected with I'd change these lines: transformers/src/transformers/data/data_collator.py Lines 591 to 592 in 73014b5
to something like: no_padding = self.padding == False or self.padding == PaddingStrategy.DO_NOT_PAD
if labels is not None and not no_padding:
max_padding = self.padding == PaddingStrategy.MAX_LENGTH and self.max_length is not None
max_label_length = max(len(l) for l in labels) if not max_padding else self.max_length
Running this for a similar example to yours: from transformers import BartTokenizer, DataCollatorForSeq2Seq
from transformers.utils import PaddingStrategy
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
inputs = [{'input_ids': [151644, 8948, 198],'attention_mask': [1, 1, 1],'labels': [1, -100, -100]},
{'input_ids': [151644, 8948, 198, 2610],'attention_mask': [1, 1, 1, 1],'labels': [2, 5, -100, -100]},
{'input_ids': [151644, 8948, 198, 2610, 525], 'attention_mask': [1, 1, 1, 1, 1],'labels': [3, 4, 6, -100, -100]}]
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
padding=PaddingStrategy.MAX_LENGTH,
max_length=10,
)
res = data_collator(inputs)
print(res['input_ids'].shape, res['labels'].shape) Output: |
cc @Rocketknight1 as you appear to be the most recent person to touch the data collators :) |
Transformers maintainer here: Yes, this looks like a bug! Also @vasqu your solution looks good - would you be willing to make a PR to add it? |
@Rocketknight1 opened a PR at #30556 including some tests for the se2seq collator since there haven't been any. |
It seems that when padding, if the MAX_LENGTH policy is set, the same padding is not performed on the label.
test case below:
results:
torch.Size([3, 10]) torch.Size([3, 5])
expected results:
torch.Size([3, 10]) torch.Size([3, 10])
Should the following code handle the pad length of the label according to different strategies?
transformers/src/transformers/data/data_collator.py
Line 592 in 73014b5
The text was updated successfully, but these errors were encountered: