Skip to content

fix(data): Handle integer labels in DataCollatorWithFlattening#41687

Open
yashisthebatman wants to merge 1 commit intohuggingface:mainfrom
yashisthebatman:fix-collator-int-labels
Open

fix(data): Handle integer labels in DataCollatorWithFlattening#41687
yashisthebatman wants to merge 1 commit intohuggingface:mainfrom
yashisthebatman:fix-collator-int-labels

Conversation

@yashisthebatman
Copy link
Copy Markdown

What does this PR do?

This PR fixes a TypeError in DataCollatorWithFlattening that occurs when processing datasets with integer labels, which is a common format for sequence classification tasks.

The collator was written with the assumption that the labels field would always be a sequence, and it attempted to slice it (sample["labels"][1:]). When sample["labels"] is an integer, this causes the application to crash.

The fix introduces a check to determine if the labels are sequence-like.

  • If the labels are a sequence (the original expected format), the collator preserves the exact original logic.
  • If the labels are integers, the collator now correctly broadcasts the integer label across all tokens of the corresponding input_ids.

This change makes the collator more robust to common data formats without altering its behavior for existing use cases, thereby preventing unexpected crashes. A new test case has been added to verify this fix.

Fixes #41652

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

This change relates to data collators, which are often used with the Trainer. Tagging from the original issue and the trainer section.
@ArthurZucker @SunMarc

Copy link
Copy Markdown

@LPirch LPirch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your work! There is one thing I didn't fully understand (see comments) and which needs further clarification.

Comment thread src/transformers/data/data_collator.py Outdated

# --- Start of our fix ---
# Determine if labels are sequence-like (token-level) or single values (sequence-level)
are_labels_sequence = is_labels_provided and isinstance(features[0].get("labels"), (list, tuple, np.ndarray))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor suggestion: is_labels_sequence would be more a consistent naming with the rest of the function

# We must create a label for each token in the input_ids.
# The label for the sequence is applied to all its tokens.
# We do NOT add a separator, as that concept doesn't apply to this kind of label.
batch["labels"] += [sample["labels"]] * len(input_ids)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not batch["labels"].append(sample["labels"])?

What I had in mind was a sequence classification task in which a model receives a sequence (batch_size, seq_length) and outputs scores for each class (batch_size, num_classes). The collator would now return a (batch_size, seq_length) tensor which needs further processing on the user side.

Is there a use case where expanding the class to the number of tokens is required?

Comment thread src/transformers/data/data_collator.py Outdated
Comment on lines +1431 to +1434
# New logic for sequence-level integer labels.
# We must create a label for each token in the input_ids.
# The label for the sequence is applied to all its tokens.
# We do NOT add a separator, as that concept doesn't apply to this kind of label.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of these comment lines should probably be removed, particularly # --- Start of our fix ---! The ones in the test are okay, but descriptions of "new logic" and "original logic" will confuse people, since they won't know which bits were added in a PR and which weren't.

@yashisthebatman yashisthebatman force-pushed the fix-collator-int-labels branch from 3e1997a to 01e748b Compare October 17, 2025 13:06
@yashisthebatman
Copy link
Copy Markdown
Author

@LPirch Thanks for the review i have removed the unecessary comments and also have also followed the more consistent naming scheme but regarding the Why not batch["labels"].append(sample["labels"])? :
The entire purpose of DataCollatorWithFlattening is to create one single long sequence, so the labels tensor must have the same length as the final input_ids tensor. Your suggestion (.append()) would create mismatched lengths (e.g., input_ids of length 7 but labels of length 2), which would immediately crash any model's loss function.

@LPirch
Copy link
Copy Markdown

LPirch commented Oct 17, 2025

Let me give a bit more context to prevent misunderstandings: In sequence classification, you don't have a label per token but a label per sample. Now, when sequences get very long, you might want to use flash attention (as in my case). The only way to process a batch of long sequences is to pack them using the flattening collator and run them through the language model. The needs to be unpacked using the position_ids tensor (see example below) to recover the original samples.

In my case, you then need a pooling that creates an embedding per sequence and a classification head (linear layers) producing an output score. The loss function is then just a regular classification loss.

unpacking example:

def pool_last_token_flat(last_hidden_states: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
    end_idx = torch.argwhere(position_ids == 0)[:, 1][1:]
    end_idx = torch.cat([end_idx, torch.tensor([last_hidden_states.shape[1] - 1], device=last_hidden_states.device)])
    return last_hidden_states[0, end_idx]

This pooling takes the last layers hidden vector of the last token of each (sub-)sequence in the flattened sequence.

TLDR: We don't create flat sequences for the sake of flattening but because flash attention 2 can't handle padding. Recovering the batch structure is required anyways in the case of sequence classification.

@yashisthebatman
Copy link
Copy Markdown
Author

Hi @LPirch ,
Thank you for that context! That makes total sense. I didn't consider the flash attention -> pack -> unpack -> pool workflow.
You're right, my original fix is useful for preventing a crash but actively unhelpful for your use case, since you'd just have to undo it. What do you think of this?
We add a boolean flag to the collator: pack_sequence_labels=False.
Default (False): It does my original fix (broadcasts the label). This is the "safe" option for anyone not doing what you're doing, and it guarantees input_ids and labels have the same length.
Your case (True): You'd set pack_sequence_labels=True. The collator would then just collect the integer labels into a simple list [label1, label2, ...], giving you the (batch_size,) tensor you need for your pooling logic.
This feels like the cleanest way to support both scenarios without breaking things for anyone.
If that sounds good to you, I'll go ahead and implement it.

@LPirch
Copy link
Copy Markdown

LPirch commented Oct 17, 2025

Sounds good! Thanks for your support 🙌

@yashisthebatman yashisthebatman force-pushed the fix-collator-int-labels branch from 01e748b to 41b7bb9 Compare October 17, 2025 15:10
@yashisthebatman
Copy link
Copy Markdown
Author

@LPirch I have done all the changes that we discussed please do check it

Copy link
Copy Markdown

@LPirch LPirch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks perfect. Thanks for your work!

@yashisthebatman
Copy link
Copy Markdown
Author

@LPirch thank you for accepting my pr i am new to open source could you tell me where i could improve myself or any suggestions that could help me become a better contributor

@LPirch
Copy link
Copy Markdown

LPirch commented Oct 17, 2025

@yashisthebatman I think some steps in the CI pipeline failed (code quality). Can you fix those?

About your question: You're doing well! Keep the positivity in communication and stay curious / be open to learn from others and keep a focus on understanding the bigger picture before finalizing a PR - that are the key ingredients if you ask me. The rest is experience.

@yashisthebatman yashisthebatman force-pushed the fix-collator-int-labels branch from 41b7bb9 to 9faee26 Compare October 17, 2025 15:54
@yashisthebatman
Copy link
Copy Markdown
Author

@LPirch fixed those formatting errors this time it must work thank you for the patience

@yashisthebatman
Copy link
Copy Markdown
Author

@LPirch all the tests have passed successfully

@yashisthebatman
Copy link
Copy Markdown
Author

@LPirch could you please merge this problem I have had successfully solved all the formatting issues

@LPirch
Copy link
Copy Markdown

LPirch commented Oct 20, 2025

Just to clarify: I'm not a maintainer of this project. I can't assign, resolve or merge anything. I'm just a user who reported an issue.

From now on, @Rocketknight1 would need to take the lead. I suggest we wait patiently until he finds the time and see if he has additional feedback. Also, let's not bloat the conversation too much.

@Rocketknight1
Copy link
Copy Markdown
Member

Rocketknight1 commented Oct 20, 2025

Hey all! I'm sorry for the delays, but I finally got around to doing a proper review! I have one main question: I realize that packing is used as a way of improving performance with some FlashAttention backends, but do our SequenceClassification models actually support packed inputs?

Packing works for CLM training because the objective for those models is token-level, so you can just pack sequences together and concatenate input_ids with a separator, then do the same with their labels. However, most SequenceClassification models emit one logit distribution per sequence, not per token. If you pack inputs together, it's not clear to me how sequence classification training is supposed to work, and I think the user would have to unpack the outputs before passing them to the model anyway. What exactly is the workflow supposed to look like when a Flattening collator is combined with a SequenceClassification model?

@yashisthebatman
Copy link
Copy Markdown
Author

@Rocketknight1 A SequenceClassification model won't know what to do with a packed input straight out of the box. This feature is really for an advanced, manual workflow that some users need for performance, especially with things like Flash Attention 2 that don't like padding.

So the idea is, a user would grab this collator with pack_sequence_labels=True. This gives them:

  1. The packed input_ids and position_ids they need to feed the base model.
  2. And crucially, a clean labels tensor of shape (batch_size,). This is the missing piece my PR provides.

From there, the workflow looks like this on the user's end:

First, they'd get the raw hidden states by calling just the base model, bypassing the classification head:

# User's code
outputs = model.base_model(packed_input_ids)
last_hidden_state = outputs.last_hidden_state # A single giant tensor

Next, they have to do the manual part: use the position_ids to find the last token of each original sequence in that giant tensor and pool them. This turns the flattened output back into a standard (batch_size, hidden_size) tensor.

And now they have everything they need. They can feed their pooled outputs into model.classifier, and the labels tensor from the collator will match up perfectly for the loss calculation.

So basically this PR doesn't make the models understand packed inputs rather provides the correctly shaped labels tensor that makes this whole manual packing/unpacking dance possible for sequence classification.

@Rocketknight1
Copy link
Copy Markdown
Member

Hmmn. I see what you're going for, but there's a distribution shift there: If a model is fine-tuned for sequence classification, it will generally not see concatenated documents with separators during training. Packing works for CLM training because documents are often concatenated across document boundaries anyway, so the model is familiar with that, but I expect the performance of a sequence classification model to sharply drop if you use the hidden states from a later sequence that follows multiple earlier sequences and separator tokens.

In other words, packed sequences aren't actually independent and the later sequences can still attend to the earlier ones, they just learn not to do that if they see enough concatenated docs in training!

As such, I'm not sure we want to add this feature - I think it's very niche and likely to be confusing to users who expect it to work, when really it requires a lot of manual effort to separate the outputs for the classification head, and probably also degrades performance in most cases as well.

@LPirch
Copy link
Copy Markdown

LPirch commented Oct 21, 2025

Oh, I was not aware that there are cross-sequence information leaks in a packed sequence. I didn't think of this as an option because (as you said) this would also contaminate gradients on regular token predictions, i.e. attending to tokens of previous sequences in a batch. From a user perspective, I find it unintuitive that flash attention implicitly sacrifices soundness for efficiency but that maybe belongs to a different issue.

Generally, I can understand your argument though and I would be fine if you close it. I have already monkey-patched the collator to my needs in the meantime and only thought it would be a useful feature for others as well.

@yashisthebatman
Copy link
Copy Markdown
Author

yashisthebatman commented Oct 21, 2025

Thanks for clarifying your position, @LPirch. Given the valid concerns about attention leakage and the potential for user confusion, I agree that the best path forward is to close this PR. This was a really insightful discussion, and I appreciate the detailed feedback from both of you. I'll go ahead and close this now.

@Rocketknight1
Copy link
Copy Markdown
Member

Hang on, my mistake! I dug into the FlashAttention docs and I was wrong; as long as position_ids and cu_seq_lens are set correctly then there shouldn't be information leakage and performance shouldn't degrade. I didn't realize that cu_seq_lens controlled attention masking like that.

I'm debating whether it's worth adding this after all, in that case

@yashisthebatman
Copy link
Copy Markdown
Author

@Rocketknight1 Hey, wow, thank you so much for digging into the FlashAttention docs and for that update! That's a super important clarification and really great to know.
No problem at all, I'll keep the PR open. Please take your time to think it over. If you do decide it's a worthwhile addition, just let me know if there's anything else you'd like me to do. No rush at all!

@yashisthebatman
Copy link
Copy Markdown
Author

@Rocketknight1 any update ?

@Rocketknight1
Copy link
Copy Markdown
Member

Will try to review it tomorrow!

Copy link
Copy Markdown
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With apologies for the extremely long delay, I finally did a proper review! I think it's good, with two comments. The old behaviour for sequence labels is preserved, and the behaviour in the case of non-sequence labels was definitely unwanted before, so this seems like a straightforward improvement.

Take a look at the two highlighted areas and let me know when you're ready for final approval + merging

batch["labels"] += [sample["labels"]] * len(input_ids)
else:
batch["labels"] += [separator_id] + input_ids[1:]
batch["labels"] += [self.separator_id] + input_ids[1:]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should revert this change, since it overrides any passed separator_id

Suggested change
batch["labels"] += [self.separator_id] + input_ids[1:]
batch["labels"] += [separator_id] + input_ids[1:]

separator_id = self.separator_id
is_labels_provided = "labels" in features[0]

is_labels_sequence = is_labels_provided and isinstance(features[0].get("labels"), (list, tuple, np.ndarray))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this case exhaustive? I think torch.tensor may also be possible.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

DataCollatorWithFlattening incompatible with sequence classification

3 participants