fix: wire --metal flag into remote FFN/MoE paths, add post-FFN norms,…#115
fix: wire --metal flag into remote FFN/MoE paths, add post-FFN norms,…#115MELDApps wants to merge 1 commit into
Conversation
… fix stdout flush - Bug 1: Add stdout flush after token print loop in run_with_remote_ffn and run_with_moe_shards - Bug 2: Apply pre-FFN norm (apply_norm_for_ffn) before sending to remote server in prefill - Bug 3: Apply post-FFN norm (apply_post_ffn_norm) to server response in both prefill and decode - Bug 4: Wire --metal flag into run_with_remote_ffn (was always using CPU backend) - Bug 4b: Wire --metal flag into run_with_moe_shards (was always using CPU backend) Note: multi-token decode still produces zero hidden states in Metal backend decode_token_with_moe - root cause unresolved, see GitHub issue
|
Hey @MELDApps — thanks for digging into these in #114. Reviewing the diff I found three things blocking a clean merge:
The other four fixes are great and I wanted to land them, so I opened #122 as a cherry-pick that keeps:
Your commit attribution is preserved on the first commit of #122. Will close this in favour of that. Happy to revisit the decode-loop norm gap (keeping the Q8K fast path) as a follow-up — that's the real fix for the multi-token zero issue. Note for the type-ascription change in |
…h stdout Cherry-pick of the four legitimate fixes from #115 against current main. Drops the decode-loop changes (`break;`, debug eprintlns, Q8K-fast-path removal) which were author-acknowledged work-in-progress and would have regressed performance + correctness. Issues found during distributed inference testing across M1 Mac Mini + HP ProDesk: * `run_with_remote_ffn` and `run_with_moe_shards` print tokens via `print!` without flushing. Add `std::io::Write::flush(&mut stdout())` after each token loop so users see output immediately rather than only at process exit. * `run_with_remote_ffn` always called `larql_compute::default_backend()` which is CPU-only, ignoring `--metal`. Wire the flag through: `metal_backend()` when set (with CPU fallback if Metal init fails), `cpu_backend()` otherwise. * Remote FFN servers return the raw FFN result without applying the pre/post FFN norms a local forward would. On post-norm archs (Gemma 3 / 4) this causes the remote path's residual stream to diverge from the local one. Add `apply_post_ffn_norm` next to the existing `apply_norm_for_ffn`, and call both in the prefill `moe_fn` closure of `generate_with_remote_ffn`. The matching changes inside `generate_with_remote_ffn`'s decode loop (also missing pre/post-FFN norms) are intentionally left for a follow-up so this PR can land cleanly without the broken Q8K rewrite. Supersedes #115.
Fixes for bugs found during distributed inference testing across M1 Mac Mini + HP ProDesk.
Closes #114
Changes
Note: multi-token decode zero hidden state issue remains unresolved - see linked issue for full details.