[bugfix] fix eval loss denominator under sequence_parallel#9152
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the Seq2SeqTrainer to handle loss calculation under sequence parallelism by aggregating the token count across ranks. A review comment pointed out that the current logic for skipping the first token is applied to every rank in the sequence parallel group, leading to an incorrect denominator, and provided a suggestion to only skip it on the first rank.
| num_items_in_batch = (labels[:, 1:] != -100).sum() | ||
| if self.template.sequence_parallel_size > 1: | ||
| # labels are sharded by SP; outputs.loss was gathered | ||
| # to full length via GatherLoss. Reduce the denominator | ||
| # across the SP group so it matches the gathered loss. | ||
| dist.all_reduce(num_items_in_batch, op=dist.ReduceOp.SUM, group=sequence_parallel.sp_group) |
There was a problem hiding this comment.
The current implementation of num_items_in_batch calculation when sequence_parallel_size > 1 is slightly inaccurate. By applying labels[:, 1:] on every rank, it skips the first token of every shard instead of just the first token of the global sequence. This leads to an undercount of the denominator by sp_world_size - 1 tokens. While the error is small for long sequences, it should be corrected by only skipping the first token on the first rank of the sequence parallel group.
| num_items_in_batch = (labels[:, 1:] != -100).sum() | |
| if self.template.sequence_parallel_size > 1: | |
| # labels are sharded by SP; outputs.loss was gathered | |
| # to full length via GatherLoss. Reduce the denominator | |
| # across the SP group so it matches the gathered loss. | |
| dist.all_reduce(num_items_in_batch, op=dist.ReduceOp.SUM, group=sequence_parallel.sp_group) | |
| if num_items_in_batch is None: | |
| if self.template.sequence_parallel_size > 1: | |
| # Only skip the first token on the first rank of the SP group | |
| if dist.get_rank(group=sequence_parallel.sp_group) == 0: | |
| num_items_in_batch = (labels[:, 1:] != -100).sum() | |
| else: | |
| num_items_in_batch = (labels != -100).sum() | |
| dist.all_reduce(num_items_in_batch, op=dist.ReduceOp.SUM, group=sequence_parallel.sp_group) | |
| else: | |
| num_items_in_batch = (labels[:, 1:] != -100).sum() |
|
thanks! |
PR type
PR information
Under
sequence_parallel_size > 1,Seq2SeqTrainer.compute_lossreports aninflated eval loss (~
sp_world_size× the true value) while training loss andtoken accuracy remain correct.
Root cause:
_prepare_inputscallssequence_parallel.prepare_inputs(inputs), whichshards
inputs['labels']to1 / sp_world_sizeof the full sequence.per_token_loss_func_spcomputes per-token CE on the local shard and thengathers across SP ranks via
GatherLoss, sooutputs.losshas full-sequencelength on every rank.
The division site then does:
The numerator covers the whole sequence (gathered), but the denominator only
counts the local SP shard, so the reported eval loss is inflated by a factor
of ~
sp_world_size.Training is unaffected because HF
Trainerpasses a globally-reducednum_items_in_batchintocompute_loss, bypassing the local-recompute branch.Token accuracy is unaffected because its numerator and denominator both use the
same gathered tensors.
Fix: when
num_items_in_batchis computed locally (i.e. wasNoneon entry)and
sequence_parallel_size > 1,all_reducethe count acrosssequence_parallel.sp_groupbefore dividing, so the denominator matches thegathered loss.
Experiment results
N/A — code-path fix; no model-quality change. The training loss path and token
accuracy are untouched. A multi-GPU smoke run with
sequence_parallel_size > 1and
eval_strategy != 'no'will showeval_lossdrop by a factor of~
sp_world_sizerelative to the previous behavior.