Skip to content

WIP-RFC: Multi-GPU Unified Interface for cuLA #50

@icavan

Description

@icavan

RFC: Multi-GPU Unified Interface for cuLA

Status: Draft

Authors: cuLA maintainers

Created: 2026-04-08


Table of Contents

  1. Summary
  2. Motivation
  3. Design Goals and Non-Goals
  4. SageAttention Interface Analysis
  5. Architecture: Three-Level Stack
  6. Level 1 — Single-GPU Unified API
  7. Level 2 — Multi-GPU Wrapper
  8. Hardware Dispatch Tables
  9. API Reference (Level 1 Signatures)
  10. Migration Path from Current API
  11. Open Questions

1. Summary

This RFC proposes a multi-GPU unified interface for the cuLA library, organized as two new Python layers above the existing raw kernels:

  • Level 1 (cula.api): Single-GPU unified API — one function per algorithm family (cula_kda, cula_lightning, cula_decode) that auto-dispatches to the correct backend (SM90 / SM100 / SM103 / fallback) and handles all bookkeeping internally.

  • Level 2 (cula.dist): Multi-GPU wrapper that composes Level 1 with torch.distributed to implement Tensor Parallelism (TP), Sequence/Context Parallelism (SP/CP), and state hand-off for pipeline parallelism.

The design is modelled on SageAttention's sageattn entry point — one function that does the right thing, with advanced options via **kwargs — while remaining consistent with cuLA's BSHD ([B, S, H, D]) tensor layout and [B, H, V, K] float32 K-last state conventions.


2. Motivation

2.1 Need for a unified interface with auto-dispatch

cuLA provides multiple kernel implementations for different GPU architectures (Hopper fused, Blackwell fused, modular fallback), each exposed as separate entry points. A unified interface that auto-dispatches to the optimal kernel based on the current device would simplify caller code and reduce integration friction for downstream frameworks (SGLang, vLLM, FLA).

2.2 No multi-GPU path

There is currently no first-class multi-GPU support in cuLA core. FLA provides CP via FLACPContext, but this couples multi-GPU logic to FLA's orchestration code. This dependency needs to be severed; a replacement multi-GPU story has not been proposed.

2.3 LLM serving needs

LLM serving frameworks (SGLang, vLLM) increasingly use:

  • TP to distribute model width across GPUs (each GPU handles H/TP_SIZE heads).
  • SP/CP to distribute context length across GPUs (each GPU handles T/SP_SIZE tokens).
  • Both combined for large models with very long sequences.

For recurrent architectures (linear attention), the stateful nature adds a unique challenge: the KV state at chunk boundaries must be passed between GPUs.


3. Design Goals and Non-Goals

Goals

  1. One function per algorithm family. No architecture-conditional imports at the call site.
  2. Auto-dispatch by hardware. SM90 → Hopper fused; SM100/SM103 → Blackwell fused; fallback → modular path.
  3. Internalize bookkeeping. chunk_indices, cu_seqlens normalization, initial_state dtype checking, default scale — all handled inside Level 1.
  4. Tensor Parallelism. Shard across head dimension H with no kernel signature changes.
  5. Sequence/Context Parallelism. Shard across token dimension T with correct state hand-off via ring-style communication.
  6. Compatible with FLA/SGLang conventions. BSHD ([B, S, H, D]), [B, H, V, K] float32 K-last state, cu_seqlens int32, same parameter names.

Non-Goals

  • Kernel-level nvshmem overlap (covered as an independent future work).

4. SageAttention Interface Analysis

SageAttention (thu-ml/SageAttention) provides a clean reference for unified attention API design. Key findings:

4.1 Public API Structure

SageAttention exports 6 functions from the main package + 1 from a separate sub-package:

Function Role
sageattn(q, k, v, ...) High-level auto-dispatch entry point
sageattn_qk_int8_pv_fp16_triton(...) Triton backend, FP16 PV
sageattn_qk_int8_pv_fp16_cuda(...) CUDA backend, FP16 PV (SM80+)
sageattn_qk_int8_pv_fp8_cuda(...) CUDA backend, FP8 PV (SM89+)
sageattn_qk_int8_pv_fp8_cuda_sm90(...) CUDA backend, FP8 PV (SM90 only)
sageattn_varlen(...) Variable-length sequences
sageattn3_blackwell(...) SageAttention3 for Blackwell (separate package)

4.2 The sageattn Entry Point — Design Patterns Worth Adopting

def sageattn(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    tensor_layout: str = "HND",
    is_causal: bool = False,
    sm_scale: Optional[float] = None,
    return_lse: bool = False,
    **kwargs: Any,
) -> torch.Tensor:

Pattern 1: Hardware auto-dispatch. sageattn inspects q.device to detect compute capability and dispatches:

Architecture Dispatches To
SM80 (A100) sageattn_qk_int8_pv_fp16_cuda
SM86 (RTX 3090 Ti) sageattn_qk_int8_pv_fp16_triton
SM89 (RTX 4090) sageattn_qk_int8_pv_fp8_cuda
SM90 (H100) sageattn_qk_int8_pv_fp8_cuda_sm90

Pattern 2: tensor_layout parameter. Supports both "HND" (head-first BHSD) and "NHD" (seq-first BSHD). cuLA only adopts BSHD

Pattern 3: **kwargs forwarding. Backend-specific options are forwarded to the selected backend via **kwargs. cuLA currently has unified parameter signatures across backends, but preserves this pattern for future architecture-specific tuning knobs.

Pattern 4: Backend availability at import time. Each CUDA backend's availability is checked by attempting to import compiled C extensions:

SM80_ENABLED = True/False  # from sm80_compile
SM89_ENABLED = True/False  # from sm89_compile
SM90_ENABLED = True/False  # from sm90_compile

Pattern 5: Uniform parameter naming. All backends share q, k, v, tensor_layout, is_causal, sm_scale, return_lse — identical names and semantics. Backend-specific params are additive.

4.3 Backend-Specific Parameters: Currently Unified, Reserved for Future

SageAttention's backends differ by quantization strategy, each introducing unique parameters (e.g., pv_accum_dtype for FP8, attn_mask for Triton only). cuLA's backends differ by GPU architecture (Hopper fused / Blackwell fused / modular fallback), but the parameter signatures are currently unified — all backends accept the same set of algorithm-level parameters (scale, initial_state, cu_seqlens, safe_gate, etc.). Behavioral differences are handled internally (e.g., Hopper fused forces safe_gate=True).

However, future hardware generations may introduce architecture-specific tuning knobs (e.g., tile sizes, accumulation precision, pipeline depth). The **kwargs forwarding pattern is preserved in Level 1 to accommodate this — any unrecognized kwargs are forwarded to the selected backend, allowing new backend-specific parameters to be added without changing the Level 1 signature.

4.4 Variable-Length Interface

def sageattn_varlen(
    q: torch.Tensor,            # [total_q_tokens, num_qo_heads, head_dim]
    k: torch.Tensor,            # [total_kv_tokens, num_kv_heads, head_dim]
    v: torch.Tensor,
    cu_seqlens_q: torch.Tensor, # [batch_size + 1]
    cu_seqlens_k: torch.Tensor,
    max_seqlen_q: int,
    max_seqlen_k: int,
    is_causal: bool = False,
    sm_scale: Optional[float] = None,
    smooth_k: bool = True,
    **kwargs,
) -> torch.Tensor:

Notable: no tensor_layout (always NHD-style packed), no return_lse, uses per-block Triton kernels exclusively.

4.5 Key Differences from cuLA

Aspect SageAttention cuLA
Algorithm type Softmax attention (quantized) Linear attention (recurrent)
Stateful No Yes — initial_state / final_state
Backends differ by Quantization strategy (INT8/FP8) GPU architecture + fusion level
Layout choice tensor_layout ("HND"=BHSD / "NHD"=BSHD) Always BSHD ([B, S, H, D])
Varlen Separate function (sageattn_varlen) Integrated via cu_seqlens parameter

4.6 Patterns Adopted for cuLA

SageAttention pattern cuLA Level 1 adoption
Single entry point per family cula_kda, cula_lightning, cula_decode
Auto-dispatch by compute capability _get_backend(device, table, fallback)
Backend availability via try-import Lazy import of Level 0 backends
**kwargs for backend-specific options All Level 1 functions accept **kwargs
sm_scale / scale default K ** -0.5
Uniform parameter naming initial_state, output_final_state, cu_seqlens

Not adopted: tensor_layout — cuLA only supports BSHD ([B, S, H, D]) layout, consistent with SGLang/FLA conventions. No head-first BHSD variant.


