Use fp32 accumulation in SkipLayerNorm/EmbedLayerNorm CUDA kernels#28682
Conversation
There was a problem hiding this comment.
Pull request overview
This PR improves numerical stability of the CUDA fused normalization kernels used in transformer models by switching mean/variance (and RMS) accumulation to FP32 when inputs are FP16/BF16, reducing overflow/NaN risk while keeping outputs in the original type. It also adds local profiling utilities to measure kernel performance with Nsight Systems.
Changes:
- Promote SkipLayerNorm and EmbedLayerNorm CUDA kernel intermediate accumulation (mean/variance/RMS) from FP16/BF16 to FP32.
- Refactor
layer_norm.cuhhelpers to operate onfloatthread/block reduction values (and remove half/bfloat16 reduction overloads). - Add
nsys-based profiling + parsing scripts for reproducible kernel timing analysis.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh | Switch block-reduce inputs/outputs and normalization math to FP32 intermediates for improved stability. |
| onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu | Accumulate skip layer norm statistics in FP32; remove now-unneeded maybe2half epsilon casting. |
| onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu | Accumulate embedding sum statistics in FP32 and pass epsilon as float. |
| onnxruntime/test/python/transformers/profile_skip_layer_norm.py | New Python profiler that builds a minimal SLN model and benchmarks it (optionally with NVTX ranges). |
| onnxruntime/test/python/transformers/profile_skip_layer_norm.sh | New wrapper script to run nsys profile and parse results. |
| onnxruntime/test/python/transformers/parse_nsys.py | New SQLite parser for nsys --export=sqlite kernel timing summaries. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
- Fix SQL injection in parse_nsys.py: use parameterized queries instead of string interpolation for kernel pattern matching - Add --nvtx-range option to parse_nsys.py to filter kernels by NVTX range (e.g., 'benchmark'), eliminating the need for --skip-first to exclude warmup - Update parse_nsys.py description/epilog to reflect current purpose - Remove pip install nvtx from shell script; just check availability and warn - Fix CodeQL import warning: use 'from onnx import save_model' instead of 'import onnx' + 'onnx.save'
kunal-vaishnavi
left a comment
There was a problem hiding this comment.
Shall we remove the strict mode flag then if FP32 accumulation is now always used?
Maybe later. That flag will be slower, so it can be deprecated later. |
Description
Use fp32 accumulation in SkipLayerNormalization, SkipSimplifiedLayerNormalization, and EmbedLayerNormalization CUDA kernels to avoid overflow and improve numerical accuracy when processing fp16/bf16 data.
The original implementation accumulated mean and variance statistics in the input data type (fp16/bf16), which can overflow for large hidden sizes or when input values have large magnitude. This change promotes all intermediate accumulation (mean, variance, normalization math) to fp32, matching the approach used by TensorRT-LLM's LayerNorm kernels.
Motivation
x²/ldacross thousands of elements in fp16 easily overflows or loses precision.Key Changes
layer_norm.cuhLayerNorm,SimplifiedLayerNorm,LayerNormSmall,SimplifiedLayerNormSmallto accept and operate onfloatfor thread_data, epsilon, mu, rsigma. Removed unusedKeyValuePairSumoverloads for half/bfloat16.skip_layer_norm_impl.cuSkipLayerNormKernelandSkipLayerNormKernelSmallto accumulate in fp32 (cub::KeyValuePair<float, float>). Removedmaybe2halfhelper (no longer needed).embed_layer_norm_impl.cuTtofloat, accumulation to usefloatthread_data.profile_skip_layer_norm.pyprofile_skip_layer_norm.shparse_nsys.pyPerformance Results
Profiled on NVIDIA GPU with nsys (B=1, seq_len=2048, fp16 data, 200 iterations, skip first 5 warmup):
No measurable performance regression. The kernel is memory-bandwidth-bound, so fp32 arithmetic is completely hidden behind memory latency.
Testing
cd onnxruntime/test/python/transformers nsys profile -o sln_fp16 --export=sqlite python profile_skip_layer_norm.py --mode fp16 --warmup 5 --repeat 100 python parse_nsys.py sln_fp16.sqlite --skip-first 5Related PRs
#28442
#15660