Skip to content

Prefill Output Pretransposed States [V, K] for Decode #28

@icavan

Description

@icavan

Motivation

cuLA's decode kernel uses a pretransposed state layout [V, K] (K-last, i.e., [B*H, V, K]), which is bank-conflict-friendly: with K contiguous in SMEM, threads in a warp reading different V-rows at the same K-offset land on different banks. This layout should be kept — it is efficient for decode.

However, the prefill/chunk kernels (chunk_gated_delta_rule_fwd_h, kda_fwd_prefill) currently output final state in [K, V] layout (K-first, FLA convention: [N, H, K, V]). This means every prefill → decode transition requires a caller-side transpose:

# test_la_decode.py:86-91
state_cute = (
    state_4d.clone()
    .permute(0, 1, 3, 2)  # [B, H, K, V] → [B, H, V, K]
    .reshape(B * H, D, D)
    .contiguous()
)

This transpose is a full GPU memcpy (64KB per head at K=V=128 fp32), and it happens on every decode step in serving. For B=256, H=64, total transpose traffic is ~1 GB — pure waste.

The goal is to make the prefill kernels optionally output [V, K] state directly, so decode can consume it zero-copy.

Background

  • Prefill state output: chunk_delta_h.pyfinal_state: [N, H, K, V] (chunk_delta_h.py:2029)
  • Hopper fused prefill: kda_fwd_prefillfinal_state (hopper_fused_fwd.py:106)
  • Decode kernel expects: [B*H, V, K] (pretransposed) (la_decode.py:86)
  • Caller-side transpose: test_la_decode.py:86-91
  • Decode tile config: TILE_V=8, TILE_K=128, 4 warps, 2-stage cp.async pipeline — designed for K-contiguous (K-last) access

Tasks

Phase 1: Analysis

  • Trace all paths that produce final_state in prefill (chunk_gated_delta_rule_fwd_h, kda_fwd_prefill, hopper_fused_fwd) — identify where the [K, V] layout is materialized (kernel-level write pattern vs. post-kernel reshape)
  • Determine whether the prefill kernel's state accumulation loop naturally writes [K, V] or if it can be trivially reordered to write [V, K] — e.g., swapping the inner/outer loop or transposing the register tile before writeback
  • Quantify the transpose cost: measure .permute(0,1,3,2).contiguous() latency for typical sizes (B=1..256, H=64, K=V=128, fp32)

Phase 2: Prefill Kernel Modification

  • Add a transpose_state (or state_layout) parameter to chunk_gated_delta_rule_fwd_h — when enabled, write final_state as [N, H, V, K] instead of [N, H, K, V]
  • Implement the transposed writeback in the CuTe DSL kernel: either swap the store loop order, or transpose the register tile in-register before the final store (register transpose is free compared to GMEM transpose)
  • Add the same option to kda_fwd_prefill / Hopper fused path if applicable
  • Ensure initial_state input still accepts [K, V] (standard FLA convention) — the transpose only applies to the output

Phase 3: End-to-End Integration

  • Wire the transpose_state flag through the chunk forward orchestration (chunk_fwd.py) so it's exposed to callers
  • Update decode call sites to consume pretransposed state directly — remove the .permute(0,1,3,2).contiguous() workaround
  • Update tests to verify prefill [V, K] output feeds directly into decode without transpose
  • Benchmark the end-to-end prefill → decode pipeline with and without the transpose elimination

Alternative Approaches

A. V-Last [K, V] Decode Kernel

Instead of changing prefill output, write a new decode kernel that natively consumes [K, V] (V-last / V-contiguous) state. This eliminates the transpose by adapting the decode side.

Recommended Approach

Primary: Prefill transposed writeback (Phase 2 above) — minimal code change, zero decode regression, register transpose is essentially free.

Secondary: If prefill modification proves difficult (e.g., fused Hopper kernel has rigid output layout), fall back to Alternative A (V-Last decode kernel with SMEM swizzle). This is more work but self-contained on the decode side.

Success Criteria

  • Prefill can output [V, K] state directly via a flag — no post-kernel transpose needed
  • Decode consumes prefill output zero-copy — no .permute().contiguous() in the serving path
  • No prefill performance regression from transposed writeback (target: within 1%)
  • End-to-end latency improvement measurable at large batch sizes (B>=64)

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions