Background
Current W8A8 (FP8/FP8) GEMV has ~1-2% relative error, which may cause quality degradation in LLM inference due to error accumulation across layers.
Current State
| Kernel |
Error |
Speed |
Use Case |
| BF16 |
~0.3% |
baseline |
Production |
| W8A8 Fast |
~1-2% |
10us |
Benchmark only |
| W8A8 Accurate |
<0.5% |
~20us |
Production (TODO) |
Problem
- LLM decode is autoregressive - errors compound
- 1-2% error per layer × 32 layers = significant degradation
- Softmax is sensitive to small logit differences
- Fast version is essentially benchmark-only, not production-ready
Proposed Solution
Accuracy Improvements
-
Smaller scale blocks: 128 → 32 elements
- More scales = better range coverage per block
- Trade-off: 4x more scale memory
-
2-level scaling (like official NVFP4):
- Per-block E4M3 scale + global FP32 scale
- Better dynamic range coverage
-
Kahan summation (optional):
- Reduce accumulation error
- Trade-off: more registers and ops
Target
- Relative error: <0.5% (close to BF16)
- Speed: ~1.5-2x slower than Fast version (acceptable for 50% memory bandwidth savings)
Implementation
gemv/fp8/fp8/sm120/
├── fp8_gemv.cu # Existing Fast version
└── fp8_accurate.cu # New Accurate version
API
# Fast (benchmark)
gemv_fp8_fp8_bf16_sm120(A, B, scale_A, scale_B, C)
# Accurate (production)
gemv_fp8_fp8_bf16_accurate_sm120(A, B, scale_A, scale_B, C)
Related
Background
Current W8A8 (FP8/FP8) GEMV has ~1-2% relative error, which may cause quality degradation in LLM inference due to error accumulation across layers.
Current State
Problem
Proposed Solution
Accuracy Improvements
Smaller scale blocks: 128 → 32 elements
2-level scaling (like official NVFP4):
Kahan summation (optional):
Target
Implementation
API
Related