Use Case
Models like Gemma 4 use shared KV caches where certain decoder layers reuse the key/value projections from earlier layers instead of computing their own. This is an efficiency optimization that reduces memory and compute — only a subset of layers perform K/V projection, and the rest borrow from a designated source layer.
We'd like to use the com.microsoft.GroupQueryAttention op for these shared-KV layers by passing the borrowed KV via past_key/past_value with omitted (or empty) key/value inputs. The op spec marks key and value as optional, suggesting this should be supported.
Current Behavior (ORT 1.25.0)
We tested three approaches on both CUDAExecutionProvider and CPUExecutionProvider. All produce the same results on both EPs:
1. Omitted k/v (empty string input names in ONNX graph) — ❌ FAILS
Input 'past_key' dimension 3 should be same as head_size, got 32 expected 16
When key is omitted, the kernel cannot infer head_size from the key tensor shape. It falls back to an incorrect computation, causing a shape validation error against past_key.
2. Zero-length k/v (shape [B, 0, kv_hidden]) — ❌ FAILS
Input 'query' and 'key' shall have same dim 1 (sequence length)
The CUDA kernel requires key.shape[1] == query.shape[1], rejecting zero-length key tensors.
3. Zero-valued k/v (shape [B, seq_len, kv_hidden], all zeros) — ⚠️ Works with caveats
This executes successfully but introduces softmax dilution: zero-key tokens get attention score 0 (q · 0 = 0), which translates to a non-zero softmax weight of e^(0 - max_score) / denominator. When multiplied by zero values, these tokens contribute nothing to the output directly, but they dilute the attention weights on the real (past) tokens.
Measured dilution (decode step, seq_len=1, 1 zero token added):
| past_seq_len |
Attention weight on zero token |
Relative output difference |
| 16 |
3.03% |
3.03% |
| 64 |
1.05% |
1.05% |
| 256 |
0.24% |
0.24% |
| 1024 |
0.06% |
0.06% |
This is acceptable for long-sequence decode but problematic for short sequences or accuracy-sensitive applications.
Requested Behavior
When key and value are omitted (optional inputs) but past_key and past_value are provided:
- Infer
head_size from past_key.shape[3] instead of key.shape[2] / kv_num_heads
- Skip KV concatenation —
present_key = past_key (no new tokens to append)
- Compute attention normally using only the past KV cache as the full KV sequence
This would make the op truly support its declared optional inputs.
Benefits
For shared-KV architectures like Gemma 4, this would:
- Eliminate per-layer Transpose + Reshape ops — shared KV is already in BNSH format matching
past_key/past_value, but currently must be converted to 3D [B, S, kv_hidden] format for the standard Attention op
- Enable a consistent GQA path for all layers — currently shared layers must fall back to the generic
Attention op, preventing GQA optimization across the full model
- Provide native sliding window support via the
local_window_size attribute, rather than requiring manual attention mask construction
Current Workaround
We use the standard Attention op (not GQA) for shared KV layers, with explicit Transpose and Reshape ops to convert the shared KV from BNSH [B, kv_heads, seq, head_dim] to 3D [B, seq, kv_hidden] format.
Reproduction
Minimal ONNX model demonstrating the failure with omitted k/v:
import numpy as np
import onnxruntime as ort
from onnx import TensorProto, helper
batch, seq_len, past_seq_len = 1, 4, 16
num_heads, kv_num_heads, head_dim = 8, 4, 32
inputs = [
helper.make_tensor_value_info("query", TensorProto.FLOAT16,
[batch, seq_len, num_heads * head_dim]),
helper.make_tensor_value_info("past_key", TensorProto.FLOAT16,
[batch, kv_num_heads, past_seq_len, head_dim]),
helper.make_tensor_value_info("past_value", TensorProto.FLOAT16,
[batch, kv_num_heads, past_seq_len, head_dim]),
helper.make_tensor_value_info("seqlens_k", TensorProto.INT32, [batch]),
helper.make_tensor_value_info("total_seq_len", TensorProto.INT32, [1]),
]
# k/v omitted via empty string input names
node = helper.make_node(
"GroupQueryAttention",
inputs=["query", "", "", "past_key", "past_value",
"seqlens_k", "total_seq_len"],
outputs=["attn_output", "present_key", "present_value"],
domain="com.microsoft",
num_heads=num_heads, kv_num_heads=kv_num_heads,
scale=float(head_dim ** -0.5), do_rotary=0,
)
graph = helper.make_graph(node, "test", inputs, [
helper.make_tensor_value_info("attn_output", TensorProto.FLOAT16, None),
helper.make_tensor_value_info("present_key", TensorProto.FLOAT16, None),
helper.make_tensor_value_info("present_value", TensorProto.FLOAT16, None),
])
model = helper.make_model(graph, opset_imports=[
helper.make_opsetid("", 21),
helper.make_opsetid("com.microsoft", 1),
])
model.ir_version = 10
sess = ort.InferenceSession(model.SerializeToString(),
providers=[("CUDAExecutionProvider", {})])
sess.run(None, {
"query": np.random.randn(batch, seq_len,
num_heads * head_dim).astype(np.float16),
"past_key": np.random.randn(batch, kv_num_heads,
past_seq_len, head_dim).astype(np.float16),
"past_value": np.random.randn(batch, kv_num_heads,
past_seq_len, head_dim).astype(np.float16),
"seqlens_k": np.array([past_seq_len - 1], dtype=np.int32),
"total_seq_len": np.array([past_seq_len], dtype=np.int32),
})
# Error: Input 'past_key' dimension 3 should be same as head_size,
# got 32 expected 16
Environment
- ORT: 1.25.0
- CUDA: 13.0
- OS: Linux
- EPs tested: CUDAExecutionProvider, CPUExecutionProvider (same behavior)
Use Case
Models like Gemma 4 use shared KV caches where certain decoder layers reuse the key/value projections from earlier layers instead of computing their own. This is an efficiency optimization that reduces memory and compute — only a subset of layers perform K/V projection, and the rest borrow from a designated source layer.
We'd like to use the
com.microsoft.GroupQueryAttentionop for these shared-KV layers by passing the borrowed KV viapast_key/past_valuewith omitted (or empty)key/valueinputs. The op spec markskeyandvalueas optional, suggesting this should be supported.Current Behavior (ORT 1.25.0)
We tested three approaches on both CUDAExecutionProvider and CPUExecutionProvider. All produce the same results on both EPs:
1. Omitted k/v (empty string input names in ONNX graph) — ❌ FAILS
When
keyis omitted, the kernel cannot inferhead_sizefrom the key tensor shape. It falls back to an incorrect computation, causing a shape validation error againstpast_key.2. Zero-length k/v (shape
[B, 0, kv_hidden]) — ❌ FAILSThe CUDA kernel requires
key.shape[1] == query.shape[1], rejecting zero-length key tensors.3. Zero-valued k/v (shape⚠️ Works with caveats
[B, seq_len, kv_hidden], all zeros) —This executes successfully but introduces softmax dilution: zero-key tokens get attention score 0 (
q · 0 = 0), which translates to a non-zero softmax weight ofe^(0 - max_score) / denominator. When multiplied by zero values, these tokens contribute nothing to the output directly, but they dilute the attention weights on the real (past) tokens.Measured dilution (decode step, seq_len=1, 1 zero token added):
This is acceptable for long-sequence decode but problematic for short sequences or accuracy-sensitive applications.
Requested Behavior
When
keyandvalueare omitted (optional inputs) butpast_keyandpast_valueare provided:head_sizefrompast_key.shape[3]instead ofkey.shape[2] / kv_num_headspresent_key = past_key(no new tokens to append)This would make the op truly support its declared optional inputs.
Benefits
For shared-KV architectures like Gemma 4, this would:
past_key/past_value, but currently must be converted to 3D[B, S, kv_hidden]format for the standard Attention opAttentionop, preventing GQA optimization across the full modellocal_window_sizeattribute, rather than requiring manual attention mask constructionCurrent Workaround
We use the standard
Attentionop (not GQA) for shared KV layers, with explicit Transpose and Reshape ops to convert the shared KV from BNSH[B, kv_heads, seq, head_dim]to 3D[B, seq, kv_hidden]format.Reproduction
Minimal ONNX model demonstrating the failure with omitted k/v:
Environment