Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] DataCollatorForSeq2Seq with PaddingStrategy.MAX_LENGTH may not pad labels #30521

Closed
muzhi1991 opened this issue Apr 27, 2024 · 4 comments · Fixed by #30556
Closed

[BUG] DataCollatorForSeq2Seq with PaddingStrategy.MAX_LENGTH may not pad labels #30521

muzhi1991 opened this issue Apr 27, 2024 · 4 comments · Fixed by #30556

Comments

@muzhi1991
Copy link
Contributor

It seems that when padding, if the MAX_LENGTH policy is set, the same padding is not performed on the label.

test case below:

from transformers import DataCollatorForSeq2Seq,
from transformers.utils import PaddingStrategy
inputs=[{'input_ids': [151644, 8948, 198],'attention_mask': [1, 1, 1],'labels': [-100, -100, -100]},
 {'input_ids': [151644, 8948, 198, 2610],'attention_mask': [1, 1, 1, 1],'labels': [-100, -100, -100, -100]},
 {'input_ids': [151644, 8948, 198, 2610, 525], 'attention_mask': [1, 1, 1, 1, 1],'labels': [-100, -100, -100, -100, -100]}]
data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        padding=PaddingStrategy.MAX_LENGTH,
        max_length=10,
    )
res=data_collator(inputs)

print(res['input_ids'].shape,res['labels'].shape)

results:
torch.Size([3, 10]) torch.Size([3, 5])

expected results:
torch.Size([3, 10]) torch.Size([3, 10])

Should the following code handle the pad length of the label according to different strategies?

max_label_length = max(len(l) for l in labels)

@muzhi1991 muzhi1991 changed the title DataCollatorForSeq2Seq with PaddingStrategy.MAX_LENGTH may not pad labels [BUG] DataCollatorForSeq2Seq with PaddingStrategy.MAX_LENGTH may not pad labels Apr 27, 2024
@vasqu
Copy link
Contributor

vasqu commented Apr 27, 2024

Thanks for raising this issue! Yea, that seems like a valid bug imo. The padding strategy isn't respected with max_length.

I'd change these lines:

if labels is not None:
max_label_length = max(len(l) for l in labels)

to something like:

no_padding = self.padding == False or self.padding == PaddingStrategy.DO_NOT_PAD
if labels is not None and not no_padding:
    max_padding = self.padding == PaddingStrategy.MAX_LENGTH and self.max_length is not None
    max_label_length = max(len(l) for l in labels) if not max_padding else self.max_length

no_padding is also not respected but it doesn't matter too much considering longest will result in the same end state. So the first line might be unnecessary, it just saves some computation ig.

Running this for a similar example to yours:

from transformers import BartTokenizer, DataCollatorForSeq2Seq
from transformers.utils import PaddingStrategy

tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
inputs = [{'input_ids': [151644, 8948, 198],'attention_mask': [1, 1, 1],'labels': [1, -100, -100]},
          {'input_ids': [151644, 8948, 198, 2610],'attention_mask': [1, 1, 1, 1],'labels': [2, 5, -100, -100]},
          {'input_ids': [151644, 8948, 198, 2610, 525], 'attention_mask': [1, 1, 1, 1, 1],'labels': [3, 4, 6, -100, -100]}]

data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    padding=PaddingStrategy.MAX_LENGTH,
    max_length=10,
)
res = data_collator(inputs)
print(res['input_ids'].shape, res['labels'].shape)

Output: torch.Size([3, 10]) torch.Size([3, 10])

@amyeroberts
Copy link
Collaborator

cc @Rocketknight1 as you appear to be the most recent person to touch the data collators :)

@Rocketknight1
Copy link
Member

Transformers maintainer here: Yes, this looks like a bug! Also @vasqu your solution looks good - would you be willing to make a PR to add it?

@vasqu
Copy link
Contributor

vasqu commented Apr 29, 2024

@Rocketknight1 opened a PR at #30556 including some tests for the se2seq collator since there haven't been any.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants