Skip to content

feat: W8A8 Accurate GEMV kernel (<0.5% error) #123

@m96-chan

Description

@m96-chan

Background

Current W8A8 (FP8/FP8) GEMV has ~1-2% relative error, which may cause quality degradation in LLM inference due to error accumulation across layers.

Current State

Kernel Error Speed Use Case
BF16 ~0.3% baseline Production
W8A8 Fast ~1-2% 10us Benchmark only
W8A8 Accurate <0.5% ~20us Production (TODO)

Problem

  • LLM decode is autoregressive - errors compound
  • 1-2% error per layer × 32 layers = significant degradation
  • Softmax is sensitive to small logit differences
  • Fast version is essentially benchmark-only, not production-ready

Proposed Solution

Accuracy Improvements

  1. Smaller scale blocks: 128 → 32 elements

    • More scales = better range coverage per block
    • Trade-off: 4x more scale memory
  2. 2-level scaling (like official NVFP4):

    • Per-block E4M3 scale + global FP32 scale
    • Better dynamic range coverage
  3. Kahan summation (optional):

    • Reduce accumulation error
    • Trade-off: more registers and ops

Target

  • Relative error: <0.5% (close to BF16)
  • Speed: ~1.5-2x slower than Fast version (acceptable for 50% memory bandwidth savings)

Implementation

gemv/fp8/fp8/sm120/
├── fp8_gemv.cu        # Existing Fast version
└── fp8_accurate.cu    # New Accurate version

API

# Fast (benchmark)
gemv_fp8_fp8_bf16_sm120(A, B, scale_A, scale_B, C)

# Accurate (production)
gemv_fp8_fp8_bf16_accurate_sm120(A, B, scale_A, scale_B, C)

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions