[Bugfix] Fix loss missing from logs when context parallelism is enabled#9380
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the Megatron trainer to include context parallel groups in data parallel operations for loss reduction and metric synchronization. Specifically, it modifies loss_func and _compute_channel_loss to use with_context_parallel=True when retrieving the data parallel group. A potential RuntimeError was identified in _compute_channel_loss where _all_reduce_metric could be called with an empty dictionary, and a code suggestion was provided to add a safety check.
| for key in sorted(set().union(*all_keys)): | ||
| new_metrics[key] = metrics[key] | ||
| new_metrics = self._all_reduce_metric(new_metrics, torch.distributed.ReduceOp.SUM) | ||
| new_metrics = self._all_reduce_metric(new_metrics, torch.distributed.ReduceOp.SUM, group=dp_cp_group) |
There was a problem hiding this comment.
The _all_reduce_metric method in the base class uses torch.stack on the dictionary values. If new_metrics is empty (which can happen if no ranks in the group have any channel-specific tokens in the current micro-batch), torch.stack will raise a RuntimeError. It is safer to check if the dictionary is non-empty before attempting the reduction.
| new_metrics = self._all_reduce_metric(new_metrics, torch.distributed.ReduceOp.SUM, group=dp_cp_group) | |
| if new_metrics: | |
| new_metrics = self._all_reduce_metric(new_metrics, torch.distributed.ReduceOp.SUM, group=dp_cp_group) |
The reporting_loss all-reduce only used get_data_parallel_group() which excludes CP ranks. When a CP rank's sequence chunk has no valid tokens (all labels=-100, common in SFT with long prompts), loss_mask.sum()=0 causing the loss key to be skipped entirely in _aggregated_metrics. Fix by using get_data_parallel_group(with_context_parallel=True) so loss and token counts are aggregated across both DP and CP dimensions.
f1c9256 to
89f2f09
Compare
|
thanks! |
PR type
Summary
losskey missing from training logs when context parallelism (CP) is enabledreporting_losswas all-reduced only overget_data_parallel_group()(excludes CP ranks), so CP ranks whose sequence chunk has no valid tokens (all labels=-100) would reportloss_mask.sum()=0, causing thelosskey to be skipped in_aggregated_metrics_compute_channel_losswhich used the DP-only group for key synchronization and metric reductionRoot Cause
This issue is caused by how Swift splits sequence data across CP ranks. Swift uses a zigzag pattern to distribute sequence chunks (e.g., with CP=2: GPU0 gets chunk_0+chunk_3, GPU1 gets chunk_1+chunk_2). In SFT scenarios where prompts are long and responses are short, the loss-bearing tokens (response part) tend to concentrate at one end of the sequence. As a result, certain CP ranks may receive chunks where all labels are -100 (prompt-only), making
loss_mask.sum()=0on those ranks. Sincereporting_losswas only all-reduced within the DP group (not across CP ranks), the aggregated token count remains zero on these ranks, and thelosskey gets skipped entirely during metric aggregation.Changes
trainer.py:69: Useget_data_parallel_group(with_context_parallel=True)forreporting_lossall-reducetrainer.py:98-104: Use DP+CP group for channel loss key synchronization and metric reduction