Highlights
- Implemented missing non-inference backward paths:
- Triton
rwkv6backward (fixed-length +cu_seqlensvarlen) - Triton
rwkv7backward (fixed-length +cu_seqlensvarlen), withrwkv7_mulbackprop flowing through RWKV7 - CUDA
blocksparse_attentionbackward via CUDA-side dense analytical fallback (removedNotImplementedError) - XLA recurrent varlen (
cu_seqlens) backward with packed-varlen/state normalization - Added
forward_autotune_onlycontext manager inejkernel.ops/ejkernel.ops.configto disable backward autotune validation while keeping forward autotune active.
Quantized Matmul / Backend Work
- Added TPU Pallas hybrid QMM path:
- Packed / predecode / XLA fallback dispatch
- Shared forward + dX kernel family and custom VJP dX path
- TPU legality gates and memory-aware heuristics
- Expanded QMM backend normalization across Triton/CUDA/CuTe/XLA/Pallas.
- CUDA QMM improvements:
- cuBLASLt/CUTLASS GEMM backend options (
EJKERNEL_QMM_CUDA_GEMM) - NF4 exact table lookup path
- Expanded affine group-size coverage.
Runtime & Compatibility Fixes
- Fixed Triton QMM tracer leak by removing global decode-table state.
- Added JAX compatibility shim for
pl.ANYmovement. - Added single-device XLA fallback path for ring attention (
axis_name=None). - Fixed recurrent varlen forward unpacking bug.
- Flash attention
FlashBiastyping improvements.
Developer Experience / Quality
- Refactored many kernel interfaces into split
*_impl_fwd.py/*_impl_bwd.py. - Added/expanded docstrings across kernels/callib/ops.
- Added structured unsupported-feature errors via
EjkernelRuntimeError. - Minor style cleanup (consolidated multiline errors, sorted exports).
Test Coverage Added/Expanded
- JIT grad/VJP + numerical sanity checks for:
- RWKV6/RWKV7/RWKV7_MUL backward paths
- CUDA blocksparse attention backward
- XLA recurrent varlen backward
- Expanded QMM routing/cache/memory-capacity/grad parity coverage across backends.
What's Changed
Full Changelog: v0.0.50...v0.0.55