fix(data): Handle integer labels in DataCollatorWithFlattening#41687
fix(data): Handle integer labels in DataCollatorWithFlattening#41687yashisthebatman wants to merge 1 commit intohuggingface:mainfrom
Conversation
LPirch
left a comment
There was a problem hiding this comment.
Thanks for your work! There is one thing I didn't fully understand (see comments) and which needs further clarification.
|
|
||
| # --- Start of our fix --- | ||
| # Determine if labels are sequence-like (token-level) or single values (sequence-level) | ||
| are_labels_sequence = is_labels_provided and isinstance(features[0].get("labels"), (list, tuple, np.ndarray)) |
There was a problem hiding this comment.
minor suggestion: is_labels_sequence would be more a consistent naming with the rest of the function
| # We must create a label for each token in the input_ids. | ||
| # The label for the sequence is applied to all its tokens. | ||
| # We do NOT add a separator, as that concept doesn't apply to this kind of label. | ||
| batch["labels"] += [sample["labels"]] * len(input_ids) |
There was a problem hiding this comment.
Why not batch["labels"].append(sample["labels"])?
What I had in mind was a sequence classification task in which a model receives a sequence (batch_size, seq_length) and outputs scores for each class (batch_size, num_classes). The collator would now return a (batch_size, seq_length) tensor which needs further processing on the user side.
Is there a use case where expanding the class to the number of tokens is required?
| # New logic for sequence-level integer labels. | ||
| # We must create a label for each token in the input_ids. | ||
| # The label for the sequence is applied to all its tokens. | ||
| # We do NOT add a separator, as that concept doesn't apply to this kind of label. |
There was a problem hiding this comment.
All of these comment lines should probably be removed, particularly # --- Start of our fix ---! The ones in the test are okay, but descriptions of "new logic" and "original logic" will confuse people, since they won't know which bits were added in a PR and which weren't.
3e1997a to
01e748b
Compare
|
@LPirch Thanks for the review i have removed the unecessary comments and also have also followed the more consistent naming scheme but regarding the Why not batch["labels"].append(sample["labels"])? : |
|
Let me give a bit more context to prevent misunderstandings: In sequence classification, you don't have a label per token but a label per sample. Now, when sequences get very long, you might want to use flash attention (as in my case). The only way to process a batch of long sequences is to pack them using the flattening collator and run them through the language model. The needs to be unpacked using the In my case, you then need a pooling that creates an embedding per sequence and a classification head (linear layers) producing an output score. The loss function is then just a regular classification loss. unpacking example: def pool_last_token_flat(last_hidden_states: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
end_idx = torch.argwhere(position_ids == 0)[:, 1][1:]
end_idx = torch.cat([end_idx, torch.tensor([last_hidden_states.shape[1] - 1], device=last_hidden_states.device)])
return last_hidden_states[0, end_idx]This pooling takes the last layers hidden vector of the last token of each (sub-)sequence in the flattened sequence. TLDR: We don't create flat sequences for the sake of flattening but because flash attention 2 can't handle padding. Recovering the batch structure is required anyways in the case of sequence classification. |
|
Hi @LPirch , |
|
Sounds good! Thanks for your support 🙌 |
01e748b to
41b7bb9
Compare
|
@LPirch I have done all the changes that we discussed please do check it |
|
@LPirch thank you for accepting my pr i am new to open source could you tell me where i could improve myself or any suggestions that could help me become a better contributor |
|
@yashisthebatman I think some steps in the CI pipeline failed (code quality). Can you fix those? About your question: You're doing well! Keep the positivity in communication and stay curious / be open to learn from others and keep a focus on understanding the bigger picture before finalizing a PR - that are the key ingredients if you ask me. The rest is experience. |
41b7bb9 to
9faee26
Compare
|
@LPirch fixed those formatting errors this time it must work thank you for the patience |
|
@LPirch all the tests have passed successfully |
|
@LPirch could you please merge this problem I have had successfully solved all the formatting issues |
|
Just to clarify: I'm not a maintainer of this project. I can't assign, resolve or merge anything. I'm just a user who reported an issue. From now on, @Rocketknight1 would need to take the lead. I suggest we wait patiently until he finds the time and see if he has additional feedback. Also, let's not bloat the conversation too much. |
|
Hey all! I'm sorry for the delays, but I finally got around to doing a proper review! I have one main question: I realize that packing is used as a way of improving performance with some FlashAttention backends, but do our Packing works for CLM training because the objective for those models is token-level, so you can just pack sequences together and concatenate |
|
@Rocketknight1 A So the idea is, a user would grab this collator with
From there, the workflow looks like this on the user's end: First, they'd get the raw hidden states by calling just the base model, bypassing the classification head: # User's code
outputs = model.base_model(packed_input_ids)
last_hidden_state = outputs.last_hidden_state # A single giant tensorNext, they have to do the manual part: use the And now they have everything they need. They can feed their pooled outputs into So basically this PR doesn't make the models understand packed inputs rather provides the correctly shaped |
|
Hmmn. I see what you're going for, but there's a distribution shift there: If a model is fine-tuned for sequence classification, it will generally not see concatenated documents with separators during training. Packing works for CLM training because documents are often concatenated across document boundaries anyway, so the model is familiar with that, but I expect the performance of a sequence classification model to sharply drop if you use the hidden states from a later sequence that follows multiple earlier sequences and separator tokens. In other words, packed sequences aren't actually independent and the later sequences can still attend to the earlier ones, they just learn not to do that if they see enough concatenated docs in training! As such, I'm not sure we want to add this feature - I think it's very niche and likely to be confusing to users who expect it to work, when really it requires a lot of manual effort to separate the outputs for the classification head, and probably also degrades performance in most cases as well. |
|
Oh, I was not aware that there are cross-sequence information leaks in a packed sequence. I didn't think of this as an option because (as you said) this would also contaminate gradients on regular token predictions, i.e. attending to tokens of previous sequences in a batch. From a user perspective, I find it unintuitive that flash attention implicitly sacrifices soundness for efficiency but that maybe belongs to a different issue. Generally, I can understand your argument though and I would be fine if you close it. I have already monkey-patched the collator to my needs in the meantime and only thought it would be a useful feature for others as well. |
|
Thanks for clarifying your position, @LPirch. Given the valid concerns about attention leakage and the potential for user confusion, I agree that the best path forward is to close this PR. This was a really insightful discussion, and I appreciate the detailed feedback from both of you. I'll go ahead and close this now. |
|
Hang on, my mistake! I dug into the FlashAttention docs and I was wrong; as long as I'm debating whether it's worth adding this after all, in that case |
|
@Rocketknight1 Hey, wow, thank you so much for digging into the FlashAttention docs and for that update! That's a super important clarification and really great to know. |
|
@Rocketknight1 any update ? |
|
Will try to review it tomorrow! |
Rocketknight1
left a comment
There was a problem hiding this comment.
With apologies for the extremely long delay, I finally did a proper review! I think it's good, with two comments. The old behaviour for sequence labels is preserved, and the behaviour in the case of non-sequence labels was definitely unwanted before, so this seems like a straightforward improvement.
Take a look at the two highlighted areas and let me know when you're ready for final approval + merging
| batch["labels"] += [sample["labels"]] * len(input_ids) | ||
| else: | ||
| batch["labels"] += [separator_id] + input_ids[1:] | ||
| batch["labels"] += [self.separator_id] + input_ids[1:] |
There was a problem hiding this comment.
I think we should revert this change, since it overrides any passed separator_id
| batch["labels"] += [self.separator_id] + input_ids[1:] | |
| batch["labels"] += [separator_id] + input_ids[1:] |
| separator_id = self.separator_id | ||
| is_labels_provided = "labels" in features[0] | ||
|
|
||
| is_labels_sequence = is_labels_provided and isinstance(features[0].get("labels"), (list, tuple, np.ndarray)) |
There was a problem hiding this comment.
Is this case exhaustive? I think torch.tensor may also be possible.
What does this PR do?
This PR fixes a
TypeErrorinDataCollatorWithFlatteningthat occurs when processing datasets with integer labels, which is a common format for sequence classification tasks.The collator was written with the assumption that the
labelsfield would always be a sequence, and it attempted to slice it (sample["labels"][1:]). Whensample["labels"]is an integer, this causes the application to crash.The fix introduces a check to determine if the
labelsare sequence-like.input_ids.This change makes the collator more robust to common data formats without altering its behavior for existing use cases, thereby preventing unexpected crashes. A new test case has been added to verify this fix.
Fixes #41652
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
This change relates to data collators, which are often used with the Trainer. Tagging from the original issue and the trainer section.
@ArthurZucker @SunMarc