Skip to content

GroupQueryAttention: Support omitted key/value inputs for shared KV cache architectures #28318

@justinchuby

Description

@justinchuby

Use Case

Models like Gemma 4 use shared KV caches where certain decoder layers reuse the key/value projections from earlier layers instead of computing their own. This is an efficiency optimization that reduces memory and compute — only a subset of layers perform K/V projection, and the rest borrow from a designated source layer.

We'd like to use the com.microsoft.GroupQueryAttention op for these shared-KV layers by passing the borrowed KV via past_key/past_value with omitted (or empty) key/value inputs. The op spec marks key and value as optional, suggesting this should be supported.

Current Behavior (ORT 1.25.0)

We tested three approaches on both CUDAExecutionProvider and CPUExecutionProvider. All produce the same results on both EPs:

1. Omitted k/v (empty string input names in ONNX graph) — ❌ FAILS

Input 'past_key' dimension 3 should be same as head_size, got 32 expected 16

When key is omitted, the kernel cannot infer head_size from the key tensor shape. It falls back to an incorrect computation, causing a shape validation error against past_key.

2. Zero-length k/v (shape [B, 0, kv_hidden]) — ❌ FAILS

Input 'query' and 'key' shall have same dim 1 (sequence length)

The CUDA kernel requires key.shape[1] == query.shape[1], rejecting zero-length key tensors.

3. Zero-valued k/v (shape [B, seq_len, kv_hidden], all zeros) — ⚠️ Works with caveats

This executes successfully but introduces softmax dilution: zero-key tokens get attention score 0 (q · 0 = 0), which translates to a non-zero softmax weight of e^(0 - max_score) / denominator. When multiplied by zero values, these tokens contribute nothing to the output directly, but they dilute the attention weights on the real (past) tokens.

Measured dilution (decode step, seq_len=1, 1 zero token added):

past_seq_len Attention weight on zero token Relative output difference
16 3.03% 3.03%
64 1.05% 1.05%
256 0.24% 0.24%
1024 0.06% 0.06%

This is acceptable for long-sequence decode but problematic for short sequences or accuracy-sensitive applications.

Requested Behavior

When key and value are omitted (optional inputs) but past_key and past_value are provided:

  1. Infer head_size from past_key.shape[3] instead of key.shape[2] / kv_num_heads
  2. Skip KV concatenationpresent_key = past_key (no new tokens to append)
  3. Compute attention normally using only the past KV cache as the full KV sequence

This would make the op truly support its declared optional inputs.

Benefits

For shared-KV architectures like Gemma 4, this would:

  • Eliminate per-layer Transpose + Reshape ops — shared KV is already in BNSH format matching past_key/past_value, but currently must be converted to 3D [B, S, kv_hidden] format for the standard Attention op
  • Enable a consistent GQA path for all layers — currently shared layers must fall back to the generic Attention op, preventing GQA optimization across the full model
  • Provide native sliding window support via the local_window_size attribute, rather than requiring manual attention mask construction

Current Workaround

We use the standard Attention op (not GQA) for shared KV layers, with explicit Transpose and Reshape ops to convert the shared KV from BNSH [B, kv_heads, seq, head_dim] to 3D [B, seq, kv_hidden] format.

Reproduction

Minimal ONNX model demonstrating the failure with omitted k/v:

import numpy as np
import onnxruntime as ort
from onnx import TensorProto, helper

batch, seq_len, past_seq_len = 1, 4, 16
num_heads, kv_num_heads, head_dim = 8, 4, 32

inputs = [
    helper.make_tensor_value_info("query", TensorProto.FLOAT16,
        [batch, seq_len, num_heads * head_dim]),
    helper.make_tensor_value_info("past_key", TensorProto.FLOAT16,
        [batch, kv_num_heads, past_seq_len, head_dim]),
    helper.make_tensor_value_info("past_value", TensorProto.FLOAT16,
        [batch, kv_num_heads, past_seq_len, head_dim]),
    helper.make_tensor_value_info("seqlens_k", TensorProto.INT32, [batch]),
    helper.make_tensor_value_info("total_seq_len", TensorProto.INT32, [1]),
]

# k/v omitted via empty string input names
node = helper.make_node(
    "GroupQueryAttention",
    inputs=["query", "", "", "past_key", "past_value",
            "seqlens_k", "total_seq_len"],
    outputs=["attn_output", "present_key", "present_value"],
    domain="com.microsoft",
    num_heads=num_heads, kv_num_heads=kv_num_heads,
    scale=float(head_dim ** -0.5), do_rotary=0,
)

graph = helper.make_graph(node, "test", inputs, [
    helper.make_tensor_value_info("attn_output", TensorProto.FLOAT16, None),
    helper.make_tensor_value_info("present_key", TensorProto.FLOAT16, None),
    helper.make_tensor_value_info("present_value", TensorProto.FLOAT16, None),
])
model = helper.make_model(graph, opset_imports=[
    helper.make_opsetid("", 21),
    helper.make_opsetid("com.microsoft", 1),
])
model.ir_version = 10

sess = ort.InferenceSession(model.SerializeToString(),
    providers=[("CUDAExecutionProvider", {})])

sess.run(None, {
    "query": np.random.randn(batch, seq_len,
        num_heads * head_dim).astype(np.float16),
    "past_key": np.random.randn(batch, kv_num_heads,
        past_seq_len, head_dim).astype(np.float16),
    "past_value": np.random.randn(batch, kv_num_heads,
        past_seq_len, head_dim).astype(np.float16),
    "seqlens_k": np.array([past_seq_len - 1], dtype=np.int32),
    "total_seq_len": np.array([past_seq_len], dtype=np.int32),
})
# Error: Input 'past_key' dimension 3 should be same as head_size,
#        got 32 expected 16

Environment

  • ORT: 1.25.0
  • CUDA: 13.0
  • OS: Linux
  • EPs tested: CUDAExecutionProvider, CPUExecutionProvider (same behavior)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions