Skip to content

Fix num_items_in_batch over-counting for causal LM losses#46204

Merged
vasqu merged 4 commits into
mainfrom
fix-num-items-in-batch-causal-lm-v2
May 27, 2026
Merged

Fix num_items_in_batch over-counting for causal LM losses#46204
vasqu merged 4 commits into
mainfrom
fix-num-items-in-batch-causal-lm-v2

Conversation

@qgallouedec
Copy link
Copy Markdown
Member

The bug

Trainer._get_num_items_in_batch counts labels at every position. But for causal LM the loss shifts labels — position 0 is never a prediction target. The denominator is too large by num_rows per micro-batch, systematically under-scaling every causal LM loss and gradient.

labels (trainer counts these):           shift_labels = labels[..., 1:]
                                         (what ForCausalLMLoss uses)
┌────┬────┬────┬────┬────┐                    ┌────┬────┬────┬────┐
│ t0 │ t1 │ t2 │ t3 │ t4 │  count = 5         │ t1 │ t2 │ t3 │ t4 │  count = 4
└────┴────┴────┴────┴────┘                    └────┴────┴────┴────┘
  ↑
  position 0 is dropped by the shift — over-count = 1 per row
shift_labels = labels[..., 1:]           # ← numerator: 4 CE terms
loss = sum_ce / num_items_in_batch       # ← denominator: counted 5

loss_type is set on every PreTrainedModel (see modeling_utils.py). Only causal LM loss types are touched; ForMaskedLM, classification, etc. are unaffected.

Reported causal LM loss becomes slightly larger (correctly so, the denominator was too big). Shift is num_rows / total_tokens per step. Gradient magnitudes scale by the same factor.

Tests

All TrainerGradientAccumulationTest tests pass (4/4). A padded vs padding-free reproducer now matches with Δgrad_norm = 0.

How it surfaced

TRL's invariance suite compared SFT with padding_free=False vs padding_free=True. Loss curves almost matched, but grad_norm drifted with a systematic +0.17 bias over 50 steps. The padding-free collator masks labels[position_ids == 0] = -100, which incidentally matches the post-shift count, so it exposed the padded path's over-count.

@qgallouedec qgallouedec requested review from SunMarc and vasqu May 26, 2026 00:32
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment thread src/transformers/trainer.py Outdated
# Causal LM losses shift labels internally (predictions at position i target label[i+1]), so position 0 of
# each row is never a prediction target. The valid-prediction count used by `num_items_in_batch` must therefore
# be taken over `labels[..., 1:]`, not the full label tensor
self._loss_shifts_labels = getattr(model_to_inspect, "loss_type", None) in (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is too simple. I think if we were to go that way we should inspect the actual loss function, e.g. CsmForConditionalGeneration would fail here, no?

But, my bigger issue is that this covers the most simple usecase where we only prepare the tokenized input and pass that but what if we were to use a data collator that properly prepares the shifted labels? We would now count 1 too much

But definitely needed to fix in general!

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented May 26, 2026

Also definitely need a test that covers this edge case

Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, left some minor comments but happy to merge it in general. Can you add a small test to check that we are calculating correctly the num_item_per_batch in case we have ForCausalLM ?

Comment on lines +511 to +518
# Causal LM losses shift labels internally (predictions at position i target label[i+1]), so position 0 of
# each row is never a prediction target. The valid-prediction count used by `num_items_in_batch` must therefore
# be taken over `labels[..., 1:]`, not the full label tensor
self._loss_shifts_labels = getattr(model_to_inspect, "loss_type", None) in (
"ForCausalLM",
"ForConditionalGeneration",
)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as @vasqu pointed, maybe we can have the following so that it is a bit more robust.

from transformers.loss.loss_utils import LOSS_MAPPING, ForCausalLMLoss
    self._loss_shifts_labels = (
        LOSS_MAPPING.get(getattr(model_to_inspect, "loss_type", None)) is ForCausalLMLoss
    )

Comment thread src/transformers/trainer.py Outdated
Comment on lines +2144 to +2149
# Causal LM losses shift labels; count over `labels[..., 1:]` to avoid over-counting position 0.
labels_for_count = (
[batch["labels"][..., 1:] for batch in batch_samples]
if self._loss_shifts_labels
else [batch["labels"] for batch in batch_samples]
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe to fix @vasqu point, we can also take into account the case where shift_labels is prepared by the user ?

  labels_for_count = [
      batch["shift_labels"] if "shift_labels" in batch
      else batch["labels"][..., 1:] if self._loss_shifts_labels
      else batch["labels"]
      for batch in batch_samples
  ]

qgallouedec and others added 2 commits May 27, 2026 13:54
- Inspect the actual loss function via LOSS_MAPPING instead of matching
  loss_type strings (catches CsmForConditionalGeneration etc.).
- If the data collator already provides `shift_labels`, count over that
  tensor directly instead of slicing labels again.
- Add unit tests for `_get_num_items_in_batch` covering the causal LM
  path (with and without pre-shifted labels) and the non-causal-LM path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks !

@vasqu vasqu enabled auto-merge May 27, 2026 14:11
@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented May 27, 2026

@qgallouedec
Copy link
Copy Markdown
Member Author

Seems like we have 1 new failure

0a3d375 should fix it

@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=46204&sha=0a3d37

@vasqu vasqu added this pull request to the merge queue May 27, 2026
@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks May 27, 2026
@vasqu vasqu added this pull request to the merge queue May 27, 2026
Merged via the queue into main with commit 67265ef May 27, 2026
44 of 45 checks passed
@vasqu vasqu deleted the fix-num-items-in-batch-causal-lm-v2 branch May 27, 2026 16:03
yuchenxie4645 pushed a commit to yuchenxie4645/transformers that referenced this pull request May 28, 2026
…ace#46204)

* Fix `num_items_in_batch` over-counting for causal LM losses

* Address review: use LOSS_MAPPING, honor pre-shifted labels, add tests

- Inspect the actual loss function via LOSS_MAPPING instead of matching
  loss_type strings (catches CsmForConditionalGeneration etc.).
- If the data collator already provides `shift_labels`, count over that
  tensor directly instead of slicing labels again.
- Add unit tests for `_get_num_items_in_batch` covering the causal LM
  path (with and without pre-shifted labels) and the non-causal-LM path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix test_train_and_predict_loss_parity

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
kashif pushed a commit to kashif/transformers that referenced this pull request Jun 1, 2026
…ace#46204)

* Fix `num_items_in_batch` over-counting for causal LM losses

* Address review: use LOSS_MAPPING, honor pre-shifted labels, add tests

- Inspect the actual loss function via LOSS_MAPPING instead of matching
  loss_type strings (catches CsmForConditionalGeneration etc.).
- If the data collator already provides `shift_labels`, count over that
  tensor directly instead of slicing labels again.
- Add unit tests for `_get_num_items_in_batch` covering the causal LM
  path (with and without pre-shifted labels) and the non-causal-LM path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix test_train_and_predict_loss_parity

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants