Fix num_items_in_batch over-counting for causal LM losses#46204
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| # Causal LM losses shift labels internally (predictions at position i target label[i+1]), so position 0 of | ||
| # each row is never a prediction target. The valid-prediction count used by `num_items_in_batch` must therefore | ||
| # be taken over `labels[..., 1:]`, not the full label tensor | ||
| self._loss_shifts_labels = getattr(model_to_inspect, "loss_type", None) in ( |
There was a problem hiding this comment.
I think this is too simple. I think if we were to go that way we should inspect the actual loss function, e.g. CsmForConditionalGeneration would fail here, no?
But, my bigger issue is that this covers the most simple usecase where we only prepare the tokenized input and pass that but what if we were to use a data collator that properly prepares the shifted labels? We would now count 1 too much
But definitely needed to fix in general!
|
Also definitely need a test that covers this edge case |
SunMarc
left a comment
There was a problem hiding this comment.
Thanks, left some minor comments but happy to merge it in general. Can you add a small test to check that we are calculating correctly the num_item_per_batch in case we have ForCausalLM ?
| # Causal LM losses shift labels internally (predictions at position i target label[i+1]), so position 0 of | ||
| # each row is never a prediction target. The valid-prediction count used by `num_items_in_batch` must therefore | ||
| # be taken over `labels[..., 1:]`, not the full label tensor | ||
| self._loss_shifts_labels = getattr(model_to_inspect, "loss_type", None) in ( | ||
| "ForCausalLM", | ||
| "ForConditionalGeneration", | ||
| ) | ||
|
|
There was a problem hiding this comment.
as @vasqu pointed, maybe we can have the following so that it is a bit more robust.
from transformers.loss.loss_utils import LOSS_MAPPING, ForCausalLMLoss
self._loss_shifts_labels = (
LOSS_MAPPING.get(getattr(model_to_inspect, "loss_type", None)) is ForCausalLMLoss
)| # Causal LM losses shift labels; count over `labels[..., 1:]` to avoid over-counting position 0. | ||
| labels_for_count = ( | ||
| [batch["labels"][..., 1:] for batch in batch_samples] | ||
| if self._loss_shifts_labels | ||
| else [batch["labels"] for batch in batch_samples] | ||
| ) |
There was a problem hiding this comment.
Maybe to fix @vasqu point, we can also take into account the case where shift_labels is prepared by the user ?
labels_for_count = [
batch["shift_labels"] if "shift_labels" in batch
else batch["labels"][..., 1:] if self._loss_shifts_labels
else batch["labels"]
for batch in batch_samples
]- Inspect the actual loss function via LOSS_MAPPING instead of matching loss_type strings (catches CsmForConditionalGeneration etc.). - If the data collator already provides `shift_labels`, count over that tensor directly instead of slicing labels again. - Add unit tests for `_get_num_items_in_batch` covering the causal LM path (with and without pre-shifted labels) and the non-causal-LM path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Seems like we have 1 new failure https://app.circleci.com/pipelines/github/huggingface/transformers/175911/workflows/b1032ab4-41a0-44f1-b1cd-072bad362f4d/jobs/2326963 Thanks tho, overall LGTM as well |
0a3d375 should fix it |
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=46204&sha=0a3d37 |
…ace#46204) * Fix `num_items_in_batch` over-counting for causal LM losses * Address review: use LOSS_MAPPING, honor pre-shifted labels, add tests - Inspect the actual loss function via LOSS_MAPPING instead of matching loss_type strings (catches CsmForConditionalGeneration etc.). - If the data collator already provides `shift_labels`, count over that tensor directly instead of slicing labels again. - Add unit tests for `_get_num_items_in_batch` covering the causal LM path (with and without pre-shifted labels) and the non-causal-LM path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix test_train_and_predict_loss_parity --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ace#46204) * Fix `num_items_in_batch` over-counting for causal LM losses * Address review: use LOSS_MAPPING, honor pre-shifted labels, add tests - Inspect the actual loss function via LOSS_MAPPING instead of matching loss_type strings (catches CsmForConditionalGeneration etc.). - If the data collator already provides `shift_labels`, count over that tensor directly instead of slicing labels again. - Add unit tests for `_get_num_items_in_batch` covering the causal LM path (with and without pre-shifted labels) and the non-causal-LM path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix test_train_and_predict_loss_parity --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The bug
Trainer._get_num_items_in_batchcounts labels at every position. But for causal LM the loss shifts labels — position 0 is never a prediction target. The denominator is too large bynum_rowsper micro-batch, systematically under-scaling every causal LM loss and gradient.loss_typeis set on everyPreTrainedModel(seemodeling_utils.py). Only causal LM loss types are touched;ForMaskedLM, classification, etc. are unaffected.Reported causal LM loss becomes slightly larger (correctly so, the denominator was too big). Shift is
num_rows / total_tokensper step. Gradient magnitudes scale by the same factor.Tests
All
TrainerGradientAccumulationTesttests pass (4/4). A padded vs padding-free reproducer now matches with Δgrad_norm = 0.How it surfaced
TRL's invariance suite compared SFT with
padding_free=Falsevspadding_free=True. Loss curves almost matched, butgrad_normdrifted with a systematic +0.17 bias over 50 steps. The padding-free collator maskslabels[position_ids == 0] = -100, which incidentally matches the post-shift count, so it exposed the padded path's over-count.