feat: banded local (Longformer) attention — fix O(T^2) long-audio OOM#9
Merged
Conversation
…aithful Adds NeMo rel_pos_local_attn (RelPositionMultiHeadAttentionLongformer) as a memory-bounded banded attention. This is the kernel for fixing the O(T^2) attention blowup that OOM'd long-audio offline transcription on unified-memory GPUs (a ~20-min clip allocated ~100GB and took the node down). - RelPosAttention::build_graph_local / forward_local: banded attention via pad-and-shift, peak memory O(T*window) instead of O(T*T). Each query attends only to keys in [t-att_left, t+att_right]; the positional term (q_v . p^T over the 2W+1 local pos) is added 1:1 to the banded content scores, exactly as NeMo combines them. Verified against NeMo's own sliding_chunks_matmul_qk/pv (col->key t-w+c to 1e-6) and a deterministic band reference (1.4e-3). - local_rel_pos_encoding: NeMo LocalAttRelPositionalEncoding (positions +att_left..-att_right), bit-identical to the centre rows of the full table. - pk::last_graph_alloc_bytes(): gallocr high-water accessor for the memory test. - gen_nemo_baseline.py --att-context-size (local-attention baseline); and gen_band_ref.py for the deterministic band reference. NOTE: NeMo's longformer is non-deterministic on short clips (sliding_chunks_matmul_pv reads uninitialized memory at boundaries via F.pad value=-1 + as_strided — two identical forward() calls differ by >1e3), so kernel parity must use the deterministic reference; end-to-end NeMo quality is anchored by long-audio WER. - Tests: test_relpos_attention_local (parity 1.4e-3) and test_relpos_attention_local_memory (alloc grows ~linearly, ratio 1.98). Not yet wired into the offline encoder path — follow-up. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-8
…) OOM) Wires the banded rel_pos_local_attn kernel into the offline encoder so long audio no longer allocates O(T^2) attention, which OOM'd unified-memory GPUs (a ~17-min clip drove ~100GB and took the node down). - ConformerLayer::build_graph gains optional att_left/att_right; when set it routes self-attention to RelPosAttention::build_graph_local with a LOCAL positional encoding, else keeps full attention unchanged. - Encoder::forward picks the window via local_attn_window(Tp): env PARAKEET_ATT_CONTEXT=W forces NeMo rel_pos_local_attn [W,W]; otherwise audio longer than ~11 min (>8192 encoder frames) auto-switches to W=128. Short audio keeps full attention (NeMo-exact; the encoder parity test is unchanged). - backend.cpp: bump kGraphSize 16384->65536 — the pad-and-shift kernel adds O(window) graph-node descriptors per layer. Verified end-to-end on a 16.6-min clip with tdt-0.6b-v3 (CPU, 16 threads): full attention: 151 s, 55.4 GB peak RSS banded (W=16): 41 s, 9.1 GB peak RSS (coherent transcript) ~6x less memory and ~3.7x faster; the full-attention path is what hit ~100GB and OOM'd. Short-clip transcripts: W=128 == full byte-for-byte; W=16 essentially identical. Note: pad-and-shift creates O(window) nodes and an O(window^2) incremental concat — fine for small windows but slow for W=128 on CPU; an efficient chunk-matmul construction (like NeMo's sliding_chunks) is a follow-up. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-8
…-model regression) Bumping kGraphSize 16384->65536 to fit a W=128 banded graph regressed small models ~+22% (tdt_ctc-110m): the per-compute metadata context and graph hash-set scale with kGraphSize. Revert to 16384 and instead cap the local-attention window at W=32 — the pad-and-shift kernel adds ~6*(2W+1) graph nodes/layer, and W<=32 fits every shipped model's encoder within the budget. PARAKEET_ATT_CONTEXT is clamped to 32. Regression bench (librispeech, 100 files, CPU, back-to-back): tdt_ctc-110m: master 19.5s vs banded 19.4s (within noise), 0/100 text diffs tdt-0.6b-v3: 0/100 text diffs Long-audio fix intact: 16.6-min clip + tdt-0.6b-v3 auto-uses W=32 -> 48s, 9.4 GB peak RSS (vs full attention 151s / 55.4 GB). Lifting the window cap to NeMo's [128,128] needs the efficient chunk-matmul construction (O(1) graph nodes) — follow-up. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-8
Mirror the B=1 banded path (build_graph_local) into the fused batched encoder so long-audio batches also use NeMo rel_pos_local_attn (O(T*window)) instead of full O(T^2) attention. RelPosAttention::build_graph_batched_local builds the 4D ([dk,T,H,B]) pad-and-shift band: K/V padded on the time axis, per-window-column views, sum_rows content scores + mul_mat positional scores (shared pos broadcast over B), a per-item band mask [P,T,1,B] keyed on each item's valid_len, soft_max over the window, then the context gather and head merge. Conformer build_graph_batched and the batched encoder forward route to it when att_left/att_right >= 0, with the shared LOCAL positional encoding. Verified on dgx (tdt_ctc-110m): the new test_encoder_batch_local exercises the path at the production window (W=32 = kMaxLocalWindow). item0 (the full clip) is bit-exact beside its shorter padded neighbour (no cross-item leak), and the padded item1 matches its standalone run within 5e-2/5e-2 - the same tolerance the full-attention batch test uses. Tighter-than-production windows only amplify float noise on near-zero activations of the padded clip (item0 stays exact, mean|d| ~1e-2); not pad leakage. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-8
The pad-and-shift banded path (build_graph_local) is correct but emits O(window) graph nodes per layer (a P-iteration view+mul+sum_rows+concat loop), which is why the window was capped at W=32. build_graph_local_chunked computes the exact same NeMo rel_pos_local_attn output with a fixed, O(1) number of nodes regardless of window, lifting the cap toward NeMo's full [128,128]. Construction: time is tiled into chunks of C frames; each chunk carries its own C+P-1 keys/values (the P-1 halo overlaps the neighbour), so a query attends only within its chunk. K/V are gathered as OVERLAPPING strided chunk views - which ggml's view-bounds check (ggml.c: data_size = dense product of ne, ignoring nb) rejects unless the source is OVER-padded to (C+P-1)*G frames; with that pad the view is legal and a single batched ggml_mul_mat produces the per-chunk q.k blocks [C+P-1, C, G, H]. A diagonal "skew" view (nb1 walking C+P on a [C+P-1,...] tensor, which passes the bounds check since P <= C+P-1) extracts the [P,T] band. The PV side inverse-skews the softmaxed band back to a [C+P-1, C] banded matrix (pad ne0 by C, skew-view, mask the lower off-band), then one batched matmul against the transposed V chunks gathers the context. Verified against the trusted pad-and-shift path (forward_local, itself 1.4e-3 vs a deterministic brute-force band reference): new test test_relpos_attention_local_chunked runs synthetic x/pos through the real layer-0 weights for T up to 333 and W up to 128 (chunk < W and chunk == W), matching forward_local to <1e-3 (max|d| ~6e-4). Existing pad-and-shift path and all encoder/conformer regressions unchanged. Encoder wiring (raise the cap and route long audio to this kernel) follows. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-8
… to 128
Wire the O(1)-node chunk-matmul kernel into the encoder and raise the local
window cap from 32 to NeMo's full 128. Both conformer attention paths now use
it: build_graph (B=1) -> build_graph_local_chunked, build_graph_batched (B>1)
-> build_graph_batched_local_chunked. The batched wrapper runs the 4D chunk
kernel once per item and stacks the [D,T] outputs into [D,T,B] (the chunk graph
is already 4D, so it can't also carry a batch dim); that is O(B) nodes, still
O(1) in the window, and B is small.
local_attn_window's cap (kMaxLocalWindow) goes 32 -> 128: the pad-and-shift
path emitted ~6*(2W+1) nodes/layer (hence the 32 cap to fit kGraphSize), but the
chunk-matmul path is window-independent in node count, so long audio now runs at
NeMo's full [128,128] window. The pad-and-shift build_graph_local /
build_graph_batched_local are kept as the verification oracle for
test_relpos_attention_local{,_chunked}.
Verified on dgx: full ctest green (51/51). test_encoder_batch_local passes at
every forced window W=8..128 (now through the chunked path). e2e on a 16.6-min
clip (tdt-0.6b-v3, CPU/16t), auto-local W=128: 36.8s / 9.8GB peak RSS, coherent
transcript - faster than the W=32 pad-and-shift capstone (41-48s / 9.1GB) at a
4x wider, NeMo-faithful window, and ~5.6x under the full-attention path that
OOM'd the node (151s / 55GB).
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-8
Document the banded local attention (rel_pos_local_attn) memory/speed win that the chunk-matmul kernel enables. 16.6-min clip, tdt-0.6b-v3, GB10 CPU/16t: global O(T^2) attention 148.3s / 54.0GB vs banded W=128 36.9s / 9.4GB (~4x faster, ~5.7x less peak RAM) at NeMo's full window, with the chunk-matmul making W=128 as cheap as W=32. Notes that short clips stay on the global path and are unchanged. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-8
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
Offline transcription ran the FastConformer encoder with global relative-position attention over the whole clip — O(T²) memory. A ~17-min / 40 MB audio file drove ~100 GB of attention activations and hard-OOM'd a unified-memory GPU (DGX Spark), taking every loaded model down with it.
Fix
Port NeMo's
rel_pos_local_attn(RelPositionMultiHeadAttentionLongformer) as memory-bounded banded attention and wire it into the offline encoder. Each query attends only to keys in[t-W, t+W]; peak memory is O(T·window) instead of O(T²).RelPosAttention::build_graph_local/forward_local: banded attention via pad-and-shift. The positional term (q_v · p^Tover the2W+1local pos) is added 1:1 to the banded content scores, exactly as NeMo combines them.local_rel_pos_encoding: NeMoLocalAttRelPositionalEncoding(positions+att_left..-att_right).ConformerLayer::build_graph/Encoder::forward: route to banded when local mode is active.PARAKEET_ATT_CONTEXT=Wforcesrel_pos_local_attn [W,W]; otherwise audio > ~11 min (>8192 encoder frames) auto-switches to W=128. Short audio keeps full attention unchanged (encoder parity test untouched).Validation
Kernel is verified against NeMo's own
sliding_chunks_matmul_qk/pv(col→keyt-w+cto 1e-6, scores to 1e-4) and a deterministic band reference (1.4e-3). Memory test: alloc grows linearly (ratio 1.98 at 2× T).End-to-end on a 16.6-min clip with
tdt-0.6b-v3(CPU, 16 threads):~6× less memory, ~3.7× faster. Short-clip transcripts:
W=128== full byte-for-byte;W=16essentially identical.Tests
test_relpos_attention_local— banded parity vs the deterministic reference.test_relpos_attention_local_memory— O(T·window) memory scaling.gen_nemo_baseline.py --att-context-size+gen_band_ref.pyreproduce the fixtures.Follow-ups (not in this PR)
build_graph_batched) still uses full attention.