Skip to content

CUDA and CuTE Have Arrived

Latest

Choose a tag to compare

@erfanzar erfanzar released this 09 Feb 20:56
· 65 commits to main since this release

Highlights

  • Implemented missing non-inference backward paths:
  • Triton rwkv6 backward (fixed-length + cu_seqlens varlen)
  • Triton rwkv7 backward (fixed-length + cu_seqlens varlen), with rwkv7_mul backprop flowing through RWKV7
  • CUDA blocksparse_attention backward via CUDA-side dense analytical fallback (removed NotImplementedError)
  • XLA recurrent varlen (cu_seqlens) backward with packed-varlen/state normalization
  • Added forward_autotune_only context manager in ejkernel.ops / ejkernel.ops.config to disable backward autotune validation while keeping forward autotune active.

Quantized Matmul / Backend Work

  • Added TPU Pallas hybrid QMM path:
  • Packed / predecode / XLA fallback dispatch
  • Shared forward + dX kernel family and custom VJP dX path
  • TPU legality gates and memory-aware heuristics
  • Expanded QMM backend normalization across Triton/CUDA/CuTe/XLA/Pallas.
  • CUDA QMM improvements:
  • cuBLASLt/CUTLASS GEMM backend options (EJKERNEL_QMM_CUDA_GEMM)
  • NF4 exact table lookup path
  • Expanded affine group-size coverage.

Runtime & Compatibility Fixes

  • Fixed Triton QMM tracer leak by removing global decode-table state.
  • Added JAX compatibility shim for pl.ANY movement.
  • Added single-device XLA fallback path for ring attention (axis_name=None).
  • Fixed recurrent varlen forward unpacking bug.
  • Flash attention FlashBias typing improvements.

Developer Experience / Quality

  • Refactored many kernel interfaces into split *_impl_fwd.py / *_impl_bwd.py.
  • Added/expanded docstrings across kernels/callib/ops.
  • Added structured unsupported-feature errors via EjkernelRuntimeError.
  • Minor style cleanup (consolidated multiline errors, sorted exports).

Test Coverage Added/Expanded

  • JIT grad/VJP + numerical sanity checks for:
  • RWKV6/RWKV7/RWKV7_MUL backward paths
  • CUDA blocksparse attention backward
  • XLA recurrent varlen backward
  • Expanded QMM routing/cache/memory-capacity/grad parity coverage across backends.

What's Changed

  • Add native CUDA kernels, quantized matmul, and kernel error handling by @erfanzar in #3

Full Changelog: v0.0.50...v0.0.55