perf(rms_norm): use fused reduce_l2_norm path (~48× faster)#1
Closed
sbryngelson wants to merge 2 commits into
Closed
perf(rms_norm): use fused reduce_l2_norm path (~48× faster)#1sbryngelson wants to merge 2 commits into
sbryngelson wants to merge 2 commits into
Conversation
The rms_norm lowering reshaped to [M,D,1,1] and ran reduce_sum over the
channel axis, which falls off the ANE's fast reduction tile past ~256 rows:
it ran 16-27x slower than layer_norm (despite RMS being structurally cheaper)
and scaled super-linearly with row count (6882us at 1024x1024, 13.7ms at
2048x1024).
Re-lower through the same fused reduce_l2_norm over the last axis that l2_norm
already uses, since the two are mathematically identical:
rms(x) = x / sqrt(mean(x^2)+eps) * g = x * sqrt(D) / sqrt(sum(x^2)) * g
eps becomes a safe-divide floor on the norm (as in l2_norm) and the sqrt(D)
rescale is folded into the gamma weight, so the op count drops too.
Measured on M5 (H17s), gamma=1:
1024x1024: 6882us -> 168us (41x), max err vs fp32 RMS 0.007
2048x1024: 13665us -> 257us (53x)
Full pytest suite: 527 passed.
sbryngelson
added a commit
that referenced
this pull request
Jun 17, 2026
Drop ASCII-art dividers, colourise example output, ASCII-only source
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
rms_normlowered to a reshape to[M,D,1,1]followed byreduce_sumover thechannel axis. That reduction falls off the ANE's fast reduction tile past
~256 rows, so
rms_normran 16–27× slower thanlayer_norm— even thoughRMS-norm is structurally cheaper (one reduction, no centering) — and scaled
super-linearly with row count.
Measured on M5 (H17s),
gamma=1:Fix
Re-lower through the same fused
reduce_l2_normover the last axis thatl2_normalready uses, since the two are mathematically identical:epsbecomes a safe-divide floor on the norm (same pattern asl2_norm), andthe
√Drescale is folded into thegammaweight, so the emitted op count dropsas well (no
[M,D,1,1]reshape, no separate square/reduce_sum/scale/rsqrt chain).Correctness / tests
test_nn_blocks::rms_norm_linear_silutolerance is 0.03).tests/test_builder_guards.py(2D guard) unchanged and passing.527 passedon M5/H17s.RMS-norm is ubiquitous in modern transformers (LLaMA/Qwen/etc.), so this is a
high-impact, low-risk lowering fix.