DFlash: enable eval_from_cache + per-position (per-k) train/eval metrics#82
Merged
yubofredwang merged 2 commits intomainfrom Apr 23, 2026
Merged
DFlash: enable eval_from_cache + per-position (per-k) train/eval metrics#82yubofredwang merged 2 commits intomainfrom
yubofredwang merged 2 commits intomainfrom
Conversation
Why: the DFlash trainer was logging only scalar train/avg_loss and
train/avg_acc, and eval_from_cache was stubbed to return {}, so DFlash
runs lacked the per-position breakdown, simulated_acc_len, and eval
metrics that Eagle3 runs provide and that we rely on in wandb for
diagnosing draft quality across positions in the proposed block.
Changes
- torchspec/models/dflash.py: DFlashModel.forward now additionally
returns loss_per_position / acc_per_position (shape [block_size]),
computed under no_grad using the existing binary_eval_mask (no decay
bias) and the already-computed loss_per_token / correct. Index 0 is
the anchor slot (always zero count) and is sliced off downstream.
- torchspec/training/dflash_trainer.py:
- _forward / _train_step propagate the per-position tensors.
- _aggregate_metrics emits train/ploss_i, train/acc_i (i=0..B-2,
re-indexed so acc_0 = first predicted token, matching Eagle3)
and train/simulated_acc_len = cumulative product of per-position
accs.
- Replaces the eval stub with real eval_forward, eval_from_cache,
and _aggregate_eval_metrics mirroring Eagle3Trainer. Eval loss is
decay-weighted using self.loss_decay_gamma so eval/avg_loss is
directly comparable to train/avg_loss; eval also emits
eval/avg_acc, eval/simulated_acc_len, eval/ploss_i, eval/acc_i.
- Stale "eval hangs in colocate/SGLang mode" comment removed; the
current controller/loop pipeline invokes eval_from_cache every
eval_interval steps and is non-colocate in the DFlash training
configs.
- tests/test_dflash.py: updated 8 call sites to unpack the new 4-tuple
return from DFlashModel.forward; added a shape assertion on the new
per-position tensors.
- configs/dflash_qwen3_8b_repro.yaml, configs/sglang_qwen3_8b_dflash.yaml:
set mooncake.enable_hard_pin: true so batch_remove(force=True) is
the sole deletion path and the master-side TTL is bypassed; requires
mooncake-transfer-engine >= 0.3.10.post1 (already pinned).
Tests: pytest tests/test_dflash.py -> 62 passed.
Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
Contributor
There was a problem hiding this comment.
Pull request overview
This PR brings DFlash training/eval logging up to parity with Eagle3 by adding per-position metrics to the DFlash model/trainer and implementing a real eval_from_cache path, plus enabling Mooncake hard-pin in DFlash-related configs.
Changes:
- Extend
DFlashModel.forwardto also return per-position loss/accuracy tensors ([block_size]) and propagate them through DFlash training metric aggregation. - Implement
DFlashTrainer.eval_forward,eval_from_cache, and_aggregate_eval_metricsto compute eval metrics from the CPU eval cache and emit per-position eval stats. - Enable
mooncake.enable_hard_pin: truein DFlash YAML configs and document the dependency requirement.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
torchspec/training/dflash_trainer.py |
Adds per-position train metrics + a full eval-from-cache implementation and eval metric aggregation. |
torchspec/models/dflash.py |
Returns per-position loss/accuracy tensors from the model forward pass. |
tests/test_dflash.py |
Updates tests to unpack the expanded forward return tuple and asserts per-position tensor shapes. |
configs/sglang_qwen3_8b_dflash.yaml |
Enables Mooncake hard-pin for this DFlash config. |
configs/dflash_qwen3_8b_repro.yaml |
Enables Mooncake hard-pin for this DFlash repro config. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+349
to
+353
| # so weights become exp(-i/gamma). | ||
| weights = torch.exp(-k.float() / gamma) | ||
| else: | ||
| weights = torch.ones_like(pred_loss_pp) | ||
| weighted_avg_loss = (pred_loss_pp * weights).sum().item() / weights.sum().item() |
Comment on lines
+355
to
+359
| metrics: dict = { | ||
| "eval/avg_loss": weighted_avg_loss, | ||
| "eval/avg_acc": pred_acc_pp.mean().item(), | ||
| "eval/simulated_acc_len": simulated_acc_len, | ||
| } |
Comment on lines
+79
to
+80
| # batch_remove(force=True) (see mooncake/eagle_store.py). Requires | ||
| # mooncake-transfer-engine >= 0.3.10.post1. |
| global_segment_size: 16GB | ||
| local_buffer_size: 4GB | ||
| # Hard-pin: master-side TTL is disabled; we rely on our explicit | ||
| # batch_remove(force=True) (see mooncake/eagle_store.py). Requires |
Preserve per-position counts so train and eval aggregate DFlash metrics with the correct token weighting. This keeps ploss/acc breakdowns and eval scalars aligned with the actual objective when masks or chunk lengths are sparse. Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
DFlashModel.forwardnow returnsloss_per_positionandacc_per_position(shape[block_size]).DFlashTrainer._aggregate_metricsslices off the anchor slot and emitstrain/ploss_i,train/acc_i(fori = 0..B-2, soacc_0is the first predicted token — matches Eagle3 semantics) plustrain/simulated_acc_len(cumulative product of per-position accs).eval_from_cache→{}witheval_forward/eval_from_cache/_aggregate_eval_metricsthat mirrorEagle3Trainer. Eval loss is decay-weighted withself.loss_decay_gammasoeval/avg_lossis directly comparable totrain/avg_loss; also emitseval/avg_acc,eval/simulated_acc_len, and per-ieval/ploss_i,eval/acc_i. The stale "eval hangs in colocate/SGLang mode" comment is gone — the controller already callseval_from_cacheeveryeval_intervaland all current DFlash training configs are non-colocated.configs/dflash_qwen3_8b_repro.yamlandconfigs/sglang_qwen3_8b_dflash.yamlsetmooncake.enable_hard_pin: true. With force-delete (landed in Refactor Mooncake Store: force delete + hard pin #73), the TTL path is no longer our cleanup path; hard-pin makes `remove(force=True)` the only way an object leaves the store. Requires `mooncake-transfer-engine >= 0.3.10.post1`, already pinned in `pyproject.toml`.Why
The linked wandb decagon runs log per-TTT-position metrics and the eval suite, but DFlash runs were showing only scalar `train/avg_loss` / `train/avg_acc` with no eval panel — making it impossible to diagnose which predicted position is failing when a run diverges (historically around steps 100–500 on m27). This brings DFlash to Eagle3 parity in wandb.
Test plan