fix(slow_eval): chunkwise CI fn slow-eval on multi-host GPU (bf16 + gather)#888
Closed
danbraunai-goodfire wants to merge 3 commits into
Closed
fix(slow_eval): chunkwise CI fn slow-eval on multi-host GPU (bf16 + gather)#888danbraunai-goodfire wants to merge 3 commits into
danbraunai-goodfire wants to merge 3 commits into
Conversation
… hosts
The slow/plot eval tier crashed every multi-host GPU LM run that uses a chunkwise
(transformer) CI fn, at the first slow eval. Two independent bugs:
1. The CI fn was read out in fp32, but the CI transformer's attention requests the
cuDNN flash impl, which is bf16/fp16-only on GPU ("Q must be fp16/bf16/fp8, got
float32"). The fp32 readout was unintentional — run the CI fn in bf16 (COMPUTE_DT)
like training + fast eval, and cast the readout statistics to fp32 (matplotlib
rejects bfloat16).
2. accumulate_site_reductions did np.asarray on the dp-sharded histogram sample, which
spans non-addressable devices on >1 process. Gather it across processes
(density/sums are already all-reduced).
CPU tests never caught either (attn_implementation() is "xla" off-GPU; single-process
np.asarray is addressable). The hand-rolled reference in test_slow_eval now mirrors the
bf16 readout.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ted eval.py idiom)
Collaborator
Author
|
Superseded by #905 (merged), which adopted this PR's exact approach for both bugs:
Both landed together in #905 ( |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
The slow/plot eval tier crashes every multi-host GPU LM run that uses a chunkwise (transformer) CI fn, at the first slow eval. Two independent bugs, both fixed here in
slow_eval.py:cuDNN-fp32 attention. The slow tier read the CI fn out in fp32, but the CI transformer's attention requests the cuDNN flash impl (
attn_implementation()), which only accepts fp16/bf16/fp8 on GPU:The fp32 readout was unintentional. Fix: run the CI fn in bf16 (
COMPUTE_DT), exactly like training and fast eval (eval.py/hidden_acts_eval.py), and cast the readout statistics to fp32 (matplotlib rejectsbfloat16arrays).np.asarrayon a process-sharded array.accumulate_site_reductionsdidnp.asarray(flat_lower[site])/flat_logits[site], which keep the dp-sharded batch axis. On >1 process those span non-addressable devices:Fix:
multihost_utils.process_allgather(..., tiled=True). (The density/sum reductions are all-reduced → already addressable.)Motivation and Context
Together these abort every multi-GPU LM decomposition at its first slow eval (the gRPC coordination cascade these exceptions trigger on the other ranks shows up as "connection refused" noise). The chunkwise CI fn is the first CI fn with attention, so it's the first to hit (1); the MLP CI fns have no attention.
Supersedes the fp32→XLA-fallback approach in #885: per Oli the fp32 readout was unintentional, so this removes it (run bf16) rather than working around it. #885 can be closed in favor of this.
ci_fn.pyis left untouched.How Has This Been Tested?
make type/ basedpyright clean.param_decomp/tests/test_slow_eval.py: 26 passed (MPLBACKEND=agg). The hand-rolled reference intest_reductions_match_hand_rolled_per_componentnow mirrors the bf16 readout (a borderline near->0.0count flipped by one under bf16 weight rounding).Does this PR introduce a breaking change?
No. Training and fast-eval paths are unchanged; only the slow tier switches its CI-fn readout from fp32 to bf16 and gathers the sharded histogram sample across processes.
🤖 Generated with Claude Code