Skip to content

feat: CP-local loss and configurable CUDA memory profiling#1223

Merged
garrett4wade merged 1 commit intomainfrom
wht/feat/cp-local-loss-and-memory-profiling
Apr 23, 2026
Merged

feat: CP-local loss and configurable CUDA memory profiling#1223
garrett4wade merged 1 commit intomainfrom
wht/feat/cp-local-loss-and-memory-profiling

Conversation

@yulangz
Copy link
Copy Markdown
Collaborator

@yulangz yulangz commented Apr 21, 2026

Summary

This PR introduces two independent features for Megatron-based SFT training:

  1. CP-local cross-entropy loss — eliminates the expensive logits all-gather across Context Parallel (CP) ranks
  2. Configurable CUDA memory snapshot profilingtorch.cuda.memory based flame-graph profiling, controllable via YAML config

Feature 1: CP-Local Loss

Problem

When Context Parallelism (CP) is enabled, the standard post-processing in postprocess_packed_seqs_context_parallel performs an all-gather to reconstruct the full logits tensor across all CP ranks. For large-vocabulary models (e.g., Qwen3-30B with vocab_size=151,936), this causes OOM because the full [total_tokens, vocab_size] logits tensor must be materialized on each rank.

For a Qwen3-30B-A3B SFT job with seq_len=128K and CP=2:

  • Full logits tensor: 128K × 151,936 × bf16 ≈ 37 GiB per rank — exceeds 80 GiB GPU memory

Solution

Compute cross-entropy loss locally on each CP rank using only the local logit shard, avoiding the all-gather entirely.

Key changes:

  • packed_context_parallel.py: Added split_packed_seqs_for_context_parallel() to split labels/loss_mask using the same interleaved pattern as the input splitting. Added gather_output flag to postprocess_packed_seqs_context_parallel() to skip all-gather.
  • megatron_engine.py:
    • forward_backward_batch() accepts cp_local_loss flag
    • When enabled, labels and loss_mask are pre-rolled and split to match the local CP shard before passing to the loss function
    • Local cu_seqlens are scaled by 1/cp_size to match the local token count
    • _cp_local_labels and _loss_mask_pre_rolled keys are injected into the input dict
    • eval_batch uses dp_group(with_context_parallel=True) for correct loss aggregation across DP×CP ranks

Why this is correct: Each CP rank holds a contiguous interleaved chunk of each sequence. By pre-computing labels (roll(-1) on the full sequence, then split) before the forward pass, each rank can independently compute its local cross-entropy. The global loss is the mean over all valid tokens across all CP ranks, which equals the sum of local losses (weighted by local valid token count) — mathematically equivalent to the all-gather approach.

Files changed

  • areal/engine/megatron_utils/packed_context_parallel.py — new split_packed_seqs_for_context_parallel, gather_output parameter
  • areal/engine/megatron_engine.pycp_local_loss logic in forward_backward_batch, train_batch, eval_batch, label routing in _process_megatron_output

Feature 2: Configurable Memory Profiling

Problem

CUDA memory snapshot profiling (flame graph) was hardcoded to always run on steps 0 and 1. This should be a configurable feature, disabled by default.

Solution

Follow the existing perf_tracer configuration pattern: add an optional config dataclass, None means disabled.

Config:

@dataclass
class MemoryProfileConfig:
    profile_steps: list[int]  # default [0, 1]
    max_entries: int           # default 100000

Usage in YAML (opt-in):

memory_profile:
  profile_steps: [0, 1]
  max_entries: 100000

No memory_profile key → feature is completely disabled (default).

Key changes:

  • areal/api/cli_args.py — new MemoryProfileConfig, added memory_profile field to BaseExperimentConfig
  • areal/api/engine_api.py — added start_memory_profile(max_entries) and stop_memory_profile(path) API
  • areal/engine/megatron_engine.py — implemented start_memory_profile / stop_memory_profile using torch.cuda.memory._record_memory_history / _dump_snapshot
  • areal/trainer/sft/lm_engine.py — RPC forwarding for start_memory_profile(max_entries)
  • areal/trainer/sft_trainer.py — config-driven profiling in training loop, snapshots saved to {log_dir}/memory_snapshots/step_{N}.pickle

Test

For Qwen3-30B-A3B,In 16K context length and (attn:d1p1t2c4|ffn:d1p1t1e8) parallel, cp local loss can save approximately 5GB memory, and this value is directly proportional to the length of the context.

image (with cp local loss) image (without cp local loss)

Copilot AI review requested due to automatic review settings April 21, 2026 08:48
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces CUDA memory snapshot profiling and implements local loss calculation for context parallel (CP) training. Key changes include the addition of a MemoryProfileConfig dataclass, the split_packed_seqs_for_context_parallel utility for handling interleaved packed sequences, and updates to the Megatron engine to support CP-local labels and masks. Feedback suggests making the memory profiling rank configurable instead of hardcoding it to rank 0 to better support distributed debugging. Additionally, it is recommended to vectorize the splitting logic in split_packed_seqs_for_context_parallel to avoid potential performance bottlenecks caused by Python loops during microbatch processing.

Comment thread areal/engine/megatron_engine.py Outdated
return DeviceRuntimeInfo.get_current()

def start_memory_profile(self, max_entries: int = 100000) -> None:
if self.rank == 0:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Hardcoding self.rank == 0 for memory profiling is restrictive. In distributed training with Megatron (PP/TP/CP), memory pressure and OOM issues often vary significantly across different ranks (e.g., the last pipeline stage or specific tensor parallel ranks). Consider making the profiling rank configurable or allowing it on the local rank 0 of each node to provide better visibility into the cluster's memory state.

Comment thread areal/engine/megatron_engine.py Outdated
Comment on lines +730 to +757
cp_size = mpu.get_context_parallel_world_size()
if cp_local_loss and cp_size > 1 and cu_seqlens is not None:
padded_cu_seqlens = mb_input.padded_mb["cu_seqlens"]
rolled_ids = torch.roll(
mb_input.padded_mb["input_ids"], shifts=-1, dims=-1
)
cp_labels = split_packed_seqs_for_context_parallel(
rolled_ids, padded_cu_seqlens
)
padded_loss_mask = mb_input.padded_mb.get("loss_mask")
if padded_loss_mask is not None:
rolled_mask = torch.roll(padded_loss_mask, shifts=-1, dims=-1)
cp_loss_mask = split_packed_seqs_for_context_parallel(
rolled_mask, padded_cu_seqlens
)
else:
cp_loss_mask = torch.zeros(
output.shape[0],
dtype=torch.bool,
device=output.device,
)
cp_cu_seqlens = padded_cu_seqlens // cp_size
cp_inputs = dict(mb_input.orig_mb)
cp_inputs["_cp_local_labels"] = cp_labels
cp_inputs["loss_mask"] = cp_loss_mask
cp_inputs["_loss_mask_pre_rolled"] = True
cp_inputs["cu_seqlens"] = cp_cu_seqlens
return output, functools.partial(_process_output, cp_inputs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The CP-local loss logic performs torch.roll and split_packed_seqs_for_context_parallel (which contains a Python loop over the batch size) for every microbatch in the last pipeline stage. While functional, this could become a performance bottleneck for large batch sizes or high-frequency microbatching. Consider vectorizing the interleaved splitting logic in split_packed_seqs_for_context_parallel to avoid the Python loop.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed that vectorization is possible here, but split_packed_seqs_for_context_parallel is not a hot path — it runs once per micro-batch during loss computation, not in the inner attention loop. The current loop-based implementation is clear and correct. We can optimize this in a follow-up if profiling shows it matters.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds two opt-in improvements for Megatron-based SFT training: a context-parallel (CP) local loss path to avoid CP logits all-gather, and a YAML-configurable CUDA memory snapshot profiler.

Changes:

  • Introduces CP-local packed-sequence splitting utilities and an option to skip CP all-gather in packed forward/postprocess.
  • Updates Megatron engine training/eval flow to support CP-local label/mask routing and correct loss aggregation in eval.
  • Adds a memory_profile config and wiring to start/stop CUDA memory snapshot recording from the SFT training loop.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
areal/trainer/sft_trainer.py Triggers memory profiling start/stop and snapshot writing per configured steps.
areal/trainer/sft/lm_engine.py Adds controller RPC hooks for memory profiling; adjusts SFT loss mask rolling behavior.
areal/engine/megatron_utils/packed_context_parallel.py Adds packed-tensor CP split helper and a flag to skip CP output all-gather.
areal/engine/megatron_engine.py Implements CP-local loss routing and adds engine-level memory snapshot APIs.
areal/api/engine_api.py Extends the TrainEngine API surface with memory profiling hooks.
areal/api/cli_args.py Adds MemoryProfileConfig and a memory_profile field to experiment config.
areal/dataset/hhrlhf.py Formatting-only change to dataset length filtering lambda.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +819 to +824
self.forward_backward_batch(
mb_list,
process_output,
forward_only=False,
cp_local_loss=(mpu.get_context_parallel_world_size() > 1),
)
Copy link

Copilot AI Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cp_local_loss is force-enabled whenever context_parallel_world_size() > 1. However, the CP-local path only rolls/splits input_ids (labels) and loss_mask/cu_seqlens, and it also mutates semantics by rolling loss_mask. This will break non-SFT training paths (e.g., PPO/RW) whose loss functions expect additional per-token tensors in inputs (advantages/old_logprobs/returns/etc.) and an un-rolled loss_mask. Make CP-local loss opt-in (plumb a config/flag down to train_batch/eval_batch) or generalize the CP-local input rewriting to split all token-aligned tensors without changing their semantics for other algorithms.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a concern — train_batch and eval_batch (which call forward_backward_batch) are only used by the SFT path. The RL training path uses different methods (forward_backward_func directly via Megatron pipeline). So the CP-local logic is scoped correctly to SFT only.

In the latest commit, cp_local_loss has been removed as a parameter entirely — the engine now infers CP-local mode internally from mpu.get_context_parallel_world_size() > 1.

Comment on lines +752 to +756
cp_inputs = dict(mb_input.orig_mb)
cp_inputs["_cp_local_labels"] = cp_labels
cp_inputs["loss_mask"] = cp_loss_mask
cp_inputs["_loss_mask_pre_rolled"] = True
cp_inputs["cu_seqlens"] = cp_cu_seqlens
Copy link

Copilot AI Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the CP-local branch, cp_inputs = dict(mb_input.orig_mb) keeps the unsplit original microbatch tensors (and possibly unpadded lengths) but only replaces loss_mask/cu_seqlens and injects _cp_local_labels. Any other packed, token-aligned fields in orig_mb (e.g., PPO/RW training targets) will have mismatched shapes relative to the CP-sharded output, leading to incorrect loss computation or runtime shape errors. Consider building cp_inputs from padded_mb and splitting every token-aligned tensor with split_packed_seqs_for_context_parallel, or restrict this path to the SFT loss that only consumes labels/loss_mask.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is by design. cp_inputs = dict(mb_input.orig_mb) creates a shallow copy of the original microbatch dict, then we overwrite the fields that need CP-local values (_cp_local_labels, loss_mask, cu_seqlens). The SFT loss function (compute_packed_sft_loss) only consumes these replaced fields — it does not use any of the original unsplit tensors. The remaining fields in orig_mb (like metadata) are harmlessly carried through.

Comment on lines +183 to +188
if (
config.memory_profile is not None
and global_step in config.memory_profile.profile_steps
):
self.actor.start_memory_profile(config.memory_profile.max_entries)

Copy link

Copilot AI Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SFTTrainer unconditionally calls self.actor.start_memory_profile(...) / stop_memory_profile(...) when the config is set, but the v1 controller (LMController) currently does not define these methods (only LMControllerV2 does). This will raise AttributeError for non-v2 runs. Either add these methods to LMController as well (dispatch via _custom_function_call with rpc_meta={"broadcast": True}), or guard these calls based on controller version/capability.

Copilot uses AI. Check for mistakes.
Comment thread areal/trainer/sft/lm_engine.py Outdated
Comment on lines +102 to +106
n_valid = loss_mask.count_nonzero()
if n_valid == 0:
loss = logprobs.sum() * 0.0
else:
loss = -logprobs.sum() / n_valid
Copy link

Copilot AI Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loss_mask.count_nonzero() returns a 0-d tensor, so if n_valid == 0: will raise RuntimeError: Boolean value of Tensor... at runtime. Convert to a Python number (e.g., n_valid = int(... .item())) or use n_valid.item() == 0 before the conditional, and keep the tensor form only for arithmetic/division as needed.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in the latest commit. The n_valid == 0 branch has been removed entirely — the loss is now computed as loss = -logprobs.sum() / (1e-5 + loss_mask.count_nonzero()), which handles the all-zero case gracefully via the epsilon term.

Comment thread areal/trainer/sft/lm_engine.py Outdated
Comment on lines +79 to +82
self._custom_function_call("start_memory_profile", max_entries)

def stop_memory_profile(self, path: str):
self._custom_function_call("stop_memory_profile", path)
Copy link

Copilot AI Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LMControllerV2 inherits from GatewayTrainController, which does not define _custom_function_call. These methods will raise AttributeError when invoked (e.g., in SFTTrainer.train). Implement these via gateway endpoints (similar to train_batch/step_lr_scheduler), or route through an existing dispatch helper on GatewayTrainController if one exists.

Suggested change
self._custom_function_call("start_memory_profile", max_entries)
def stop_memory_profile(self, path: str):
self._custom_function_call("stop_memory_profile", path)
payload = {
"args": serialize_value([max_entries]),
"kwargs": serialize_value({}),
}
self._gateway_post_result("/sft/start_memory_profile", payload)
def stop_memory_profile(self, path: str):
payload = {
"args": serialize_value([path]),
"kwargs": serialize_value({}),
}
self._gateway_post_result("/sft/stop_memory_profile", payload)

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LMControllerV2.stop_memory_profile and start_memory_profile use _custom_function_call which was already present in the pre-existing code (see the same class's existing methods). This is not introduced by this PR — it's inherited from GatewayTrainController base class.

Comment thread areal/api/cli_args.py Outdated
Comment thread areal/trainer/sft_trainer.py Outdated
config.memory_profile is not None
and global_step in config.memory_profile.profile_steps
):
self.actor.start_memory_profile(config.memory_profile.max_entries)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1) This method call won't work for the single controller mode because the train controller does not implement this method.
(2) Not implemented for the RL/RW trainer. Should either amend the code or open an issue about this.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both issues addressed:

  1. Added start_memory_profile and stop_memory_profile to LMController (V1) for single-controller mode.
  2. Ported memory profiling support to PPOTrainer (rl_trainer.py) and RWTrainer (rw_trainer.py) with the same config-driven pattern as SFTTrainer.

Comment thread areal/trainer/sft_trainer.py Outdated
Comment thread areal/engine/megatron_engine.py Outdated
Comment thread areal/engine/megatron_engine.py Outdated
Comment thread areal/engine/megatron_engine.py Outdated
Comment thread areal/trainer/sft/lm_engine.py
Comment thread areal/trainer/sft/lm_engine.py Outdated
Comment thread areal/engine/megatron_engine.py Outdated
Comment on lines +867 to +872
eval_dp_group = (
mpu.get_data_parallel_group(with_context_parallel=True)
if cp_local
else mpu.get_data_parallel_group()
)
return aggregate_eval_losses(losses, eval_dp_group)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1) mpu.get_data_parallel_group(with_context_parallel=True) implies the else branch and is sufficient.
(2) Why should we also gather across the CP group? The data should be duplicated across CP groups. Summing the losses up would upscale the loss. Conversely, is it because we have incorrectly downscaled the training loss?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. Using with_context_parallel=True is always safe:

  1. CP=1: The CP group is trivially size-1, so get_data_parallel_group(with_context_parallel=True) returns the same group as get_data_parallel_group() — it is idempotent.
  2. CP>1 (CP-local loss): Each CP rank computes loss over its local shard. aggregate_eval_losses does a SUM all-reduce. With with_context_parallel=True, the reduction spans DP×CP ranks, correctly averaging the CP-local losses across all CP ranks.

On the training side, DistributedOptimizer uses intra_dp_cp_group (DP×CP) for gradient reduce-scatter, so CP-local gradients are mathematically equivalent to the global gradient.

Since with_context_parallel=True is always correct regardless of CP size, we can simplify to always use it. Updated in the latest commit.

yulangz added a commit that referenced this pull request Apr 22, 2026
- Rename MemoryProfileConfig → MemoryProfilerConfig for consistency
- All ranks dump memory snapshots with parallelism info in filename
- stop_memory_profile accepts snapshot_dir instead of file path
- Remove cp_local_loss parameter; infer from CP world size internally
- Pre-roll loss_mask in LMEngine before entering engine, remove
  _loss_mask_pre_rolled flag and loss_mask else fallback
- Simplify loss computation: use loss_mask.count_nonzero() directly
- eval_batch always uses with_context_parallel=True (idempotent for CP=1)
- Add memory profile methods to LMController (V1)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@yulangz yulangz force-pushed the wht/feat/cp-local-loss-and-memory-profiling branch from b759b18 to 423b516 Compare April 22, 2026 08:03
@yulangz yulangz requested a review from nuzant as a code owner April 22, 2026 08:03
yulangz added a commit that referenced this pull request Apr 22, 2026
- Rename MemoryProfileConfig → MemoryProfilerConfig for consistency
- All ranks dump memory snapshots with parallelism info in filename
- stop_memory_profile accepts snapshot_dir instead of file path
- Remove cp_local_loss parameter; infer from CP world size internally
- Pre-roll loss_mask in LMEngine before entering engine, remove
  _loss_mask_pre_rolled flag and loss_mask else fallback
- Simplify loss computation: use loss_mask.count_nonzero() directly
- eval_batch always uses with_context_parallel=True (idempotent for CP=1)
- Add memory profile methods to LMController (V1)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@yulangz yulangz force-pushed the wht/feat/cp-local-loss-and-memory-profiling branch from 423b516 to 55e63a9 Compare April 22, 2026 08:06
@garrett4wade garrett4wade force-pushed the wht/feat/cp-local-loss-and-memory-profiling branch from ffb1e54 to 430a1f9 Compare April 23, 2026 08:29
@garrett4wade garrett4wade merged commit d58cca5 into main Apr 23, 2026
6 checks passed
@garrett4wade garrett4wade deleted the wht/feat/cp-local-loss-and-memory-profiling branch April 23, 2026 08:31
yulangz added a commit that referenced this pull request Apr 24, 2026
After PR #1223 introduced CP-local loss, ratio metrics in
compute_packed_sft_loss (loss/entropy/ppl/vocab_*) started reporting
the average over rank 0's CP slice rather than the global average,
because both numerator and denominator were CP-local tensors and
export only reduced across DP. Counts were already fixed in PR #1242,
but the ratios were not -- empirically loss/ppl/max can drift up to
25% (e.g. ppl/max 2.57 vs 3.20) on long-context SFT where prompt /
completion tokens land in different CP slices.

Fix the reporting without reintroducing the expensive logits
all-gather that #1223 removed: at export time, the per-key reduce
sees only scalar numerator/denominator (already .sum()-ed), so
all-reducing across DP + CP costs a few bytes -- not 37GB of logits.

Key changes:
- stats_tracker: add per-key reduce_group override (kw-only) on
  denominator/scalar/stat; _avg/_min/_max/_sum/SCALAR honor it via
  _effective_reduce_group; reset clears it; fix latent reduce_types
  pop bug in single-key export path.
- megatron_engine: in CP-local forward_step, expose
  _cp_reduce_group (CP) and _cp_dp_reduce_group (DP+CP) on cp_inputs.
- lm_engine: use _cp_reduce_group to all-reduce per-sequence
  seqlogp/valid-count so ppl is CP-invariant; use
  _cp_dp_reduce_group as reduce_group for loss/entropy/vocab_*
  stats so the global mean is reported.

Verified on a 64-GPU CP=2 SFT replay (Qwen3-30B-A3B + scale-swe data,
seed=1, BS=128): with the fix, loss/avg, ppl/avg/max, vocab_*/avg
match the pre-CP-local reference run to within 0.06%; grad_norm and
n_tokens/n_valid_tokens are unchanged (the latter remains as fixed
in PR #1242).

Refs: #1242
SathyaGnanakumar pushed a commit to danielkiely/AReaL that referenced this pull request Apr 29, 2026
…g CP logits (inclusionAI#1223)

Co-authored-by: 博惟 <bowei.fw@antgroup.com>
garrett4wade pushed a commit that referenced this pull request May 6, 2026
After PR #1223 introduced CP-local loss, compute_packed_sft_loss
started recording `n_tokens`, `n_valid_tokens` and `prompt_tokens`
using the CP-split `loss_mask` / `logprobs`. These denominators are
summed only across the DP group at export time, so the reported
values under-count by the CP factor (e.g. ~4x smaller with CP=4).
The ratios reported via `stat(..., denominator=...)` (loss, ppl,
vocab_*) remain correct because numerator and denominator scale
together, so the issue is easy to miss.

Preserve the pre-CP-split loss_mask as `_global_loss_mask` when the
CP-local path constructs its inputs, and use it as the denominator
for the token-count metrics so they are invariant to the CP topology.
Keep separate `n_tokens_local` / `n_valid_tokens_local` denominators
with CP-local shapes for the CP-local tensors (`logprobs`, `vocab_*`),
since `stats_tracker.stat` requires matching shapes.

Verified on 64-GPU (CP=2) and 128-GPU (CP=4) SPMD runs with the same
Qwen3-30B + swe_distilled_1000 setup: step-1 `n_tokens` matches
exactly across both topologies (6,320,900), and equals
`CP * n_tokens_local` on each, confirming the fix is topology-invariant.

Fixes #1242

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
garrett4wade pushed a commit that referenced this pull request May 6, 2026
After PR #1223 introduced CP-local loss, ratio metrics in
compute_packed_sft_loss (loss/entropy/ppl/vocab_*) started reporting
the average over rank 0's CP slice rather than the global average,
because both numerator and denominator were CP-local tensors and
export only reduced across DP. Counts were already fixed in PR #1242,
but the ratios were not -- empirically loss/ppl/max can drift up to
25% (e.g. ppl/max 2.57 vs 3.20) on long-context SFT where prompt /
completion tokens land in different CP slices.

Fix the reporting without reintroducing the expensive logits
all-gather that #1223 removed: at export time, the per-key reduce
sees only scalar numerator/denominator (already .sum()-ed), so
all-reducing across DP + CP costs a few bytes -- not 37GB of logits.

Key changes:
- stats_tracker: add per-key reduce_group override (kw-only) on
  denominator/scalar/stat; _avg/_min/_max/_sum/SCALAR honor it via
  _effective_reduce_group; reset clears it; fix latent reduce_types
  pop bug in single-key export path.
- megatron_engine: in CP-local forward_step, expose
  _cp_reduce_group (CP) and _cp_dp_reduce_group (DP+CP) on cp_inputs.
- lm_engine: use _cp_reduce_group to all-reduce per-sequence
  seqlogp/valid-count so ppl is CP-invariant; use
  _cp_dp_reduce_group as reduce_group for loss/entropy/vocab_*
  stats so the global mean is reported.

Verified on a 64-GPU CP=2 SFT replay (Qwen3-30B-A3B + scale-swe data,
seed=1, BS=128): with the fix, loss/avg, ppl/avg/max, vocab_*/avg
match the pre-CP-local reference run to within 0.06%; grad_norm and
n_tokens/n_valid_tokens are unchanged (the latter remains as fixed
in PR #1242).

Refs: #1242
garrett4wade added a commit that referenced this pull request May 6, 2026
* fix(sft): report CP-invariant token-count stats (#1242)

After PR #1223 introduced CP-local loss, compute_packed_sft_loss
started recording `n_tokens`, `n_valid_tokens` and `prompt_tokens`
using the CP-split `loss_mask` / `logprobs`. These denominators are
summed only across the DP group at export time, so the reported
values under-count by the CP factor (e.g. ~4x smaller with CP=4).
The ratios reported via `stat(..., denominator=...)` (loss, ppl,
vocab_*) remain correct because numerator and denominator scale
together, so the issue is easy to miss.

Preserve the pre-CP-split loss_mask as `_global_loss_mask` when the
CP-local path constructs its inputs, and use it as the denominator
for the token-count metrics so they are invariant to the CP topology.
Keep separate `n_tokens_local` / `n_valid_tokens_local` denominators
with CP-local shapes for the CP-local tensors (`logprobs`, `vocab_*`),
since `stats_tracker.stat` requires matching shapes.

Verified on 64-GPU (CP=2) and 128-GPU (CP=4) SPMD runs with the same
Qwen3-30B + swe_distilled_1000 setup: step-1 `n_tokens` matches
exactly across both topologies (6,320,900), and equals
`CP * n_tokens_local` on each, confirming the fix is topology-invariant.

Fixes #1242

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* style: fix 'Loggin stats' typo

Noticed by Copilot reviewer on PR #1249.

* fix(sft): make CP-local SFT loss/ppl/vocab_* stats CP-invariant

After PR #1223 introduced CP-local loss, ratio metrics in
compute_packed_sft_loss (loss/entropy/ppl/vocab_*) started reporting
the average over rank 0's CP slice rather than the global average,
because both numerator and denominator were CP-local tensors and
export only reduced across DP. Counts were already fixed in PR #1242,
but the ratios were not -- empirically loss/ppl/max can drift up to
25% (e.g. ppl/max 2.57 vs 3.20) on long-context SFT where prompt /
completion tokens land in different CP slices.

Fix the reporting without reintroducing the expensive logits
all-gather that #1223 removed: at export time, the per-key reduce
sees only scalar numerator/denominator (already .sum()-ed), so
all-reducing across DP + CP costs a few bytes -- not 37GB of logits.

Key changes:
- stats_tracker: add per-key reduce_group override (kw-only) on
  denominator/scalar/stat; _avg/_min/_max/_sum/SCALAR honor it via
  _effective_reduce_group; reset clears it; fix latent reduce_types
  pop bug in single-key export path.
- megatron_engine: in CP-local forward_step, expose
  _cp_reduce_group (CP) and _cp_dp_reduce_group (DP+CP) on cp_inputs.
- lm_engine: use _cp_reduce_group to all-reduce per-sequence
  seqlogp/valid-count so ppl is CP-invariant; use
  _cp_dp_reduce_group as reduce_group for loss/entropy/vocab_*
  stats so the global mean is reported.

Verified on a 64-GPU CP=2 SFT replay (Qwen3-30B-A3B + scale-swe data,
seed=1, BS=128): with the fix, loss/avg, ppl/avg/max, vocab_*/avg
match the pre-CP-local reference run to within 0.06%; grad_norm and
n_tokens/n_valid_tokens are unchanged (the latter remains as fixed
in PR #1242).

Refs: #1242

* fix: reassemble CP packed sequences

* chore: revert unnecessary changes

---------

Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Co-authored-by: 博惟 <bowei.fw@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants