Add k-bit blockwise quantization (K=2-5) with warp-level CUDA kernels#1858
Add k-bit blockwise quantization (K=2-5) with warp-level CUDA kernels#1858TimDettmers wants to merge 9 commits intomainfrom
Conversation
Implements Stages 0-5 of the k-bit quantization plan from cuda-spec.md: - Pure Python reference (quantize_kbit_ref, dequantize_kbit_ref) with 57 passing tests - CUDA kernels using __ballot_sync bit-plane packing and __shfl_sync codebook lookup - Test kernels (pack/unpack, memory format, codebook lookup) and production kernels - All C interface symbols exported and loadable via ctypes CUDA kernels compile but are not yet executable due to an RDC device linking issue where template instantiations in kernels.cu are not pulled into the final fatbinary. See KBIT_PROGRESS.md for diagnosis and recommended fix (move kernel bodies into ops.cu or a new self-contained file). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The "invalid device function" error was caused by mismatched kernel declarations in kernels.cuh (without __restrict__) vs definitions in ops.cu (with __restrict__). With CUDA separable compilation (-rdc=true), this created conflicting host stubs in the function registration. Fix: remove forward declarations from kernels.cuh, keep kernel definitions and launch wrappers together in ops.cu. Also added CUDA_RESOLVE_DEVICE_SYMBOLS ON to CMakeLists.txt. All 157 tests now pass: Stage 0 (Python ref), Stages 1-3 (CUDA test kernels), Stage 4 (quantize), Stage 5 (dequantize) -- covering K=2-5, fp16/bf16/fp32, various tensor sizes, and analytical error bounds. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Stage 6: Error analysis on 1M+ elements (analytical bounds, MSE, SQNR) - Stage 7: Cross-validation against existing NF4 dequant - Stage 8: Performance benchmarks (bandwidth utilization, throughput scaling) - Python API: quantize_kbit(), dequantize_kbit(), create_normal_float_codebook() in functional.py with torch.library registration in _ops.py and CUDA kernel dispatch in backends/cuda/ops.py - Codebook caching per (k, device) pair Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Not needed in the final branch. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Vectorized dequant kernel (half2 stores, 4 blocks/warp) gives 1.23-1.29x speedup over scalar kernel, reaching 80-87% of peak HBM bandwidth. Routes fp16 output through vectorized path; bf16/fp32 use scalar fallback. E4M4 uint8 absmax (bias=11, IEEE-style subnormals) reduces absmax storage from 4 bytes to 1 byte per block. K=4 drops from 5.0 to 4.25 bits/elem, matching NF4 bs=64 storage. SQNR degradation is <0.4 dB across all K values. Decode uses direct IEEE 754 bit construction for zero overhead on the dequant hot path. 240 tests passing (22 new E4M4 tests). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Remove scalar dequant kernel (vectorized is strictly better) - Remove fp32 absmax dequant path; E4M4 uint8 is now the default, fp16 absmax kept as an option - Remove Stage 1-3 test scaffolding kernels (pack/unpack, memory format, codebook lookup) and their C wrappers - Dequant always produces fp16 at the CUDA level; bf16/fp32 output via cast in Python - Net removal of 334 lines; 188 tests passing Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace half2-specific vectorized kernel with a generic version templated on T (output type) and ABSMAX_T (absmax format). Scalar stores via (T)val; hardware coalesces warp writes. No fp16 regression (within benchmark noise). Native bf16 and fp32 output at the kernel level — no Python-side cast needed. Add output dtype correctness tests (bf16/fp32 match fp16) and asymmetric codebook tests (all-positive, all-negative, skewed, non-uniform spacing, duplicate entries). 222 tests passing. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Apply ruff lint fix (unused variable), ruff format, and clang-format to pass CI pre-commit hooks. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The error bound was using a flat 1.25x multiplier on the quantization error, but E4M4 absmax quantization adds up to 1/16 (6.25%) absolute scale error. For K=5 where the codebook gap is ~0.0625, this E4M4 error is 2x the quantization error itself, exceeding the 1.25x margin. Fix by computing the bound correctly as (max_gap/2 + 1/16) * absmax, which adds both error sources instead of scaling one by a fixed factor. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
TimDettmers
left a comment
There was a problem hiding this comment.
PR Review: #1858 — Add k-bit blockwise quantization (K=2-5) with warp-level CUDA kernels
Adds a complete k-bit blockwise quantization subsystem: CUDA kernels using __ballot_sync bit-plane packing and __shfl_sync codebook lookup (blocksize=32), E4M4 uint8 absmax format, Python API (quantize_kbit/dequantize_kbit), torch.library op definitions, CUDA backend dispatch, and 222 tests covering correctness, error analysis, NF4 cross-validation, and benchmarks.
No blocking issues.
The CUDA kernel architecture is sound — warp-level primitives for packing/lookup/reduction eliminate shared memory usage entirely, the multi-block dequant amortizes codebook load, and the E4M4 decode uses direct IEEE 754 bit construction for zero overhead. Template instantiations cover all type combinations. The Python-level API follows existing patterns (torch.library.define, register_kernel, register_fake) and integrates cleanly with the existing quantization stack.
Suggestion: The create_normal_float_codebook function requires scipy at runtime (for norm.ppf), which is only in the [test] extra. If kbit quantization is intended for production use, the codebooks could be hardcoded (like NF4's get_4bit_type) to remove the scipy dependency, or scipy could be moved to a non-test dependency group.
Suggestion: The dequant CUDA backend dispatch (backends/cuda/ops.py:140-143) silently converts fp32 absmax to E4M4 before passing to the kernel. This is convenient but means a user who explicitly passes fp32 absmax gets E4M4 precision without opting in. A warning or explicit parameter might be clearer.
- Security: Clear
- Downstream impact: None (additive — new API, no changes to existing functions)
- Tests: Adequate (222 tests including reference cross-validation, error bounds, edge cases, benchmarks, asymmetric codebooks)
- CI: All pass (Lint, CPU x4 platforms, CUDA T4+L40S x3 CUDA versions, Windows CUDA, ROCm builds, XPU builds)
- torch.compile: Compatible (uses
torch.libraryregistration withregister_fakeshape functions) - Serialization: N/A (new format, no backward compat concerns yet)
|
|
||
| # --------------------------------------------------------------------------- | ||
| # K-bit blockwise quantization (K=2..5, blocksize=32) | ||
| # --------------------------------------------------------------------------- |
There was a problem hiding this comment.
The create_normal_float_codebook function imports scipy at runtime. Since scipy is only in the [test] extra, calling quantize_kbit() without a pre-built codebook in a production install will raise ImportError. Consider hardcoding the codebooks (like get_4bit_type does for NF4) or documenting the scipy requirement.
| torch.float16: "fp16abs", | ||
| } | ||
|
|
||
|
|
There was a problem hiding this comment.
When the user passes fp32 absmax, the dequant dispatch silently encodes it to E4M4 before calling the kernel. This is a lossy conversion the caller may not expect — they passed fp32 precision but get E4M4 precision. Consider either warning or documenting this behavior.
Add complete PR review posting workflow to pr_review_guide.md (Section 17): - Review format scaled to PR complexity (brief for clean, detailed for issues) - GitHub posting via gh CLI (--comment for most, --request-changes for security only, never --approve) - Inline comments via gh api with JSON temp file approach and field reference - Early termination path for trivial PRs (docs/style/test-only, Section 4.2) - Re-review workflow for author follow-ups - Workflow diagram updated to show early termination branch Fix linting_guide.md consistency: - Quick Reference no longer presents ruff-only as equivalent to pre-commit - Recommended Workflow uses pre-commit as the primary command, not optional Add agent reference documents for PR review prerequisites: - architecture_guide.md, code_standards.md, api_surface.md - downstream_integrations.md, security_guide.md, kbit_gemm_context.md Update CLAUDE.md PR review section to mention posting instructions. Update issue_maintenance_guide.md with expanded triage patterns. Tested end-to-end on PR #1858 (k-bit quantization) as the first review using this pipeline. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary
__ballot_syncpacking,__shfl_synccodebook lookup)quantize_kbit(),dequantize_kbit(),create_normal_float_codebook(),encode_absmax_e4m4(),decode_absmax_e4m4()Architecture
Each quantization block of 32 elements maps to exactly one CUDA warp. This enables:
__ballot_sync: K calls produce K uint32 words with zero bit waste for any K value. No word-boundary issues for odd K (unlike sequential packing).__shfl_sync: Each lane holds one codebook entry (up to 2^K=32 entries fit in warp width). Lookup is register-to-register (~5 cycles), no shared memory needed.__shfl_down_sync: 5 reduction steps, no shared memory, no__syncthreads().Zero shared memory used. Zero warp divergence in the hot path. Templated on output type (fp16/bf16/fp32) and absmax format (E4M4 uint8/fp16).
E4M4 absmax format
uint8 micro-float with 4-bit exponent, 4-bit mantissa, bias=11, IEEE-style subnormals. Range [6.1e-5, 31.0]. Mean encode/decode relative error: 1.1%, 95th percentile: 2.4%. SQNR degradation vs fp32 absmax: <0.4 dB across all K values. Decode uses direct IEEE 754 bit construction (
__uint_as_float) for zero overhead on the dequant hot path.Benchmarks (RTX 4090, 67M elements, E4M4 absmax)
Dequant kernel throughput
Comparison with existing NF4 (fp16 output)
K=4 is at parity with NF4 in kernel throughput (0.97x) while using 0.25 fewer bits/element. K=2 and K=3 are faster due to less data to read. Both kernels are bandwidth-bound at 68-78% of peak HBM.
Quality (SQNR, 1M elements, normal distribution)
K=4 achieves comparable quality to NF4 at 0.25 fewer bits/element (4.25 vs 4.50). K=3 offers 3.25 bits/element (4.9x compression) with 15 dB SQNR. K=2 provides 7.1x compression for extreme quantization.
Storage comparison
Key design decisions
__ballot_sync/__shfl_sync__ballot_sync__shfl_syncfrom lane registers_syncwarp primitivesAPI
The default codebook is a symmetric normal-float codebook (quantiles of N(0,1), normalized to [-1, 1]). Unlike the existing NF4 codebook which is asymmetric (7 negative, 1 zero, 8 positive), this codebook has equal representation on both sides (8 negative, 8 positive, no explicit zero). Custom codebooks of any shape can be passed via the
codebookparameter.Test plan
222 tests covering:
Files changed
csrc/ops.cu— CUDA kernels and launchers (+229 lines)csrc/pythonInterface.cpp— C interface wrappers (+125 lines)bitsandbytes/functional.py— Python API (+194 lines)bitsandbytes/_ops.py— torch.library op definitions (+44 lines)bitsandbytes/backends/cuda/ops.py— CUDA backend dispatch (+90 lines)tests/test_kbit_quantization.py— Test suite (+1372 lines)