RFC: Multi-GPU Unified Interface for cuLA
Status: Draft
Authors: cuLA maintainers
Created: 2026-04-08
Table of Contents
- Summary
- Motivation
- Design Goals and Non-Goals
- SageAttention Interface Analysis
- Architecture: Three-Level Stack
- Level 1 — Single-GPU Unified API
- Level 2 — Multi-GPU Wrapper
- Hardware Dispatch Tables
- API Reference (Level 1 Signatures)
- Migration Path from Current API
- Open Questions
1. Summary
This RFC proposes a multi-GPU unified interface for the cuLA library, organized as two new Python layers above the existing raw kernels:
-
Level 1 (cula.api): Single-GPU unified API — one function per algorithm family (cula_kda, cula_lightning, cula_decode) that auto-dispatches to the correct backend (SM90 / SM100 / SM103 / fallback) and handles all bookkeeping internally.
-
Level 2 (cula.dist): Multi-GPU wrapper that composes Level 1 with torch.distributed to implement Tensor Parallelism (TP), Sequence/Context Parallelism (SP/CP), and state hand-off for pipeline parallelism.
The design is modelled on SageAttention's sageattn entry point — one function that does the right thing, with advanced options via **kwargs — while remaining consistent with cuLA's BSHD ([B, S, H, D]) tensor layout and [B, H, V, K] float32 K-last state conventions.
2. Motivation
2.1 Need for a unified interface with auto-dispatch
cuLA provides multiple kernel implementations for different GPU architectures (Hopper fused, Blackwell fused, modular fallback), each exposed as separate entry points. A unified interface that auto-dispatches to the optimal kernel based on the current device would simplify caller code and reduce integration friction for downstream frameworks (SGLang, vLLM, FLA).
2.2 No multi-GPU path
There is currently no first-class multi-GPU support in cuLA core. FLA provides CP via FLACPContext, but this couples multi-GPU logic to FLA's orchestration code. This dependency needs to be severed; a replacement multi-GPU story has not been proposed.
2.3 LLM serving needs
LLM serving frameworks (SGLang, vLLM) increasingly use:
- TP to distribute model width across GPUs (each GPU handles H/TP_SIZE heads).
- SP/CP to distribute context length across GPUs (each GPU handles T/SP_SIZE tokens).
- Both combined for large models with very long sequences.
For recurrent architectures (linear attention), the stateful nature adds a unique challenge: the KV state at chunk boundaries must be passed between GPUs.
3. Design Goals and Non-Goals
Goals
- One function per algorithm family. No architecture-conditional imports at the call site.
- Auto-dispatch by hardware. SM90 → Hopper fused; SM100/SM103 → Blackwell fused; fallback → modular path.
- Internalize bookkeeping.
chunk_indices, cu_seqlens normalization, initial_state dtype checking, default scale — all handled inside Level 1.
- Tensor Parallelism. Shard across head dimension
H with no kernel signature changes.
- Sequence/Context Parallelism. Shard across token dimension
T with correct state hand-off via ring-style communication.
- Compatible with FLA/SGLang conventions.
BSHD ([B, S, H, D]), [B, H, V, K] float32 K-last state, cu_seqlens int32, same parameter names.
Non-Goals
- Kernel-level nvshmem overlap (covered as an independent future work).
4. SageAttention Interface Analysis
SageAttention (thu-ml/SageAttention) provides a clean reference for unified attention API design. Key findings:
4.1 Public API Structure
SageAttention exports 6 functions from the main package + 1 from a separate sub-package:
| Function |
Role |
sageattn(q, k, v, ...) |
High-level auto-dispatch entry point |
sageattn_qk_int8_pv_fp16_triton(...) |
Triton backend, FP16 PV |
sageattn_qk_int8_pv_fp16_cuda(...) |
CUDA backend, FP16 PV (SM80+) |
sageattn_qk_int8_pv_fp8_cuda(...) |
CUDA backend, FP8 PV (SM89+) |
sageattn_qk_int8_pv_fp8_cuda_sm90(...) |
CUDA backend, FP8 PV (SM90 only) |
sageattn_varlen(...) |
Variable-length sequences |
sageattn3_blackwell(...) |
SageAttention3 for Blackwell (separate package) |
4.2 The sageattn Entry Point — Design Patterns Worth Adopting
def sageattn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
tensor_layout: str = "HND",
is_causal: bool = False,
sm_scale: Optional[float] = None,
return_lse: bool = False,
**kwargs: Any,
) -> torch.Tensor:
Pattern 1: Hardware auto-dispatch. sageattn inspects q.device to detect compute capability and dispatches:
| Architecture |
Dispatches To |
| SM80 (A100) |
sageattn_qk_int8_pv_fp16_cuda |
| SM86 (RTX 3090 Ti) |
sageattn_qk_int8_pv_fp16_triton |
| SM89 (RTX 4090) |
sageattn_qk_int8_pv_fp8_cuda |
| SM90 (H100) |
sageattn_qk_int8_pv_fp8_cuda_sm90 |
Pattern 2: tensor_layout parameter. Supports both "HND" (head-first BHSD) and "NHD" (seq-first BSHD). cuLA only adopts BSHD。
Pattern 3: **kwargs forwarding. Backend-specific options are forwarded to the selected backend via **kwargs. cuLA currently has unified parameter signatures across backends, but preserves this pattern for future architecture-specific tuning knobs.
Pattern 4: Backend availability at import time. Each CUDA backend's availability is checked by attempting to import compiled C extensions:
SM80_ENABLED = True/False # from sm80_compile
SM89_ENABLED = True/False # from sm89_compile
SM90_ENABLED = True/False # from sm90_compile
Pattern 5: Uniform parameter naming. All backends share q, k, v, tensor_layout, is_causal, sm_scale, return_lse — identical names and semantics. Backend-specific params are additive.
4.3 Backend-Specific Parameters: Currently Unified, Reserved for Future
SageAttention's backends differ by quantization strategy, each introducing unique parameters (e.g., pv_accum_dtype for FP8, attn_mask for Triton only). cuLA's backends differ by GPU architecture (Hopper fused / Blackwell fused / modular fallback), but the parameter signatures are currently unified — all backends accept the same set of algorithm-level parameters (scale, initial_state, cu_seqlens, safe_gate, etc.). Behavioral differences are handled internally (e.g., Hopper fused forces safe_gate=True).
However, future hardware generations may introduce architecture-specific tuning knobs (e.g., tile sizes, accumulation precision, pipeline depth). The **kwargs forwarding pattern is preserved in Level 1 to accommodate this — any unrecognized kwargs are forwarded to the selected backend, allowing new backend-specific parameters to be added without changing the Level 1 signature.
4.4 Variable-Length Interface
def sageattn_varlen(
q: torch.Tensor, # [total_q_tokens, num_qo_heads, head_dim]
k: torch.Tensor, # [total_kv_tokens, num_kv_heads, head_dim]
v: torch.Tensor,
cu_seqlens_q: torch.Tensor, # [batch_size + 1]
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
is_causal: bool = False,
sm_scale: Optional[float] = None,
smooth_k: bool = True,
**kwargs,
) -> torch.Tensor:
Notable: no tensor_layout (always NHD-style packed), no return_lse, uses per-block Triton kernels exclusively.
4.5 Key Differences from cuLA
| Aspect |
SageAttention |
cuLA |
| Algorithm type |
Softmax attention (quantized) |
Linear attention (recurrent) |
| Stateful |
No |
Yes — initial_state / final_state |
| Backends differ by |
Quantization strategy (INT8/FP8) |
GPU architecture + fusion level |
| Layout choice |
tensor_layout ("HND"=BHSD / "NHD"=BSHD) |
Always BSHD ([B, S, H, D]) |
| Varlen |
Separate function (sageattn_varlen) |
Integrated via cu_seqlens parameter |
4.6 Patterns Adopted for cuLA
| SageAttention pattern |
cuLA Level 1 adoption |
| Single entry point per family |
cula_kda, cula_lightning, cula_decode |
| Auto-dispatch by compute capability |
_get_backend(device, table, fallback) |
| Backend availability via try-import |
Lazy import of Level 0 backends |
**kwargs for backend-specific options |
All Level 1 functions accept **kwargs |
sm_scale / scale default |
K ** -0.5 |
| Uniform parameter naming |
initial_state, output_final_state, cu_seqlens |
Not adopted: tensor_layout — cuLA only supports BSHD ([B, S, H, D]) layout, consistent with SGLang/FLA conventions. No head-first BHSD variant.
5. Architecture: Three-Level Stack
┌─────────────────────────────────────────────────────────────────────────┐
│ Level 2: cula.dist │
│ cula_kda_tp, cula_kda_cp, cula_kda_tp_cp, cula_lightning_cp, ... │
│ (multi-GPU wrappers using torch.distributed / nvshmem) │
└─────────────────────────────┬───────────────────────────────────────────┘
│ calls
┌─────────────────────────────▼───────────────────────────────────────────┐
│ Level 1: cula.api │
│ cula_kda, cula_lightning, cula_decode │
│ (single-GPU, auto-dispatches to correct backend, handles bookkeeping) │
└─────────────────────────────┬───────────────────────────────────────────┘
│ calls
┌─────────────────────────────▼───────────────────────────────────────────┐
│ Level 0: cula.ops / cula.kda / cula.lightning (current codebase) │
│ chunk_kda, kda_prefill_hopper, lightning_attn_fwd, kda_decode, ... │
│ (raw kernels, no auto-dispatch, explicit hardware requirements) │
└─────────────────────────────────────────────────────────────────────────┘
Level 0 is the current codebase, untouched by this RFC. Level 1 and Level 2 are new modules.
6. Level 1 — Single-GPU Unified API
New module: cula/api.py
6.1 Hardware Dispatch Mechanism
Dispatch table keyed by (sm_major, sm_minor):
# Conceptual — cula/api/_dispatch.py
from cula.utils import get_device_sm_version
_dispatch_cache = {}
def _get_backend(device, table: dict, fallback: Callable) -> Callable:
"""
Look up the best kernel for *device* in *table*.
Falls back to *fallback* if the architecture is not in the table.
Cached per device string.
"""
key_str = str(device)
if key_str not in _dispatch_cache:
sm = get_device_sm_version(device)
_dispatch_cache[key_str] = table.get(sm, fallback)
return _dispatch_cache[key_str]
6.2 cula_kda — KDA Family
@torch.compiler.disable
def cula_kda(
q: torch.Tensor, # [B, T, H, K] bf16
k: torch.Tensor, # [B, T, H, K] bf16
v: torch.Tensor, # [B, T, H, V] bf16
g: torch.Tensor, # [B, T, H, K] or [B, T, H] fp32
beta: torch.Tensor, # [B, T, H] fp32
*,
scale: float | None = None,
initial_state: torch.Tensor | None = None, # [B, H, V, K] fp32, K-last layout
output_final_state: bool = False,
use_qk_l2norm_in_kernel: bool = False,
use_gate_in_kernel: bool = False,
cu_seqlens: torch.Tensor | None = None, # int32 or int64 → normalized to int32
safe_gate: bool = False,
lower_bound: float | None = None,
disable_recompute: bool = False,
return_intermediate_states: bool = False,
device: torch.device | str | None = None,
**kwargs, # forwarded to selected backend; reserved for future arch-specific tuning
) -> tuple[torch.Tensor, torch.Tensor | None]:
Behavior contract:
scale defaults to K ** -0.5.
device defaults to q.device.
cu_seqlens int64 is accepted and normalized to int32 internally.
chunk_indices is computed internally from cu_seqlens; callers must not pass it.
g: [B, T, H] (scalar gate, FLA convention) is expanded to [B, T, H, K] internally when needed.
initial_state, if provided, must be [B, H, V, K] float32 K-last layout — enforced internally.
**kwargs are forwarded to the selected backend, reserved for future arch-specific tuning.
Dispatch table:
SM90 → cula.kda.hopper_fused_fwd.cula_kda_prefill (safe_gate forced True)
SM100 → cula.kda.blackwell_fused_fwd (once available), fallback to chunk_kda
SM103 → same as SM100
other → cula.kda.chunk.chunk_kda (modular fallback)
6.3 cula_lightning — Lightning Attention Family
@torch.compiler.disable
def cula_lightning(
Q: torch.Tensor, # [B, T, H, D] bf16
K: torch.Tensor, # [B, T, H, D] bf16
V: torch.Tensor, # [B, T, H, D] bf16
decay: torch.Tensor, # [H] fp32, per-head scalar decay (log-space)
*,
scale: float | None = None,
initial_state: torch.Tensor | None = None, # [N, H, D, D] fp32
output_final_state: bool = False,
cu_seqlens: torch.Tensor | None = None,
state_pool: torch.Tensor | None = None, # [P, H, D, D] fp32, inplace update
initial_state_indices: torch.Tensor | None = None, # [N] int32
chunk_size: int = 64,
device: torch.device | str | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
Behavior contract:
- State I/O follows
[B, H, V, K] K-last convention. Level 1 applies BHKV→BHVK transpose internally for any kernel not yet migrated.
cu_seqlens varlen: when provided, B=1 and sequences are packed. Accepts int32 or int64.
state_pool + initial_state_indices is the cuLA-native indirect state indexing; initial_state is the FLA-style convention. When initial_state is passed with cu_seqlens, Level 1 converts to pool + indices internally.
- Dispatches to
LinearAttentionChunkwiseDecay (Blackwell) or FLA Triton fallback.
6.4 cula_decode — Unified Decode
@torch.compiler.disable
def cula_decode(
q: torch.Tensor, # [N, 1, H, K] or [1, N, H, K]
k: torch.Tensor, # [N, 1, H, K] or [1, N, H, K]
v: torch.Tensor, # [N, 1, H, V] or [1, N, H, V]
state: torch.Tensor, # [P, H, V, K] fp32, updated IN-PLACE
*,
mode: str = "kda", # "kda" or "lightning"
state_indices: torch.Tensor | None = None, # [N] int32, maps batch → pool slot
scale: float | None = None,
device: torch.device | str | None = None,
**kwargs, # mode-specific: decay (lightning); reserved for future arch-specific tuning
) -> tuple[torch.Tensor, torch.Tensor]:
Returns (o, state) where state is the same tensor (mutated in-place).
6.5 Common Conventions
| Convention |
Value |
| Tensor layout |
BSHD ([B, S, H, D]) |
| Prefill state |
[B, H, V, K] float32, K-last |
| Decode state pool |
[P, H, V, K] float32, K-last, in-place update |
cu_seqlens type |
int32 or int64 accepted; normalized to int32 |
scale default |
head_dim_K ** -0.5 |
chunk_indices |
Never exposed to caller; computed internally |
initial_state dtype |
float32 enforced |
| Thread safety |
Each call independent; no global mutable state except JIT cache |
7. Level 2 — Multi-GPU Wrapper (cula.dist)
Status: To be designed. The detailed API signatures and implementation for Level 2 are deferred to a follow-up RFC.
Level 2 (cula/dist.py) will compose Level 1 single-GPU calls with torch.distributed (and potentially nvshmem) to support multi-GPU execution. The key parallelism modes to be addressed:
| Mode |
Axis Sharded |
Key Challenge |
| TP |
Head H |
Simplest — linear attention is head-independent, no inter-rank state transfer needed |
| SP/CP |
Token T |
Requires ring-style state hand-off ([B, H, V, K] float32) between ranks via isend/irecv |
| PP |
Layer |
State management delegated to training framework; Level 1's initial_state/final_state contract is sufficient |
| Hybrid (TP+CP) |
Head H + Token T |
2D process group; SP state ring only communicates the local head shard |
Linear attention vs. FlashAttention CP communication advantage: FlashAttention's ring attention must transfer both partial output O and log-sum-exp LSE between ranks (payload scales with T_chunk). Linear attention's CP only transfers the fixed-size recurrent state [B, H, V, K] (typically 16-64 MB regardless of sequence length), with no cross-rank numerical correction needed.
Areas requiring further design:
- Detailed function signatures for
cula_kda_tp, cula_kda_cp, cula_kda_tp_cp, cula_lightning_cp, etc.
- State buffer allocation and management utilities across ranks
- Varlen support in distributed settings (sequence-complete vs. sequence-split partitioning)
cu_seqlens global vs. local convention in SP mode
- Gradient flow through CP state transfer for training
- nvshmem kernel-level compute-communication overlap as a future optimization
8. Hardware Dispatch Tables
KDA Prefill
| SM Version |
Kernel |
Notes |
(9, 0) |
cula.kda.hopper_fused_fwd.cula_kda_prefill |
SM90a, CUTLASS TMA warp-spec, safe_gate forced True |
(10, 0) |
cula.kda.blackwell_fused_fwd (once ready), fallback chunk_kda |
SM100a |
(10, 3) |
same as (10, 0) |
SM103 (B300) |
| other |
cula.kda.chunk.chunk_kda |
Modular fallback (any GPU) |
The existing get_kda_fused_fwd(device) in cula/utils.py already implements this pattern for the fused path; Level 1 extends it to handle the modular fallback.
Lightning Prefill
| SM Version |
Kernel |
Notes |
(10, 0) |
cula.ops.lightning_attn.LinearAttentionChunkwiseDecay |
SM100 CuTe DSL |
(10, 3) |
same |
SM103 |
| other |
FLA Triton fallback |
Any GPU |
Decode
| Mode |
Kernel |
Notes |
"kda" |
cula.kda.kda_decode.kda_decode |
CuTe DSL, JIT-compiled per device |
"lightning" |
cula.lightning.la_decode.linear_attention_decode |
CuTe DSL, JIT-compiled per device |
11. API Reference (Complete Signatures)
cula.api module
# cula/api.py
import torch
from typing import Optional
@torch.compiler.disable
def cula_kda(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
*,
scale: Optional[float] = None,
initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = False,
use_qk_l2norm_in_kernel: bool = False,
use_gate_in_kernel: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
safe_gate: bool = False,
lower_bound: Optional[float] = None,
disable_recompute: bool = False,
return_intermediate_states: bool = False,
device: Optional[torch.device] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ...
@torch.compiler.disable
def cula_lightning(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
decay: torch.Tensor,
*,
scale: Optional[float] = None,
initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
state_pool: Optional[torch.Tensor] = None,
initial_state_indices: Optional[torch.Tensor] = None,
chunk_size: int = 64,
device: Optional[torch.device] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ...
@torch.compiler.disable
def cula_decode(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
state: torch.Tensor,
*,
mode: str = "kda",
state_indices: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
device: Optional[torch.device] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]: ...
cula.dist module
# cula/dist.py
import torch
import torch.distributed as dist
from typing import Optional
# --- TP wrappers ---
def cula_kda_tp(
q, k, v, g, beta, *,
tp_group: Optional[dist.ProcessGroup] = None,
tp_size: int = 1,
tp_rank: int = 0,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ...
def cula_lightning_tp(
Q, K, V, decay, *,
tp_group: Optional[dist.ProcessGroup] = None,
tp_size: int = 1,
tp_rank: int = 0,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ...
# --- CP/SP wrappers ---
def cula_kda_cp(
q, k, v, g, beta, *,
sp_group: dist.ProcessGroup,
sp_size: int,
sp_rank: int,
initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = True,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ...
def cula_lightning_cp(
Q, K, V, decay, *,
sp_group: dist.ProcessGroup,
sp_size: int,
sp_rank: int,
initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = True,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ...
# --- Hybrid TP+CP ---
def cula_kda_tp_cp(
q, k, v, g, beta, *,
tp_group: dist.ProcessGroup,
sp_group: dist.ProcessGroup,
tp_size: int,
sp_size: int,
tp_rank: int,
sp_rank: int,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ...
# --- Utilities ---
def alloc_state_buffer(
batch_size: int, num_heads: int, head_dim_k: int, head_dim_v: int, *,
tp_size: int = 1, dtype: torch.dtype = torch.float32, device: str = "cuda",
) -> torch.Tensor: ...
def alloc_decode_state_pool(
pool_size: int, num_heads: int, head_dim_k: int, head_dim_v: int, *,
tp_size: int = 1, dtype: torch.dtype = torch.float32, device: str = "cuda",
) -> torch.Tensor: ...
def split_cu_seqlens_for_sp(
cu_seqlens: torch.Tensor, sp_rank: int, sp_size: int, max_seqlen: int,
) -> tuple[torch.Tensor, torch.Tensor]: ...
12. Migration Path from Current API
12.1 Level 0 → Level 1 (single-GPU users)
| Current |
Level 1 replacement |
from cula.kda import chunk_kda; chunk_kda(q, k, v, g, beta, ...) |
from cula.api import cula_kda; cula_kda(q, k, v, g, beta, ...) |
from cula.kda import kda_prefill_hopper; kda_prefill_hopper(...) |
from cula.api import cula_kda (auto-dispatches on SM90) |
from cula.utils import get_kda_fused_fwd; fn = get_kda_fused_fwd(device); fn(...) |
from cula.api import cula_kda; cula_kda(...) |
lightning_attn_fwd(Q, K, V, decay, ...) |
from cula.api import cula_lightning; cula_lightning(Q, K, V, decay, ...) |
Level 0 functions are not removed. Level 1 calls them internally. Users needing low-level options continue to import from Level 0.
12.2 FLA CP context → Level 2 CP
| Current (cula-fla) |
Level 2 replacement |
FLACPContext + chunk_kda(..., cp_context=ctx) |
cula_kda_cp(q, k, v, g, beta, sp_group=pg, sp_size=N, sp_rank=r, ...) |
The cp_context kwarg is still accepted by cula_kda (Level 1) for backward compatibility but is deprecated.
12.3 cula.__init__.py re-exports
# cula/__init__.py
from cula.api import cula_kda, cula_lightning, cula_decode
13. Future Work: nvshmem Kernel-Level Overlap
The SP ring protocol (§7.3) uses PyTorch distributed ops (isend/irecv) outside the CUDA kernel. On multi-node clusters where state transfer traverses InfiniBand, latency may exceed attention compute time for short sequences.
nvshmem enables kernel-level communication: a CUDA kernel can initiate point-to-point transfers without returning to the CPU:
- Kernel computes final state at end of local chunk.
- Issues
nvshmemx_putmem_signal_on_stream to push state to rank r+1.
- On rank
r+1, kernel waits on nvshmem signal before reading initial state.
This collapses the compute-communicate boundary to within a single kernel launch.
A placeholder nvshmem_backend kwarg is reserved in cula_kda_cp; passing nvshmem_backend=True will raise NotImplementedError until this work is complete.
14. Open Questions
Q1: Should cula.api live in cula core or cula-fla?
cula.api Level 1 currently depends on cula-fla for the modular fallback path (which imports FLA). Recommendation: cula.api lives in cula core with lazy imports. If cula-fla is not installed, Level 1 still works on SM90/SM100 using fused kernels; the modular fallback degrades gracefully with a warning.
Q2: Who owns cu_seqlens dtype normalization?
Level 0 asserts int32; FLA uses int64. Recommendation: keep Level 0 strict (int32), add normalization in Level 1.
Q3: State layout migration from existing [N, H, K, V] prefill kernels?
This RFC adopts [B, H, V, K] (K-last) as the single canonical state layout across both prefill and decode, because K-last is SMEM bank-conflict-friendly for decode kernels (threads reading different V-rows at the same K-offset land on different banks). Existing prefill kernels (chunk_kda, cula_kda_prefill) that output [N, H, K, V] (K-first / FLA convention) will need migration: Level 1 handles the transpose internally for any un-migrated kernel during the transition period.
Q4: Varlen + SP: sequence-complete vs. sequence-split?
Recommendation: implement sequence-complete first (simpler, covers inference). Add sequence-split as sp_mode="split" kwarg in a follow-up, preserving backward compatibility.
Q5: Gradient flow through CP state transfer?
Training with SP requires gradient backward through the state ring. Out of scope for this RFC but must be designed before training with SP is possible.
Summary of Proposed Changes
| File |
Action |
Description |
cula/api.py |
New |
Level 1: cula_kda, cula_lightning, cula_decode, dispatch logic |
cula/dist.py |
New |
Level 2: TP/CP/SP wrappers, state buffer utilities |
cula/__init__.py |
Modify |
Re-export Level 1 functions |
cula/utils.py |
Modify |
Extend get_kda_fused_fwd to return modular fallback instead of NotImplementedError |
RFC: Multi-GPU Unified Interface for cuLA
Status: Draft
Authors: cuLA maintainers
Created: 2026-04-08
Table of Contents
1. Summary
This RFC proposes a multi-GPU unified interface for the cuLA library, organized as two new Python layers above the existing raw kernels:
Level 1 (
cula.api): Single-GPU unified API — one function per algorithm family (cula_kda,cula_lightning,cula_decode) that auto-dispatches to the correct backend (SM90 / SM100 / SM103 / fallback) and handles all bookkeeping internally.Level 2 (
cula.dist): Multi-GPU wrapper that composes Level 1 withtorch.distributedto implement Tensor Parallelism (TP), Sequence/Context Parallelism (SP/CP), and state hand-off for pipeline parallelism.The design is modelled on SageAttention's
sageattnentry point — one function that does the right thing, with advanced options via**kwargs— while remaining consistent with cuLA'sBSHD([B, S, H, D]) tensor layout and[B, H, V, K]float32 K-last state conventions.2. Motivation
2.1 Need for a unified interface with auto-dispatch
cuLA provides multiple kernel implementations for different GPU architectures (Hopper fused, Blackwell fused, modular fallback), each exposed as separate entry points. A unified interface that auto-dispatches to the optimal kernel based on the current device would simplify caller code and reduce integration friction for downstream frameworks (SGLang, vLLM, FLA).
2.2 No multi-GPU path
There is currently no first-class multi-GPU support in cuLA core. FLA provides CP via
FLACPContext, but this couples multi-GPU logic to FLA's orchestration code. This dependency needs to be severed; a replacement multi-GPU story has not been proposed.2.3 LLM serving needs
LLM serving frameworks (SGLang, vLLM) increasingly use:
For recurrent architectures (linear attention), the stateful nature adds a unique challenge: the KV state at chunk boundaries must be passed between GPUs.
3. Design Goals and Non-Goals
Goals
chunk_indices,cu_seqlensnormalization,initial_statedtype checking, defaultscale— all handled inside Level 1.Hwith no kernel signature changes.Twith correct state hand-off via ring-style communication.BSHD([B, S, H, D]),[B, H, V, K]float32 K-last state,cu_seqlensint32, same parameter names.Non-Goals
4. SageAttention Interface Analysis
SageAttention (thu-ml/SageAttention) provides a clean reference for unified attention API design. Key findings:
4.1 Public API Structure
SageAttention exports 6 functions from the main package + 1 from a separate sub-package:
sageattn(q, k, v, ...)sageattn_qk_int8_pv_fp16_triton(...)sageattn_qk_int8_pv_fp16_cuda(...)sageattn_qk_int8_pv_fp8_cuda(...)sageattn_qk_int8_pv_fp8_cuda_sm90(...)sageattn_varlen(...)sageattn3_blackwell(...)4.2 The
sageattnEntry Point — Design Patterns Worth AdoptingPattern 1: Hardware auto-dispatch.
sageattninspectsq.deviceto detect compute capability and dispatches:sageattn_qk_int8_pv_fp16_cudasageattn_qk_int8_pv_fp16_tritonsageattn_qk_int8_pv_fp8_cudasageattn_qk_int8_pv_fp8_cuda_sm90Pattern 2:
tensor_layoutparameter. Supports both"HND"(head-firstBHSD) and"NHD"(seq-firstBSHD). cuLA only adoptsBSHD。Pattern 3:
**kwargsforwarding. Backend-specific options are forwarded to the selected backend via**kwargs. cuLA currently has unified parameter signatures across backends, but preserves this pattern for future architecture-specific tuning knobs.Pattern 4: Backend availability at import time. Each CUDA backend's availability is checked by attempting to import compiled C extensions:
Pattern 5: Uniform parameter naming. All backends share
q, k, v, tensor_layout, is_causal, sm_scale, return_lse— identical names and semantics. Backend-specific params are additive.4.3 Backend-Specific Parameters: Currently Unified, Reserved for Future
SageAttention's backends differ by quantization strategy, each introducing unique parameters (e.g.,
pv_accum_dtypefor FP8,attn_maskfor Triton only). cuLA's backends differ by GPU architecture (Hopper fused / Blackwell fused / modular fallback), but the parameter signatures are currently unified — all backends accept the same set of algorithm-level parameters (scale,initial_state,cu_seqlens,safe_gate, etc.). Behavioral differences are handled internally (e.g., Hopper fused forcessafe_gate=True).However, future hardware generations may introduce architecture-specific tuning knobs (e.g., tile sizes, accumulation precision, pipeline depth). The
**kwargsforwarding pattern is preserved in Level 1 to accommodate this — any unrecognized kwargs are forwarded to the selected backend, allowing new backend-specific parameters to be added without changing the Level 1 signature.4.4 Variable-Length Interface
Notable: no
tensor_layout(always NHD-style packed), noreturn_lse, uses per-block Triton kernels exclusively.4.5 Key Differences from cuLA
initial_state/final_statetensor_layout("HND"=BHSD/ "NHD"=BSHD)BSHD([B, S, H, D])sageattn_varlen)cu_seqlensparameter4.6 Patterns Adopted for cuLA
cula_kda,cula_lightning,cula_decode_get_backend(device, table, fallback)**kwargsfor backend-specific options**kwargssm_scale/scaledefaultK ** -0.5initial_state,output_final_state,cu_seqlensNot adopted:
tensor_layout— cuLA only supportsBSHD([B, S, H, D]) layout, consistent with SGLang/FLA conventions. No head-firstBHSDvariant.5. Architecture: Three-Level Stack
Level 0 is the current codebase, untouched by this RFC. Level 1 and Level 2 are new modules.
6. Level 1 — Single-GPU Unified API
New module:
cula/api.py6.1 Hardware Dispatch Mechanism
Dispatch table keyed by
(sm_major, sm_minor):6.2
cula_kda— KDA FamilyBehavior contract:
scaledefaults toK ** -0.5.devicedefaults toq.device.cu_seqlensint64 is accepted and normalized to int32 internally.chunk_indicesis computed internally fromcu_seqlens; callers must not pass it.g: [B, T, H](scalar gate, FLA convention) is expanded to[B, T, H, K]internally when needed.initial_state, if provided, must be[B, H, V, K]float32 K-last layout — enforced internally.**kwargsare forwarded to the selected backend, reserved for future arch-specific tuning.Dispatch table:
6.3
cula_lightning— Lightning Attention FamilyBehavior contract:
[B, H, V, K]K-last convention. Level 1 applies BHKV→BHVK transpose internally for any kernel not yet migrated.cu_seqlensvarlen: when provided,B=1and sequences are packed. Accepts int32 or int64.state_pool+initial_state_indicesis the cuLA-native indirect state indexing;initial_stateis the FLA-style convention. Wheninitial_stateis passed withcu_seqlens, Level 1 converts to pool + indices internally.LinearAttentionChunkwiseDecay(Blackwell) or FLA Triton fallback.6.4
cula_decode— Unified DecodeReturns
(o, state)wherestateis the same tensor (mutated in-place).6.5 Common Conventions
BSHD([B, S, H, D])[B, H, V, K]float32, K-last[P, H, V, K]float32, K-last, in-place updatecu_seqlenstypescaledefaulthead_dim_K ** -0.5chunk_indicesinitial_statedtype7. Level 2 — Multi-GPU Wrapper (
cula.dist)Level 2 (
cula/dist.py) will compose Level 1 single-GPU calls withtorch.distributed(and potentially nvshmem) to support multi-GPU execution. The key parallelism modes to be addressed:HT[B, H, V, K]float32) between ranks viaisend/irecvinitial_state/final_statecontract is sufficientH+ TokenTLinear attention vs. FlashAttention CP communication advantage: FlashAttention's ring attention must transfer both partial output
Oand log-sum-expLSEbetween ranks (payload scales withT_chunk). Linear attention's CP only transfers the fixed-size recurrent state[B, H, V, K](typically 16-64 MB regardless of sequence length), with no cross-rank numerical correction needed.Areas requiring further design:
cula_kda_tp,cula_kda_cp,cula_kda_tp_cp,cula_lightning_cp, etc.cu_seqlensglobal vs. local convention in SP mode8. Hardware Dispatch Tables
KDA Prefill
(9, 0)cula.kda.hopper_fused_fwd.cula_kda_prefillsafe_gateforced True(10, 0)cula.kda.blackwell_fused_fwd(once ready), fallbackchunk_kda(10, 3)(10, 0)cula.kda.chunk.chunk_kdaThe existing
get_kda_fused_fwd(device)incula/utils.pyalready implements this pattern for the fused path; Level 1 extends it to handle the modular fallback.Lightning Prefill
(10, 0)cula.ops.lightning_attn.LinearAttentionChunkwiseDecay(10, 3)Decode
"kda"cula.kda.kda_decode.kda_decode"lightning"cula.lightning.la_decode.linear_attention_decode11. API Reference (Complete Signatures)
cula.apimodulecula.distmodule12. Migration Path from Current API
12.1 Level 0 → Level 1 (single-GPU users)
from cula.kda import chunk_kda; chunk_kda(q, k, v, g, beta, ...)from cula.api import cula_kda; cula_kda(q, k, v, g, beta, ...)from cula.kda import kda_prefill_hopper; kda_prefill_hopper(...)from cula.api import cula_kda(auto-dispatches on SM90)from cula.utils import get_kda_fused_fwd; fn = get_kda_fused_fwd(device); fn(...)from cula.api import cula_kda; cula_kda(...)lightning_attn_fwd(Q, K, V, decay, ...)from cula.api import cula_lightning; cula_lightning(Q, K, V, decay, ...)Level 0 functions are not removed. Level 1 calls them internally. Users needing low-level options continue to import from Level 0.
12.2 FLA CP context → Level 2 CP
FLACPContext+chunk_kda(..., cp_context=ctx)cula_kda_cp(q, k, v, g, beta, sp_group=pg, sp_size=N, sp_rank=r, ...)The
cp_contextkwarg is still accepted bycula_kda(Level 1) for backward compatibility but is deprecated.12.3
cula.__init__.pyre-exports13. Future Work: nvshmem Kernel-Level Overlap
The SP ring protocol (§7.3) uses PyTorch distributed ops (
isend/irecv) outside the CUDA kernel. On multi-node clusters where state transfer traverses InfiniBand, latency may exceed attention compute time for short sequences.nvshmem enables kernel-level communication: a CUDA kernel can initiate point-to-point transfers without returning to the CPU:
nvshmemx_putmem_signal_on_streamto push state to rankr+1.r+1, kernel waits on nvshmem signal before reading initial state.This collapses the compute-communicate boundary to within a single kernel launch.
A placeholder
nvshmem_backendkwarg is reserved incula_kda_cp; passingnvshmem_backend=Truewill raiseNotImplementedErroruntil this work is complete.14. Open Questions
Q1: Should
cula.apilive inculacore orcula-fla?cula.apiLevel 1 currently depends oncula-flafor the modular fallback path (which imports FLA). Recommendation:cula.apilives inculacore with lazy imports. Ifcula-flais not installed, Level 1 still works on SM90/SM100 using fused kernels; the modular fallback degrades gracefully with a warning.Q2: Who owns
cu_seqlensdtype normalization?Level 0 asserts int32; FLA uses int64. Recommendation: keep Level 0 strict (int32), add normalization in Level 1.
Q3: State layout migration from existing
[N, H, K, V]prefill kernels?This RFC adopts
[B, H, V, K](K-last) as the single canonical state layout across both prefill and decode, because K-last is SMEM bank-conflict-friendly for decode kernels (threads reading different V-rows at the same K-offset land on different banks). Existing prefill kernels (chunk_kda,cula_kda_prefill) that output[N, H, K, V](K-first / FLA convention) will need migration: Level 1 handles the transpose internally for any un-migrated kernel during the transition period.Q4: Varlen + SP: sequence-complete vs. sequence-split?
Recommendation: implement sequence-complete first (simpler, covers inference). Add sequence-split as
sp_mode="split"kwarg in a follow-up, preserving backward compatibility.Q5: Gradient flow through CP state transfer?
Training with SP requires gradient backward through the state ring. Out of scope for this RFC but must be designed before training with SP is possible.
Summary of Proposed Changes
cula/api.pycula_kda,cula_lightning,cula_decode, dispatch logiccula/dist.pycula/__init__.pycula/utils.pyget_kda_fused_fwdto return modular fallback instead ofNotImplementedError