fix: tile subsampling for long audio to avoid ggml 2^31 tensor overflow on GPU#19
Conversation
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
… tiling test invariant
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
0cc3194 to
236c688
Compare
Post-rebase GPU re-verification + perf regression checkAfter rebasing onto current 1. End-to-end re-run on the rebased commit (the private 51-min repro file was deleted for privacy, so used a ~52-min synthetic clip = the
2. Performance regression A/B — the only change on the common (short/normal) execution path is the
Both deltas are within run-to-run noise (rep spreads overlap), and Conclusion: no performance or accuracy regression; long audio that crashed on |
Problem
Transcribing audio longer than ~44 min on GPU crashed. There were two distinct CUDA limits in the long-audio path, both invisible to CPU tests (CPU has neither limit):
(n_mels/2)·(T/2)·conv_channelselements. Fortdt-0.6b-v3(n_mels=128, conv_channels=256) a 51-min clip is 2,521,890,816 > INT_MAX, and ggml's CUDA unary (relu) kernel indexes elements withint→ wraps negative →invalid configuration argumentinggml_cuda_op_relu. (Same 2³¹ wall PyTorch hits, canUse32BitIndexMath not working properly in Conv2D layer pytorch/pytorch#80020; NeMo chunks the subsampling conv viasubsampling_conv_chunking_factor.)gridDim.ycap. Once past (1), the encoder's banded local-attention over-pads K/V to a contiguous axisLk = (C+P-1)·ceil(T'/C) ≈ 77kfor T'=38,481, and ggml's CUDApadkernel mapsne1straight togridDim.y, which CUDA caps at 65535 →PAD failed / invalid argument.Fix
(1) Tile the subsampling stage over time (
Subsampling::forward_tiled,Encoder::forward_batch_tiled) so no conv tensor exceeds 2³¹, then run the unchanged conformer stack on the full sequence. The subsampler's receptive field is ±7 mel frames, so tiling with an 8-frame halo is bit-exact on interior frames. Done in our code (no ggml change) — covers CPU/Metal too and bounds the ~10 GB activation spike.Model::transcribe_*route long audio to the tiled path above a model-derived threshold (safe_mel_window) via onesubsampling_tile_forhelper; both batched and single-clip (CLI /transcribe_path) paths are wired.PARAKEET_SUBSAMPLING_TILE=<frames>forces it (testing).(2) Grid-stride the ggml-cuda
padkernel (third_party/ggml-patches/0004-cuda-pad-grid-stride.patch) so it handlesne1/ne2·ne3 > 65535. A kernel audit confirmedpadis the only op in the long-audio encoder that routes a large dim through the cappedgridDim.y/z(softmax/norm usegridDim.x; add/mul auto-fall-back to a flattenedx; im2col already grid-strides; cpy/scale/concat are x-only/int64). The fix is perf-neutral (when a dim ≤ 65535 the stride loop runs exactly once → identical launch geometry) and general — it lifts the ceiling to ~23 h (next limit is an unrelated bin_bcast int32 index). This is what PyTorch already does, and it's upstreamable.Validation
test_subsampling_tiling—forward_tiledvsforward: single-tile bit-exact; multi-tile worst per-frame rel ~1.8e-5.test_encoder_long—forward_batch_tiledvsforward_batch: injection layout verified (large-tile worstrel 3.4e-3).test_transcribe_tiled— full pipeline, fused vs forced-tiled, identical non-empty transcripts on batched and single-clip paths.libggml-cuda.so, GB10): 20-min banded clip +0.01%, 60-s clip +0.53% (both within run-to-run noise),proc_mstranscribe-only, output byte-identical. Re-verified end-to-end on the rebased PR head with a 52-min synthetic clip (exit 0, no CUDA error). See comment for the table.tdt-0.6b-v3, the 51-min file: PASS. Was a crash onmaster; now transcribes in ~16 s (≈192× realtime) to a complete 9,196-word transcript, exit 0, no CUDA error.🤖 Generated with Claude Code