5. Architecture: Three-Level Stack

┌─────────────────────────────────────────────────────────────────────────┐
│  Level 2: cula.dist                                                     │
│  cula_kda_tp, cula_kda_cp, cula_kda_tp_cp, cula_lightning_cp, ...       │
│  (multi-GPU wrappers using torch.distributed / nvshmem)                 │
└─────────────────────────────┬───────────────────────────────────────────┘
                              │  calls
┌─────────────────────────────▼───────────────────────────────────────────┐
│  Level 1: cula.api                                                      │
│  cula_kda, cula_lightning, cula_decode                                  │
│  (single-GPU, auto-dispatches to correct backend, handles bookkeeping)  │
└─────────────────────────────┬───────────────────────────────────────────┘
                              │  calls
┌─────────────────────────────▼───────────────────────────────────────────┐
│  Level 0: cula.ops / cula.kda / cula.lightning  (current codebase)      │
│  chunk_kda, kda_prefill_hopper, lightning_attn_fwd, kda_decode, ...     │
│  (raw kernels, no auto-dispatch, explicit hardware requirements)        │
└─────────────────────────────────────────────────────────────────────────┘

Level 0 is the current codebase, untouched by this RFC. Level 1 and Level 2 are new modules.


6. Level 1 — Single-GPU Unified API

New module: cula/api.py

6.1 Hardware Dispatch Mechanism

Dispatch table keyed by (sm_major, sm_minor):

# Conceptual — cula/api/_dispatch.py

from cula.utils import get_device_sm_version

_dispatch_cache = {}

def _get_backend(device, table: dict, fallback: Callable) -> Callable:
    """
    Look up the best kernel for *device* in *table*.
    Falls back to *fallback* if the architecture is not in the table.
    Cached per device string.
    """
    key_str = str(device)
    if key_str not in _dispatch_cache:
        sm = get_device_sm_version(device)
        _dispatch_cache[key_str] = table.get(sm, fallback)
    return _dispatch_cache[key_str]

6.2 cula_kda — KDA Family

@torch.compiler.disable
def cula_kda(
    q: torch.Tensor,              # [B, T, H, K] bf16
    k: torch.Tensor,              # [B, T, H, K] bf16
    v: torch.Tensor,              # [B, T, H, V] bf16
    g: torch.Tensor,              # [B, T, H, K] or [B, T, H] fp32
    beta: torch.Tensor,           # [B, T, H] fp32
    *,
    scale: float | None = None,
    initial_state: torch.Tensor | None = None,   # [B, H, V, K] fp32, K-last layout
    output_final_state: bool = False,
    use_qk_l2norm_in_kernel: bool = False,
    use_gate_in_kernel: bool = False,
    cu_seqlens: torch.Tensor | None = None,       # int32 or int64 → normalized to int32
    safe_gate: bool = False,
    lower_bound: float | None = None,
    disable_recompute: bool = False,
    return_intermediate_states: bool = False,
    device: torch.device | str | None = None,
    **kwargs,  # forwarded to selected backend; reserved for future arch-specific tuning
) -> tuple[torch.Tensor, torch.Tensor | None]:

Behavior contract:

  • scale defaults to K ** -0.5.
  • device defaults to q.device.
  • cu_seqlens int64 is accepted and normalized to int32 internally.
  • chunk_indices is computed internally from cu_seqlens; callers must not pass it.
  • g: [B, T, H] (scalar gate, FLA convention) is expanded to [B, T, H, K] internally when needed.
  • initial_state, if provided, must be [B, H, V, K] float32 K-last layout — enforced internally.
  • **kwargs are forwarded to the selected backend, reserved for future arch-specific tuning.

Dispatch table:

SM90     → cula.kda.hopper_fused_fwd.cula_kda_prefill (safe_gate forced True)
SM100    → cula.kda.blackwell_fused_fwd (once available), fallback to chunk_kda
SM103    → same as SM100
other    → cula.kda.chunk.chunk_kda (modular fallback)

6.3 cula_lightning — Lightning Attention Family

@torch.compiler.disable
def cula_lightning(
    Q: torch.Tensor,              # [B, T, H, D] bf16
    K: torch.Tensor,              # [B, T, H, D] bf16
    V: torch.Tensor,              # [B, T, H, D] bf16
    decay: torch.Tensor,          # [H] fp32, per-head scalar decay (log-space)
    *,
    scale: float | None = None,
    initial_state: torch.Tensor | None = None,   # [N, H, D, D] fp32
    output_final_state: bool = False,
    cu_seqlens: torch.Tensor | None = None,
    state_pool: torch.Tensor | None = None,       # [P, H, D, D] fp32, inplace update
    initial_state_indices: torch.Tensor | None = None,  # [N] int32
    chunk_size: int = 64,
    device: torch.device | str | None = None,
    **kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:

Behavior contract:

  • State I/O follows [B, H, V, K] K-last convention. Level 1 applies BHKV→BHVK transpose internally for any kernel not yet migrated.
  • cu_seqlens varlen: when provided, B=1 and sequences are packed. Accepts int32 or int64.
  • state_pool + initial_state_indices is the cuLA-native indirect state indexing; initial_state is the FLA-style convention. When initial_state is passed with cu_seqlens, Level 1 converts to pool + indices internally.
  • Dispatches to LinearAttentionChunkwiseDecay (Blackwell) or FLA Triton fallback.

6.4 cula_decode — Unified Decode

@torch.compiler.disable
def cula_decode(
    q: torch.Tensor,              # [N, 1, H, K] or [1, N, H, K]
    k: torch.Tensor,              # [N, 1, H, K] or [1, N, H, K]
    v: torch.Tensor,              # [N, 1, H, V] or [1, N, H, V]
    state: torch.Tensor,          # [P, H, V, K] fp32, updated IN-PLACE
    *,
    mode: str = "kda",            # "kda" or "lightning"
    state_indices: torch.Tensor | None = None,  # [N] int32, maps batch → pool slot
    scale: float | None = None,
    device: torch.device | str | None = None,
    **kwargs,  # mode-specific: decay (lightning); reserved for future arch-specific tuning
) -> tuple[torch.Tensor, torch.Tensor]:

Returns (o, state) where state is the same tensor (mutated in-place).

6.5 Common Conventions

Convention Value
Tensor layout BSHD ([B, S, H, D])
Prefill state [B, H, V, K] float32, K-last
Decode state pool [P, H, V, K] float32, K-last, in-place update
cu_seqlens type int32 or int64 accepted; normalized to int32
scale default head_dim_K ** -0.5
chunk_indices Never exposed to caller; computed internally
initial_state dtype float32 enforced
Thread safety Each call independent; no global mutable state except JIT cache

7. Level 2 — Multi-GPU Wrapper (cula.dist)

Status: To be designed. The detailed API signatures and implementation for Level 2 are deferred to a follow-up RFC.

Level 2 (cula/dist.py) will compose Level 1 single-GPU calls with torch.distributed (and potentially nvshmem) to support multi-GPU execution. The key parallelism modes to be addressed:

Mode Axis Sharded Key Challenge
TP Head H Simplest — linear attention is head-independent, no inter-rank state transfer needed
SP/CP Token T Requires ring-style state hand-off ([B, H, V, K] float32) between ranks via isend/irecv
PP Layer State management delegated to training framework; Level 1's initial_state/final_state contract is sufficient
Hybrid (TP+CP) Head H + Token T 2D process group; SP state ring only communicates the local head shard

Linear attention vs. FlashAttention CP communication advantage: FlashAttention's ring attention must transfer both partial output O and log-sum-exp LSE between ranks (payload scales with T_chunk). Linear attention's CP only transfers the fixed-size recurrent state [B, H, V, K] (typically 16-64 MB regardless of sequence length), with no cross-rank numerical correction needed.

Areas requiring further design:

  • Detailed function signatures for cula_kda_tp, cula_kda_cp, cula_kda_tp_cp, cula_lightning_cp, etc.
  • State buffer allocation and management utilities across ranks
  • Varlen support in distributed settings (sequence-complete vs. sequence-split partitioning)
  • cu_seqlens global vs. local convention in SP mode
  • Gradient flow through CP state transfer for training
  • nvshmem kernel-level compute-communication overlap as a future optimization

8. Hardware Dispatch Tables

KDA Prefill

SM Version Kernel Notes
(9, 0) cula.kda.hopper_fused_fwd.cula_kda_prefill SM90a, CUTLASS TMA warp-spec, safe_gate forced True
(10, 0) cula.kda.blackwell_fused_fwd (once ready), fallback chunk_kda SM100a
(10, 3) same as (10, 0) SM103 (B300)
other cula.kda.chunk.chunk_kda Modular fallback (any GPU)

The existing get_kda_fused_fwd(device) in cula/utils.py already implements this pattern for the fused path; Level 1 extends it to handle the modular fallback.

Lightning Prefill

SM Version Kernel Notes
(10, 0) cula.ops.lightning_attn.LinearAttentionChunkwiseDecay SM100 CuTe DSL
(10, 3) same SM103
other FLA Triton fallback Any GPU

Decode

Mode Kernel Notes
"kda" cula.kda.kda_decode.kda_decode CuTe DSL, JIT-compiled per device
"lightning" cula.lightning.la_decode.linear_attention_decode CuTe DSL, JIT-compiled per device

11. API Reference (Complete Signatures)

cula.api module

# cula/api.py

import torch
from typing import Optional

@torch.compiler.disable
def cula_kda(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    g: torch.Tensor,
    beta: torch.Tensor,
    *,
    scale: Optional[float] = None,
    initial_state: Optional[torch.Tensor] = None,
    output_final_state: bool = False,
    use_qk_l2norm_in_kernel: bool = False,
    use_gate_in_kernel: bool = False,
    cu_seqlens: Optional[torch.Tensor] = None,
    safe_gate: bool = False,
    lower_bound: Optional[float] = None,
    disable_recompute: bool = False,
    return_intermediate_states: bool = False,
    device: Optional[torch.device] = None,
    **kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ...

@torch.compiler.disable
def cula_lightning(
    Q: torch.Tensor,
    K: torch.Tensor,
    V: torch.Tensor,
    decay: torch.Tensor,
    *,
    scale: Optional[float] = None,
    initial_state: Optional[torch.Tensor] = None,
    output_final_state: bool = False,
    cu_seqlens: Optional[torch.Tensor] = None,
    state_pool: Optional[torch.Tensor] = None,
    initial_state_indices: Optional[torch.Tensor] = None,
    chunk_size: int = 64,
    device: Optional[torch.device] = None,
    **kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ...

@torch.compiler.disable
def cula_decode(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    state: torch.Tensor,
    *,
    mode: str = "kda",
    state_indices: Optional[torch.Tensor] = None,
    scale: Optional[float] = None,
    device: Optional[torch.device] = None,
    **kwargs,
) -> tuple[torch.Tensor, torch.Tensor]: ...

cula.dist module

# cula/dist.py

import torch
import torch.distributed as dist
from typing import Optional

# --- TP wrappers ---

def cula_kda_tp(
    q, k, v, g, beta, *,
    tp_group: Optional[dist.ProcessGroup] = None,
    tp_size: int = 1,
    tp_rank: int = 0,
    **kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ...

def cula_lightning_tp(
    Q, K, V, decay, *,
    tp_group: Optional[dist.ProcessGroup] = None,
    tp_size: int = 1,
    tp_rank: int = 0,
    **kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ...

# --- CP/SP wrappers ---

def cula_kda_cp(
    q, k, v, g, beta, *,
    sp_group: dist.ProcessGroup,
    sp_size: int,
    sp_rank: int,
    initial_state: Optional[torch.Tensor] = None,
    output_final_state: bool = True,
    **kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ...

def cula_lightning_cp(
    Q, K, V, decay, *,
    sp_group: dist.ProcessGroup,
    sp_size: int,
    sp_rank: int,
    initial_state: Optional[torch.Tensor] = None,
    output_final_state: bool = True,
    **kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ...

# --- Hybrid TP+CP ---

def cula_kda_tp_cp(
    q, k, v, g, beta, *,
    tp_group: dist.ProcessGroup,
    sp_group: dist.ProcessGroup,
    tp_size: int,
    sp_size: int,
    tp_rank: int,
    sp_rank: int,
    **kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ...

# --- Utilities ---

def alloc_state_buffer(
    batch_size: int, num_heads: int, head_dim_k: int, head_dim_v: int, *,
    tp_size: int = 1, dtype: torch.dtype = torch.float32, device: str = "cuda",
) -> torch.Tensor: ...

def alloc_decode_state_pool(
    pool_size: int, num_heads: int, head_dim_k: int, head_dim_v: int, *,
    tp_size: int = 1, dtype: torch.dtype = torch.float32, device: str = "cuda",
) -> torch.Tensor: ...

def split_cu_seqlens_for_sp(
    cu_seqlens: torch.Tensor, sp_rank: int, sp_size: int, max_seqlen: int,
) -> tuple[torch.Tensor, torch.Tensor]: ...

12. Migration Path from Current API

12.1 Level 0 → Level 1 (single-GPU users)

Current Level 1 replacement
from cula.kda import chunk_kda; chunk_kda(q, k, v, g, beta, ...) from cula.api import cula_kda; cula_kda(q, k, v, g, beta, ...)
from cula.kda import kda_prefill_hopper; kda_prefill_hopper(...) from cula.api import cula_kda (auto-dispatches on SM90)
from cula.utils import get_kda_fused_fwd; fn = get_kda_fused_fwd(device); fn(...) from cula.api import cula_kda; cula_kda(...)
lightning_attn_fwd(Q, K, V, decay, ...) from cula.api import cula_lightning; cula_lightning(Q, K, V, decay, ...)

Level 0 functions are not removed. Level 1 calls them internally. Users needing low-level options continue to import from Level 0.

12.2 FLA CP context → Level 2 CP

Current (cula-fla) Level 2 replacement
FLACPContext + chunk_kda(..., cp_context=ctx) cula_kda_cp(q, k, v, g, beta, sp_group=pg, sp_size=N, sp_rank=r, ...)

The cp_context kwarg is still accepted by cula_kda (Level 1) for backward compatibility but is deprecated.

12.3 cula.__init__.py re-exports

# cula/__init__.py
from cula.api import cula_kda, cula_lightning, cula_decode

13. Future Work: nvshmem Kernel-Level Overlap

The SP ring protocol (§7.3) uses PyTorch distributed ops (isend/irecv) outside the CUDA kernel. On multi-node clusters where state transfer traverses InfiniBand, latency may exceed attention compute time for short sequences.

nvshmem enables kernel-level communication: a CUDA kernel can initiate point-to-point transfers without returning to the CPU:

  1. Kernel computes final state at end of local chunk.
  2. Issues nvshmemx_putmem_signal_on_stream to push state to rank r+1.
  3. On rank r+1, kernel waits on nvshmem signal before reading initial state.

This collapses the compute-communicate boundary to within a single kernel launch.

A placeholder nvshmem_backend kwarg is reserved in cula_kda_cp; passing nvshmem_backend=True will raise NotImplementedError until this work is complete.


14. Open Questions

Q1: Should cula.api live in cula core or cula-fla?

cula.api Level 1 currently depends on cula-fla for the modular fallback path (which imports FLA). Recommendation: cula.api lives in cula core with lazy imports. If cula-fla is not installed, Level 1 still works on SM90/SM100 using fused kernels; the modular fallback degrades gracefully with a warning.

Q2: Who owns cu_seqlens dtype normalization?

Level 0 asserts int32; FLA uses int64. Recommendation: keep Level 0 strict (int32), add normalization in Level 1.

Q3: State layout migration from existing [N, H, K, V] prefill kernels?

This RFC adopts [B, H, V, K] (K-last) as the single canonical state layout across both prefill and decode, because K-last is SMEM bank-conflict-friendly for decode kernels (threads reading different V-rows at the same K-offset land on different banks). Existing prefill kernels (chunk_kda, cula_kda_prefill) that output [N, H, K, V] (K-first / FLA convention) will need migration: Level 1 handles the transpose internally for any un-migrated kernel during the transition period.

Q4: Varlen + SP: sequence-complete vs. sequence-split?

Recommendation: implement sequence-complete first (simpler, covers inference). Add sequence-split as sp_mode="split" kwarg in a follow-up, preserving backward compatibility.

Q5: Gradient flow through CP state transfer?

Training with SP requires gradient backward through the state ring. Out of scope for this RFC but must be designed before training with SP is possible.


Summary of Proposed Changes

File Action Description
cula/api.py New Level 1: cula_kda, cula_lightning, cula_decode, dispatch logic
cula/dist.py New Level 2: TP/CP/SP wrappers, state buffer utilities
cula/__init__.py Modify Re-export Level 1 functions
cula/utils.py Modify Extend get_kda_fused_fwd to return modular fallback instead of NotImplementedError

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions