Skip to content

fix(slow_eval): chunkwise CI fn slow-eval on multi-host GPU (bf16 + gather)#888

Closed
danbraunai-goodfire wants to merge 3 commits into
feature/jaxfrom
fix/chunkwise-ci-slow-eval-multihost
Closed

fix(slow_eval): chunkwise CI fn slow-eval on multi-host GPU (bf16 + gather)#888
danbraunai-goodfire wants to merge 3 commits into
feature/jaxfrom
fix/chunkwise-ci-slow-eval-multihost

Conversation

@danbraunai-goodfire

Copy link
Copy Markdown
Collaborator

Description

The slow/plot eval tier crashes every multi-host GPU LM run that uses a chunkwise (transformer) CI fn, at the first slow eval. Two independent bugs, both fixed here in slow_eval.py:

  1. cuDNN-fp32 attention. The slow tier read the CI fn out in fp32, but the CI transformer's attention requests the cuDNN flash impl (attn_implementation()), which only accepts fp16/bf16/fp8 on GPU:

    NotImplementedError: Q must be fp16/bf16/fp8_e4m3fn/fp8_e5m2, got float32
    

    The fp32 readout was unintentional. Fix: run the CI fn in bf16 (COMPUTE_DT), exactly like training and fast eval (eval.py / hidden_acts_eval.py), and cast the readout statistics to fp32 (matplotlib rejects bfloat16 arrays).

  2. np.asarray on a process-sharded array. accumulate_site_reductions did np.asarray(flat_lower[site]) / flat_logits[site], which keep the dp-sharded batch axis. On >1 process those span non-addressable devices:

    RuntimeError: Fetching value for `jax.Array` that spans non-addressable (non process local) devices ...
    

    Fix: multihost_utils.process_allgather(..., tiled=True). (The density/sum reductions are all-reduced → already addressable.)

Motivation and Context

Together these abort every multi-GPU LM decomposition at its first slow eval (the gRPC coordination cascade these exceptions trigger on the other ranks shows up as "connection refused" noise). The chunkwise CI fn is the first CI fn with attention, so it's the first to hit (1); the MLP CI fns have no attention.

Supersedes the fp32→XLA-fallback approach in #885: per Oli the fp32 readout was unintentional, so this removes it (run bf16) rather than working around it. #885 can be closed in favor of this. ci_fn.py is left untouched.

How Has This Been Tested?

  • make type / basedpyright clean.
  • param_decomp/tests/test_slow_eval.py: 26 passed (MPLBACKEND=agg). The hand-rolled reference in test_reductions_match_hand_rolled_per_component now mirrors the bf16 readout (a borderline near->0.0 count flipped by one under bf16 weight rounding).
  • On and-btdr (2 nodes / 16 GPU): pre-fix the chunkwise pile-4L runs abort at the first slow eval; with this fix the slow tier (CI histograms, activation density, component-CI plots) renders and training continues.

Does this PR introduce a breaking change?

No. Training and fast-eval paths are unchanged; only the slow tier switches its CI-fn readout from fp32 to bf16 and gathers the sharded histogram sample across processes.

🤖 Generated with Claude Code

danbraunai-goodfire and others added 3 commits June 23, 2026 16:31
… hosts

The slow/plot eval tier crashed every multi-host GPU LM run that uses a chunkwise
(transformer) CI fn, at the first slow eval. Two independent bugs:

1. The CI fn was read out in fp32, but the CI transformer's attention requests the
   cuDNN flash impl, which is bf16/fp16-only on GPU ("Q must be fp16/bf16/fp8, got
   float32"). The fp32 readout was unintentional — run the CI fn in bf16 (COMPUTE_DT)
   like training + fast eval, and cast the readout statistics to fp32 (matplotlib
   rejects bfloat16).

2. accumulate_site_reductions did np.asarray on the dp-sharded histogram sample, which
   spans non-addressable devices on >1 process. Gather it across processes
   (density/sums are already all-reduced).

CPU tests never caught either (attn_implementation() is "xla" off-GPU; single-process
np.asarray is addressable). The hand-rolled reference in test_slow_eval now mirrors the
bf16 readout.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@danbraunai-goodfire

Copy link
Copy Markdown
Collaborator Author

Superseded by #905 (merged), which adopted this PR's exact approach for both bugs:

  1. cuDNN-fp32 attention — CI-fn readout now runs in training precision (bf16) via cast_floating(ci_fn, COMPUTE_DT), not fp32. (slow_eval.py:123, :215)
  2. np.asarray on a process-sharded array — histogram sample now gathered with multihost_utils.process_allgather(..., tiled=True). (slow_eval.py:176, :179)

Both landed together in #905 (413549642), so this branch adds nothing over current feature/jax. (This also closes out the #885#888#905 chain.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant