Skip to content

fix: wire --metal flag into remote FFN/MoE paths, add post-FFN norms,…#115

Open
MELDApps wants to merge 1 commit into
chrishayuk:mainfrom
MELDApps:fix/remote-ffn-output
Open

fix: wire --metal flag into remote FFN/MoE paths, add post-FFN norms,…#115
MELDApps wants to merge 1 commit into
chrishayuk:mainfrom
MELDApps:fix/remote-ffn-output

Conversation

@MELDApps
Copy link
Copy Markdown

Fixes for bugs found during distributed inference testing across M1 Mac Mini + HP ProDesk.

Closes #114

Changes

  • Bug 1: Add stdout flush after token print loop in run_with_remote_ffn and run_with_moe_shards
  • Bug 2: Apply pre-FFN norm before sending to remote server in prefill closure
  • Bug 3: Apply post-FFN norm to server response in both prefill and decode closures
  • Bug 4: Wire --metal flag into run_with_remote_ffn (was always using CPU backend)
  • Bug 4b: Wire --metal flag into run_with_moe_shards (was always using CPU backend)

Note: multi-token decode zero hidden state issue remains unresolved - see linked issue for full details.

… fix stdout flush

- Bug 1: Add stdout flush after token print loop in run_with_remote_ffn and run_with_moe_shards
- Bug 2: Apply pre-FFN norm (apply_norm_for_ffn) before sending to remote server in prefill
- Bug 3: Apply post-FFN norm (apply_post_ffn_norm) to server response in both prefill and decode
- Bug 4: Wire --metal flag into run_with_remote_ffn (was always using CPU backend)
- Bug 4b: Wire --metal flag into run_with_moe_shards (was always using CPU backend)

Note: multi-token decode still produces zero hidden states in Metal backend
decode_token_with_moe - root cause unresolved, see GitHub issue
@chrishayuk
Copy link
Copy Markdown
Owner

Hey @MELDApps — thanks for digging into these in #114. Reviewing the diff I found three things blocking a clean merge:

  1. A break; left mid-decode-loop in generate_with_remote_ffn (lines ~155-159 of the diff). That exits the loop after a single decode token — which is exactly the "multi-token decode produces zeros" symptom Multi-token decode produces zeros in remote inference paths (--ffn and --moe-shards) #114 documents, so I think this was bisect/debug code that didn't get cleaned up.
  2. Two eprintln!(\"[debug] ...\") lines (also in generate_with_remote_ffn and generate_with_remote_ffn_batch).
  3. The Q8K quantisation fast path (forward_single_q8k) was deleted from the decode moe_fn and replaced with the unconditional naive forward() call. That's a perf regression and orthogonal to the norm bug.

The other four fixes are great and I wanted to land them, so I opened #122 as a cherry-pick that keeps:

  • Stdout flush in both run_with_remote_ffn and run_with_moe_shards
  • --metal wiring in run_with_remote_ffn
  • Pre-FFN norm in the prefill moe_fn
  • The new apply_post_ffn_norm helper (with one ordering fix the unit tests caught — see PR description)

Your commit attribution is preserved on the first commit of #122. Will close this in favour of that. Happy to revisit the decode-loop norm gap (keeping the Q8K fast path) as a follow-up — that's the real fix for the multi-token zero issue.

Note for the type-ascription change in run_with_moe_shardsdefault_backend() is currently CPU-only (see larql-compute/src/lib.rs:165), so wiring --metal into MoE shards is more involved than the type ascription; it needs the same if metal { metal_backend() } else { cpu_backend() } shape we landed in run_with_remote_ffn. Want to send a follow-up for that?

chrishayuk pushed a commit that referenced this pull request May 22, 2026
…h stdout

Cherry-pick of the four legitimate fixes from #115 against current
main. Drops the decode-loop changes (`break;`, debug eprintlns,
Q8K-fast-path removal) which were author-acknowledged work-in-progress
and would have regressed performance + correctness.

Issues found during distributed inference testing across M1 Mac Mini
+ HP ProDesk:

  * `run_with_remote_ffn` and `run_with_moe_shards` print tokens via
    `print!` without flushing. Add `std::io::Write::flush(&mut stdout())`
    after each token loop so users see output immediately rather than
    only at process exit.
  * `run_with_remote_ffn` always called `larql_compute::default_backend()`
    which is CPU-only, ignoring `--metal`. Wire the flag through:
    `metal_backend()` when set (with CPU fallback if Metal init
    fails), `cpu_backend()` otherwise.
  * Remote FFN servers return the raw FFN result without applying the
    pre/post FFN norms a local forward would. On post-norm archs
    (Gemma 3 / 4) this causes the remote path's residual stream to
    diverge from the local one. Add `apply_post_ffn_norm` next to the
    existing `apply_norm_for_ffn`, and call both in the prefill
    `moe_fn` closure of `generate_with_remote_ffn`.

The matching changes inside `generate_with_remote_ffn`'s decode loop
(also missing pre/post-FFN norms) are intentionally left for a
follow-up so this PR can land cleanly without the broken Q8K rewrite.

Supersedes #115.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Multi-token decode produces zeros in remote inference paths (--ffn and --moe-shards)

2 participants