# Grouped-Query Attention (GQA)

**Grouped-Query Attention (GQA)** is a variant of MHA that improves efficiency by **reducing the number of key and value heads** while keeping the number of query heads, i.e., letting multiple query heads share the same k/v heads. It was introduced by Google Research in 2023 paper "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" [arxiv.org/abs/2305.13245](https://arxiv.org/abs/2305.13245).

Standard MHA has one key-value pair per query and achieves the highest quality, but is also expensive in memory and compute. On the other end, multi-query attention (MQA) shares one key-value pair across all queries, trading off quality for efficiency. GQA strikes a balance by grouping query heads to share fewer key/value heads, enabling higher performance with fewer parameters in the attention model, which improves speed whie preserving quality. GQA enables lower KV cache memory usage and faster inference, better scalability in large models with long sequences. 

In practice, the number of query heads and the number of k/v heads are tunable hyperparameters. Number of query heads need to be divisible by number of kv heads. Both need to be divisible by the embedding dimension. For latency-sensitive use cases, raise the num_query_heads to num_kv_heads ratio. For more expressive training, lower the ratio so that fewer query heads share the same kv head.

GQA is adopted by several major LLMs: PaLM, LLaMA 2 70B, LLaMA 3, DeepSeek-v2/v3, Qwen2/Qwen3.

## Code

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        d_k = query.size(-1)
        scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=query.dtype))
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        output = torch.matmul(attention_weights, value)
        return output, attention_weights

In [None]:
class GroupedQueryAttention(nn.Module):
    def __init__(self, embed_dim, num_query_heads, num_groups, dropout=0.1):
        super().__init__()
        assert embed_dim % num_query_heads == 0, "embed_dim must be divisible by num_query_heads"
        assert num_query_heads % num_groups == 0, "num_query_heads must be divisible by num_groups"

        self.embed_dim = embed_dim
        self.num_query_heads = num_query_heads
        self.query_head_dim = embed_dim // num_query_heads
        self.num_groups = num_groups
        self.heads_per_group = num_query_heads // num_groups
        self.kv_head_dim = self.query_head_dim * num_groups # Each group shares the same key/value head

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, self.kv_head_dim)
        self.v_proj = nn.Linear(embed_dim, self.kv_head_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        self.attn = ScaledDotProductAttention(dropout)

    def forward(self, x, mask=None):
        """
        Args:
            x: (batch_size, seq_length, embed_dim) - input tensor
            mask: (batch_size, num_query_heads, seq_length, seq_length) - optional attention mask
        Returns:
            output: (batch_size, seq_length, embed_dim) - output tensor after attention
            attn_weights: (batch_size, num_query_heads, seq_length, seq_length)
        """
        batch_size, seq_length, _ = x.shape

        # Put input X through query learnable projection layer (weight matrix plus bias), split the embed_dim dimension to the number of query heads, then transpose the second and third dimensions for multi-head attention
        # q is then of shape (batch_size, self.num_query_heads, seq_length, self.query_head_dim)
        q = self.q_proj(x).view(batch_size, seq_length, self.num_query_heads, self.query_head_dim).transpose(1, 2)

        # Put input X through the key and value learnable linear projection layer (weight matrices plus biases), split the embed_dim dimension to the number of key/value heads, then transpose the second and third dimensions for multi-head attention
        # k and v are then of shape (batch_size, self.num_kv_heads, seq_length, self.kv_head_dim)
        k = self.k_proj(x).view(batch_size, seq_length, self.num_kv_heads, self.kv_head_dim).transpose(1, 2)
        v = self.v_proj(x).view(batch_size, seq_length, self.num_kv_heads, self.kv_head_dim).transpose(1, 2)

        # Repeat key/value heads to match the number of query heads
        if self.num_query_heads != self.num_kv_heads:
            k = k.repeat_interleave(self.num_query_heads // self.num_kv_heads, dim=1)
            v = v.repeat_interleave(self.num_query_heads // self.num_kv_heads, dim=1)

        # Apply attention
        attn_output, attn_weights = self.attn(q, k, v, mask)

        # Concatenate heads and project output
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, self.embed_dim)
        output = self.out_proj(attn_output)

        return output, attn_weights

## Line-by-line Explanation

### def \_\_init\_\_(self, embed_dim, num_query_heads, num_kv_heads, dropout=0.1)

- First, make sure that embedding dimension is divisible by both the number of query heads and the number of k/v heads
  - ```
    assert embed_dim % num_query_heads == 0, "embed_dim must be divisible by num_query_heads"
    assert embed_dim % num_kv_heads == 0, "embed_dim must be divisible by num_kv_heads"
    ```
- store the embedding dimension, number of query heads and number of k/v heads to instance attributes
  - ```
    self.embed_dim = embed_dim
    self.num_query_heads = num_query_heads
    self.num_kv_heads = num_kv_heads
    ```
- calculate query head dimension and k/v head dimension
  - ```
    self.q_head_dim = embed_dim // num_query_heads
    self.kv_head_dim = embed_dim // num_kv_heads
    ```
- create the learnable weight matrices
  - ```
    self.q_proj = nn.Linear(embed_dim, embed_dim)
    self.k_proj = nn.Linear(embed_dim, embed_dim)
    self.v_proj = nn.Linear(embed_dim, embed_dim)
    self.out_proj = nn.Linear(embed_dim, embed_dim)
    ```

### def forward(self, x, mask=None)

The input tensors are
- x: (batch_size, seq_len, embed_dim)
- mask: optional attention mask for causal or padding, should be in shape (batch_size, seq_len, seq_len)

Line-by-line implementation:
- Get the batch size and sequence length from the input X's shape
  - `batch_size, seq_length, _ = x.shape`
- Compute the Q, K, V tensors and prepare them for GQA
  - ```
    q = self.q_proj(x).view(batch_size, seq_length, self.num_query_heads, self.q_head_dim).transpose(1, 2)
    k = self.k_proj(x).view(batch_size, seq_length, self.num_kv_heads, self.kv_head_dim).transpose(1, 2)
    v = self.v_proj(x).view(batch_size, seq_length, self.num_kv_heads, self.kv_head_dim).transpose(1, 2)
    ```
    - First put the input through the corresponding learnable linear projection layer
    - Then split the last dimension (embed_dim) into the corresponding number of heads dimension and head dimension 
    - Finally, transpose the seq_length and num_xx_heads dimensions for grouped multi-head attention
- Share kv heads across query heads (If there are fewer kv heads than query heads, duplicate kv heads to match the number of query heads, as the dot product is happening on this dimension)
  - ```
    if self.num_query_heads != self.num_kv_heads:
        k = k.repeat_interleave(self.num_query_heads // self.num_kv_heads, dim=1)
        v = v.repeat_interleave(self.num_query_heads // self.num_kv_heads, dim=1)
    ```
      - after this, k, v tensor shape becomes `(batch_size, num_query_heads, seq_len, kv_head_dim)`
      - q tensor shape is `(batch_size, num_query_heads, seq_len, q_head_dim)`
- Compute scaled dot-product attention of each head
  -  `attn_output, attn_weights = self.attn(q, k, v, mask)`
     -  attn_output shape: `(batch_size, num_query_heads, seq_len, )