Skip to content

[bugfix] fix eval loss denominator under sequence_parallel#9152

Merged
Jintao-Huang merged 1 commit into
modelscope:mainfrom
YarivColbeci:fix_sp_eval_loss_num_items
Apr 21, 2026
Merged

[bugfix] fix eval loss denominator under sequence_parallel#9152
Jintao-Huang merged 1 commit into
modelscope:mainfrom
YarivColbeci:fix_sp_eval_loss_num_items

Conversation

@YarivColbeci
Copy link
Copy Markdown
Contributor

PR type

  • Bug Fix

PR information

Under sequence_parallel_size > 1, Seq2SeqTrainer.compute_loss reports an
inflated eval loss (~sp_world_size× the true value) while training loss and
token accuracy remain correct.

Root cause:

  • _prepare_inputs calls sequence_parallel.prepare_inputs(inputs), which
    shards inputs['labels'] to 1 / sp_world_size of the full sequence.

  • per_token_loss_func_sp computes per-token CE on the local shard and then
    gathers across SP ranks via GatherLoss, so outputs.loss has full-sequence
    length on every rank.

  • The division site then does:

    if num_items_in_batch is None:
        num_items_in_batch = (labels[:, 1:] != -100).sum()   # labels is SHARDED
    loss = outputs.loss.sum() / num_items_in_batch            # outputs.loss is FULL length

    The numerator covers the whole sequence (gathered), but the denominator only
    counts the local SP shard, so the reported eval loss is inflated by a factor
    of ~sp_world_size.

Training is unaffected because HF Trainer passes a globally-reduced
num_items_in_batch into compute_loss, bypassing the local-recompute branch.
Token accuracy is unaffected because its numerator and denominator both use the
same gathered tensors.

Fix: when num_items_in_batch is computed locally (i.e. was None on entry)
and sequence_parallel_size > 1, all_reduce the count across
sequence_parallel.sp_group before dividing, so the denominator matches the
gathered loss.

Experiment results

N/A — code-path fix; no model-quality change. The training loss path and token
accuracy are untouched. A multi-GPU smoke run with sequence_parallel_size > 1
and eval_strategy != 'no' will show eval_loss drop by a factor of
~sp_world_size relative to the previous behavior.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the Seq2SeqTrainer to handle loss calculation under sequence parallelism by aggregating the token count across ranks. A review comment pointed out that the current logic for skipping the first token is applied to every rank in the sequence parallel group, leading to an incorrect denominator, and provided a suggestion to only skip it on the first rank.

Comment on lines 199 to +204
num_items_in_batch = (labels[:, 1:] != -100).sum()
if self.template.sequence_parallel_size > 1:
# labels are sharded by SP; outputs.loss was gathered
# to full length via GatherLoss. Reduce the denominator
# across the SP group so it matches the gathered loss.
dist.all_reduce(num_items_in_batch, op=dist.ReduceOp.SUM, group=sequence_parallel.sp_group)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of num_items_in_batch calculation when sequence_parallel_size > 1 is slightly inaccurate. By applying labels[:, 1:] on every rank, it skips the first token of every shard instead of just the first token of the global sequence. This leads to an undercount of the denominator by sp_world_size - 1 tokens. While the error is small for long sequences, it should be corrected by only skipping the first token on the first rank of the sequence parallel group.

Suggested change
num_items_in_batch = (labels[:, 1:] != -100).sum()
if self.template.sequence_parallel_size > 1:
# labels are sharded by SP; outputs.loss was gathered
# to full length via GatherLoss. Reduce the denominator
# across the SP group so it matches the gathered loss.
dist.all_reduce(num_items_in_batch, op=dist.ReduceOp.SUM, group=sequence_parallel.sp_group)
if num_items_in_batch is None:
if self.template.sequence_parallel_size > 1:
# Only skip the first token on the first rank of the SP group
if dist.get_rank(group=sequence_parallel.sp_group) == 0:
num_items_in_batch = (labels[:, 1:] != -100).sum()
else:
num_items_in_batch = (labels != -100).sum()
dist.all_reduce(num_items_in_batch, op=dist.ReduceOp.SUM, group=sequence_parallel.sp_group)
else:
num_items_in_batch = (labels[:, 1:] != -100).sum()

@Jintao-Huang
Copy link
Copy Markdown
Collaborator

thanks!

@Jintao-Huang Jintao-Huang merged commit 6ea7b40 into modelscope:main Apr 21, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants