feat: CP-local loss and configurable CUDA memory profiling#1223
feat: CP-local loss and configurable CUDA memory profiling#1223garrett4wade merged 1 commit intomainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces CUDA memory snapshot profiling and implements local loss calculation for context parallel (CP) training. Key changes include the addition of a MemoryProfileConfig dataclass, the split_packed_seqs_for_context_parallel utility for handling interleaved packed sequences, and updates to the Megatron engine to support CP-local labels and masks. Feedback suggests making the memory profiling rank configurable instead of hardcoding it to rank 0 to better support distributed debugging. Additionally, it is recommended to vectorize the splitting logic in split_packed_seqs_for_context_parallel to avoid potential performance bottlenecks caused by Python loops during microbatch processing.
| return DeviceRuntimeInfo.get_current() | ||
|
|
||
| def start_memory_profile(self, max_entries: int = 100000) -> None: | ||
| if self.rank == 0: |
There was a problem hiding this comment.
Hardcoding self.rank == 0 for memory profiling is restrictive. In distributed training with Megatron (PP/TP/CP), memory pressure and OOM issues often vary significantly across different ranks (e.g., the last pipeline stage or specific tensor parallel ranks). Consider making the profiling rank configurable or allowing it on the local rank 0 of each node to provide better visibility into the cluster's memory state.
| cp_size = mpu.get_context_parallel_world_size() | ||
| if cp_local_loss and cp_size > 1 and cu_seqlens is not None: | ||
| padded_cu_seqlens = mb_input.padded_mb["cu_seqlens"] | ||
| rolled_ids = torch.roll( | ||
| mb_input.padded_mb["input_ids"], shifts=-1, dims=-1 | ||
| ) | ||
| cp_labels = split_packed_seqs_for_context_parallel( | ||
| rolled_ids, padded_cu_seqlens | ||
| ) | ||
| padded_loss_mask = mb_input.padded_mb.get("loss_mask") | ||
| if padded_loss_mask is not None: | ||
| rolled_mask = torch.roll(padded_loss_mask, shifts=-1, dims=-1) | ||
| cp_loss_mask = split_packed_seqs_for_context_parallel( | ||
| rolled_mask, padded_cu_seqlens | ||
| ) | ||
| else: | ||
| cp_loss_mask = torch.zeros( | ||
| output.shape[0], | ||
| dtype=torch.bool, | ||
| device=output.device, | ||
| ) | ||
| cp_cu_seqlens = padded_cu_seqlens // cp_size | ||
| cp_inputs = dict(mb_input.orig_mb) | ||
| cp_inputs["_cp_local_labels"] = cp_labels | ||
| cp_inputs["loss_mask"] = cp_loss_mask | ||
| cp_inputs["_loss_mask_pre_rolled"] = True | ||
| cp_inputs["cu_seqlens"] = cp_cu_seqlens | ||
| return output, functools.partial(_process_output, cp_inputs) |
There was a problem hiding this comment.
The CP-local loss logic performs torch.roll and split_packed_seqs_for_context_parallel (which contains a Python loop over the batch size) for every microbatch in the last pipeline stage. While functional, this could become a performance bottleneck for large batch sizes or high-frequency microbatching. Consider vectorizing the interleaved splitting logic in split_packed_seqs_for_context_parallel to avoid the Python loop.
There was a problem hiding this comment.
Agreed that vectorization is possible here, but split_packed_seqs_for_context_parallel is not a hot path — it runs once per micro-batch during loss computation, not in the inner attention loop. The current loop-based implementation is clear and correct. We can optimize this in a follow-up if profiling shows it matters.
There was a problem hiding this comment.
Pull request overview
This PR adds two opt-in improvements for Megatron-based SFT training: a context-parallel (CP) local loss path to avoid CP logits all-gather, and a YAML-configurable CUDA memory snapshot profiler.
Changes:
- Introduces CP-local packed-sequence splitting utilities and an option to skip CP all-gather in packed forward/postprocess.
- Updates Megatron engine training/eval flow to support CP-local label/mask routing and correct loss aggregation in eval.
- Adds a
memory_profileconfig and wiring to start/stop CUDA memory snapshot recording from the SFT training loop.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| areal/trainer/sft_trainer.py | Triggers memory profiling start/stop and snapshot writing per configured steps. |
| areal/trainer/sft/lm_engine.py | Adds controller RPC hooks for memory profiling; adjusts SFT loss mask rolling behavior. |
| areal/engine/megatron_utils/packed_context_parallel.py | Adds packed-tensor CP split helper and a flag to skip CP output all-gather. |
| areal/engine/megatron_engine.py | Implements CP-local loss routing and adds engine-level memory snapshot APIs. |
| areal/api/engine_api.py | Extends the TrainEngine API surface with memory profiling hooks. |
| areal/api/cli_args.py | Adds MemoryProfileConfig and a memory_profile field to experiment config. |
| areal/dataset/hhrlhf.py | Formatting-only change to dataset length filtering lambda. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| self.forward_backward_batch( | ||
| mb_list, | ||
| process_output, | ||
| forward_only=False, | ||
| cp_local_loss=(mpu.get_context_parallel_world_size() > 1), | ||
| ) |
There was a problem hiding this comment.
cp_local_loss is force-enabled whenever context_parallel_world_size() > 1. However, the CP-local path only rolls/splits input_ids (labels) and loss_mask/cu_seqlens, and it also mutates semantics by rolling loss_mask. This will break non-SFT training paths (e.g., PPO/RW) whose loss functions expect additional per-token tensors in inputs (advantages/old_logprobs/returns/etc.) and an un-rolled loss_mask. Make CP-local loss opt-in (plumb a config/flag down to train_batch/eval_batch) or generalize the CP-local input rewriting to split all token-aligned tensors without changing their semantics for other algorithms.
There was a problem hiding this comment.
This is not a concern — train_batch and eval_batch (which call forward_backward_batch) are only used by the SFT path. The RL training path uses different methods (forward_backward_func directly via Megatron pipeline). So the CP-local logic is scoped correctly to SFT only.
In the latest commit, cp_local_loss has been removed as a parameter entirely — the engine now infers CP-local mode internally from mpu.get_context_parallel_world_size() > 1.
| cp_inputs = dict(mb_input.orig_mb) | ||
| cp_inputs["_cp_local_labels"] = cp_labels | ||
| cp_inputs["loss_mask"] = cp_loss_mask | ||
| cp_inputs["_loss_mask_pre_rolled"] = True | ||
| cp_inputs["cu_seqlens"] = cp_cu_seqlens |
There was a problem hiding this comment.
In the CP-local branch, cp_inputs = dict(mb_input.orig_mb) keeps the unsplit original microbatch tensors (and possibly unpadded lengths) but only replaces loss_mask/cu_seqlens and injects _cp_local_labels. Any other packed, token-aligned fields in orig_mb (e.g., PPO/RW training targets) will have mismatched shapes relative to the CP-sharded output, leading to incorrect loss computation or runtime shape errors. Consider building cp_inputs from padded_mb and splitting every token-aligned tensor with split_packed_seqs_for_context_parallel, or restrict this path to the SFT loss that only consumes labels/loss_mask.
There was a problem hiding this comment.
This is by design. cp_inputs = dict(mb_input.orig_mb) creates a shallow copy of the original microbatch dict, then we overwrite the fields that need CP-local values (_cp_local_labels, loss_mask, cu_seqlens). The SFT loss function (compute_packed_sft_loss) only consumes these replaced fields — it does not use any of the original unsplit tensors. The remaining fields in orig_mb (like metadata) are harmlessly carried through.
| if ( | ||
| config.memory_profile is not None | ||
| and global_step in config.memory_profile.profile_steps | ||
| ): | ||
| self.actor.start_memory_profile(config.memory_profile.max_entries) | ||
|
|
There was a problem hiding this comment.
SFTTrainer unconditionally calls self.actor.start_memory_profile(...) / stop_memory_profile(...) when the config is set, but the v1 controller (LMController) currently does not define these methods (only LMControllerV2 does). This will raise AttributeError for non-v2 runs. Either add these methods to LMController as well (dispatch via _custom_function_call with rpc_meta={"broadcast": True}), or guard these calls based on controller version/capability.
| n_valid = loss_mask.count_nonzero() | ||
| if n_valid == 0: | ||
| loss = logprobs.sum() * 0.0 | ||
| else: | ||
| loss = -logprobs.sum() / n_valid |
There was a problem hiding this comment.
loss_mask.count_nonzero() returns a 0-d tensor, so if n_valid == 0: will raise RuntimeError: Boolean value of Tensor... at runtime. Convert to a Python number (e.g., n_valid = int(... .item())) or use n_valid.item() == 0 before the conditional, and keep the tensor form only for arithmetic/division as needed.
There was a problem hiding this comment.
Fixed in the latest commit. The n_valid == 0 branch has been removed entirely — the loss is now computed as loss = -logprobs.sum() / (1e-5 + loss_mask.count_nonzero()), which handles the all-zero case gracefully via the epsilon term.
| self._custom_function_call("start_memory_profile", max_entries) | ||
|
|
||
| def stop_memory_profile(self, path: str): | ||
| self._custom_function_call("stop_memory_profile", path) |
There was a problem hiding this comment.
LMControllerV2 inherits from GatewayTrainController, which does not define _custom_function_call. These methods will raise AttributeError when invoked (e.g., in SFTTrainer.train). Implement these via gateway endpoints (similar to train_batch/step_lr_scheduler), or route through an existing dispatch helper on GatewayTrainController if one exists.
| self._custom_function_call("start_memory_profile", max_entries) | |
| def stop_memory_profile(self, path: str): | |
| self._custom_function_call("stop_memory_profile", path) | |
| payload = { | |
| "args": serialize_value([max_entries]), | |
| "kwargs": serialize_value({}), | |
| } | |
| self._gateway_post_result("/sft/start_memory_profile", payload) | |
| def stop_memory_profile(self, path: str): | |
| payload = { | |
| "args": serialize_value([path]), | |
| "kwargs": serialize_value({}), | |
| } | |
| self._gateway_post_result("/sft/stop_memory_profile", payload) |
There was a problem hiding this comment.
LMControllerV2.stop_memory_profile and start_memory_profile use _custom_function_call which was already present in the pre-existing code (see the same class's existing methods). This is not introduced by this PR — it's inherited from GatewayTrainController base class.
| config.memory_profile is not None | ||
| and global_step in config.memory_profile.profile_steps | ||
| ): | ||
| self.actor.start_memory_profile(config.memory_profile.max_entries) |
There was a problem hiding this comment.
(1) This method call won't work for the single controller mode because the train controller does not implement this method.
(2) Not implemented for the RL/RW trainer. Should either amend the code or open an issue about this.
There was a problem hiding this comment.
Both issues addressed:
- Added
start_memory_profileandstop_memory_profiletoLMController(V1) for single-controller mode. - Ported memory profiling support to
PPOTrainer(rl_trainer.py) andRWTrainer(rw_trainer.py) with the same config-driven pattern as SFTTrainer.
| eval_dp_group = ( | ||
| mpu.get_data_parallel_group(with_context_parallel=True) | ||
| if cp_local | ||
| else mpu.get_data_parallel_group() | ||
| ) | ||
| return aggregate_eval_losses(losses, eval_dp_group) |
There was a problem hiding this comment.
(1) mpu.get_data_parallel_group(with_context_parallel=True) implies the else branch and is sufficient.
(2) Why should we also gather across the CP group? The data should be duplicated across CP groups. Summing the losses up would upscale the loss. Conversely, is it because we have incorrectly downscaled the training loss?
There was a problem hiding this comment.
Good question. Using with_context_parallel=True is always safe:
- CP=1: The CP group is trivially size-1, so
get_data_parallel_group(with_context_parallel=True)returns the same group asget_data_parallel_group()— it is idempotent. - CP>1 (CP-local loss): Each CP rank computes loss over its local shard.
aggregate_eval_lossesdoes a SUM all-reduce. Withwith_context_parallel=True, the reduction spans DP×CP ranks, correctly averaging the CP-local losses across all CP ranks.
On the training side, DistributedOptimizer uses intra_dp_cp_group (DP×CP) for gradient reduce-scatter, so CP-local gradients are mathematically equivalent to the global gradient.
Since with_context_parallel=True is always correct regardless of CP size, we can simplify to always use it. Updated in the latest commit.
- Rename MemoryProfileConfig → MemoryProfilerConfig for consistency - All ranks dump memory snapshots with parallelism info in filename - stop_memory_profile accepts snapshot_dir instead of file path - Remove cp_local_loss parameter; infer from CP world size internally - Pre-roll loss_mask in LMEngine before entering engine, remove _loss_mask_pre_rolled flag and loss_mask else fallback - Simplify loss computation: use loss_mask.count_nonzero() directly - eval_batch always uses with_context_parallel=True (idempotent for CP=1) - Add memory profile methods to LMController (V1) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
b759b18 to
423b516
Compare
- Rename MemoryProfileConfig → MemoryProfilerConfig for consistency - All ranks dump memory snapshots with parallelism info in filename - stop_memory_profile accepts snapshot_dir instead of file path - Remove cp_local_loss parameter; infer from CP world size internally - Pre-roll loss_mask in LMEngine before entering engine, remove _loss_mask_pre_rolled flag and loss_mask else fallback - Simplify loss computation: use loss_mask.count_nonzero() directly - eval_batch always uses with_context_parallel=True (idempotent for CP=1) - Add memory profile methods to LMController (V1) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
423b516 to
55e63a9
Compare
ffb1e54 to
430a1f9
Compare
After PR #1223 introduced CP-local loss, ratio metrics in compute_packed_sft_loss (loss/entropy/ppl/vocab_*) started reporting the average over rank 0's CP slice rather than the global average, because both numerator and denominator were CP-local tensors and export only reduced across DP. Counts were already fixed in PR #1242, but the ratios were not -- empirically loss/ppl/max can drift up to 25% (e.g. ppl/max 2.57 vs 3.20) on long-context SFT where prompt / completion tokens land in different CP slices. Fix the reporting without reintroducing the expensive logits all-gather that #1223 removed: at export time, the per-key reduce sees only scalar numerator/denominator (already .sum()-ed), so all-reducing across DP + CP costs a few bytes -- not 37GB of logits. Key changes: - stats_tracker: add per-key reduce_group override (kw-only) on denominator/scalar/stat; _avg/_min/_max/_sum/SCALAR honor it via _effective_reduce_group; reset clears it; fix latent reduce_types pop bug in single-key export path. - megatron_engine: in CP-local forward_step, expose _cp_reduce_group (CP) and _cp_dp_reduce_group (DP+CP) on cp_inputs. - lm_engine: use _cp_reduce_group to all-reduce per-sequence seqlogp/valid-count so ppl is CP-invariant; use _cp_dp_reduce_group as reduce_group for loss/entropy/vocab_* stats so the global mean is reported. Verified on a 64-GPU CP=2 SFT replay (Qwen3-30B-A3B + scale-swe data, seed=1, BS=128): with the fix, loss/avg, ppl/avg/max, vocab_*/avg match the pre-CP-local reference run to within 0.06%; grad_norm and n_tokens/n_valid_tokens are unchanged (the latter remains as fixed in PR #1242). Refs: #1242
…g CP logits (inclusionAI#1223) Co-authored-by: 博惟 <bowei.fw@antgroup.com>
After PR #1223 introduced CP-local loss, compute_packed_sft_loss started recording `n_tokens`, `n_valid_tokens` and `prompt_tokens` using the CP-split `loss_mask` / `logprobs`. These denominators are summed only across the DP group at export time, so the reported values under-count by the CP factor (e.g. ~4x smaller with CP=4). The ratios reported via `stat(..., denominator=...)` (loss, ppl, vocab_*) remain correct because numerator and denominator scale together, so the issue is easy to miss. Preserve the pre-CP-split loss_mask as `_global_loss_mask` when the CP-local path constructs its inputs, and use it as the denominator for the token-count metrics so they are invariant to the CP topology. Keep separate `n_tokens_local` / `n_valid_tokens_local` denominators with CP-local shapes for the CP-local tensors (`logprobs`, `vocab_*`), since `stats_tracker.stat` requires matching shapes. Verified on 64-GPU (CP=2) and 128-GPU (CP=4) SPMD runs with the same Qwen3-30B + swe_distilled_1000 setup: step-1 `n_tokens` matches exactly across both topologies (6,320,900), and equals `CP * n_tokens_local` on each, confirming the fix is topology-invariant. Fixes #1242 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
After PR #1223 introduced CP-local loss, ratio metrics in compute_packed_sft_loss (loss/entropy/ppl/vocab_*) started reporting the average over rank 0's CP slice rather than the global average, because both numerator and denominator were CP-local tensors and export only reduced across DP. Counts were already fixed in PR #1242, but the ratios were not -- empirically loss/ppl/max can drift up to 25% (e.g. ppl/max 2.57 vs 3.20) on long-context SFT where prompt / completion tokens land in different CP slices. Fix the reporting without reintroducing the expensive logits all-gather that #1223 removed: at export time, the per-key reduce sees only scalar numerator/denominator (already .sum()-ed), so all-reducing across DP + CP costs a few bytes -- not 37GB of logits. Key changes: - stats_tracker: add per-key reduce_group override (kw-only) on denominator/scalar/stat; _avg/_min/_max/_sum/SCALAR honor it via _effective_reduce_group; reset clears it; fix latent reduce_types pop bug in single-key export path. - megatron_engine: in CP-local forward_step, expose _cp_reduce_group (CP) and _cp_dp_reduce_group (DP+CP) on cp_inputs. - lm_engine: use _cp_reduce_group to all-reduce per-sequence seqlogp/valid-count so ppl is CP-invariant; use _cp_dp_reduce_group as reduce_group for loss/entropy/vocab_* stats so the global mean is reported. Verified on a 64-GPU CP=2 SFT replay (Qwen3-30B-A3B + scale-swe data, seed=1, BS=128): with the fix, loss/avg, ppl/avg/max, vocab_*/avg match the pre-CP-local reference run to within 0.06%; grad_norm and n_tokens/n_valid_tokens are unchanged (the latter remains as fixed in PR #1242). Refs: #1242
* fix(sft): report CP-invariant token-count stats (#1242) After PR #1223 introduced CP-local loss, compute_packed_sft_loss started recording `n_tokens`, `n_valid_tokens` and `prompt_tokens` using the CP-split `loss_mask` / `logprobs`. These denominators are summed only across the DP group at export time, so the reported values under-count by the CP factor (e.g. ~4x smaller with CP=4). The ratios reported via `stat(..., denominator=...)` (loss, ppl, vocab_*) remain correct because numerator and denominator scale together, so the issue is easy to miss. Preserve the pre-CP-split loss_mask as `_global_loss_mask` when the CP-local path constructs its inputs, and use it as the denominator for the token-count metrics so they are invariant to the CP topology. Keep separate `n_tokens_local` / `n_valid_tokens_local` denominators with CP-local shapes for the CP-local tensors (`logprobs`, `vocab_*`), since `stats_tracker.stat` requires matching shapes. Verified on 64-GPU (CP=2) and 128-GPU (CP=4) SPMD runs with the same Qwen3-30B + swe_distilled_1000 setup: step-1 `n_tokens` matches exactly across both topologies (6,320,900), and equals `CP * n_tokens_local` on each, confirming the fix is topology-invariant. Fixes #1242 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * style: fix 'Loggin stats' typo Noticed by Copilot reviewer on PR #1249. * fix(sft): make CP-local SFT loss/ppl/vocab_* stats CP-invariant After PR #1223 introduced CP-local loss, ratio metrics in compute_packed_sft_loss (loss/entropy/ppl/vocab_*) started reporting the average over rank 0's CP slice rather than the global average, because both numerator and denominator were CP-local tensors and export only reduced across DP. Counts were already fixed in PR #1242, but the ratios were not -- empirically loss/ppl/max can drift up to 25% (e.g. ppl/max 2.57 vs 3.20) on long-context SFT where prompt / completion tokens land in different CP slices. Fix the reporting without reintroducing the expensive logits all-gather that #1223 removed: at export time, the per-key reduce sees only scalar numerator/denominator (already .sum()-ed), so all-reducing across DP + CP costs a few bytes -- not 37GB of logits. Key changes: - stats_tracker: add per-key reduce_group override (kw-only) on denominator/scalar/stat; _avg/_min/_max/_sum/SCALAR honor it via _effective_reduce_group; reset clears it; fix latent reduce_types pop bug in single-key export path. - megatron_engine: in CP-local forward_step, expose _cp_reduce_group (CP) and _cp_dp_reduce_group (DP+CP) on cp_inputs. - lm_engine: use _cp_reduce_group to all-reduce per-sequence seqlogp/valid-count so ppl is CP-invariant; use _cp_dp_reduce_group as reduce_group for loss/entropy/vocab_* stats so the global mean is reported. Verified on a 64-GPU CP=2 SFT replay (Qwen3-30B-A3B + scale-swe data, seed=1, BS=128): with the fix, loss/avg, ppl/avg/max, vocab_*/avg match the pre-CP-local reference run to within 0.06%; grad_norm and n_tokens/n_valid_tokens are unchanged (the latter remains as fixed in PR #1242). Refs: #1242 * fix: reassemble CP packed sequences * chore: revert unnecessary changes --------- Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com> Co-authored-by: 博惟 <bowei.fw@antgroup.com>
Summary
This PR introduces two independent features for Megatron-based SFT training:
torch.cuda.memorybased flame-graph profiling, controllable via YAML configFeature 1: CP-Local Loss
Problem
When Context Parallelism (CP) is enabled, the standard post-processing in
postprocess_packed_seqs_context_parallelperforms an all-gather to reconstruct the full logits tensor across all CP ranks. For large-vocabulary models (e.g., Qwen3-30B with vocab_size=151,936), this causes OOM because the full[total_tokens, vocab_size]logits tensor must be materialized on each rank.For a Qwen3-30B-A3B SFT job with
seq_len=128KandCP=2:128K × 151,936 × bf16 ≈ 37 GiBper rank — exceeds 80 GiB GPU memorySolution
Compute cross-entropy loss locally on each CP rank using only the local logit shard, avoiding the all-gather entirely.
Key changes:
packed_context_parallel.py: Addedsplit_packed_seqs_for_context_parallel()to split labels/loss_mask using the same interleaved pattern as the input splitting. Addedgather_outputflag topostprocess_packed_seqs_context_parallel()to skip all-gather.megatron_engine.py:forward_backward_batch()acceptscp_local_lossflagcu_seqlensare scaled by1/cp_sizeto match the local token count_cp_local_labelsand_loss_mask_pre_rolledkeys are injected into the input dicteval_batchusesdp_group(with_context_parallel=True)for correct loss aggregation across DP×CP ranksWhy this is correct: Each CP rank holds a contiguous interleaved chunk of each sequence. By pre-computing labels (
roll(-1)on the full sequence, then split) before the forward pass, each rank can independently compute its local cross-entropy. The global loss is the mean over all valid tokens across all CP ranks, which equals the sum of local losses (weighted by local valid token count) — mathematically equivalent to the all-gather approach.Files changed
areal/engine/megatron_utils/packed_context_parallel.py— newsplit_packed_seqs_for_context_parallel,gather_outputparameterareal/engine/megatron_engine.py—cp_local_losslogic inforward_backward_batch,train_batch,eval_batch, label routing in_process_megatron_outputFeature 2: Configurable Memory Profiling
Problem
CUDA memory snapshot profiling (flame graph) was hardcoded to always run on steps 0 and 1. This should be a configurable feature, disabled by default.
Solution
Follow the existing
perf_tracerconfiguration pattern: add an optional config dataclass,Nonemeans disabled.Config:
Usage in YAML (opt-in):
No
memory_profilekey → feature is completely disabled (default).Key changes:
areal/api/cli_args.py— newMemoryProfileConfig, addedmemory_profilefield toBaseExperimentConfigareal/api/engine_api.py— addedstart_memory_profile(max_entries)andstop_memory_profile(path)APIareal/engine/megatron_engine.py— implementedstart_memory_profile/stop_memory_profileusingtorch.cuda.memory._record_memory_history/_dump_snapshotareal/trainer/sft/lm_engine.py— RPC forwarding forstart_memory_profile(max_entries)areal/trainer/sft_trainer.py— config-driven profiling in training loop, snapshots saved to{log_dir}/memory_snapshots/step_{N}.pickleTest
For Qwen3-30B-A3B,In 16K context length and (attn:d1p1t2c4|ffn:d1p1t1e8) parallel, cp local loss can save approximately 5GB memory, and this value is directly proportional to the length of the context.