# Multi-Query Attention (MQA)

Multi-Query Attention (MQA) is a way to improve Multi-Head Attention (MHA) inference efficiency by sharing the keys and values across all attention heads. It greatly reduces memory bandwidth required for loading key and value tensors over and over again during incremental decoding, at the cost of only a minor quality degradation. 

It was introduced in Google's 2019 paper "Fast Transformer Decoding: One Write-Head is All You Need". [arxiv.org/pdf/1911.02150](https://arxiv.org/pdf/1911.02150)

## Code

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        d_k = query.size(-1)
        scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=query.dtype))
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        output = torch.matmul(attention_weights, value)
        return output, attention_weights

In [4]:
class MultiQueryAttention(nn.Module):
    def __init__(self, embed_dim, num_query_heads, dropout=0.1):
        super().__init__()
        assert embed_dim % num_query_heads == 0, "embed_dim must be divisible by num_query_heads"

        self.embed_dim = embed_dim
        self.num_query_heads = num_query_heads
        self.head_dim = embed_dim // num_query_heads # per head dimension

        self.q_proj = nn.Linear(embed_dim, embed_dim) # multiple query heads
        self.k_proj = nn.Linear(embed_dim, self.head_dim) # shared key head
        self.v_proj = nn.Linear(embed_dim, self.head_dim) # shared value head
        self.out_proj = nn.Linear(embed_dim, embed_dim) # output projection

        self.attention = ScaledDotProductAttention(dropout)

    def forward(self, x, mask=None):
        """
        Args:
            x: (batch_size, seq_len, embed_dim) - input token representations
            mask: (batch_size, num_query_heads, seq_len, seq_len) - optional attention mask
        Returns:
            output: (batch_size, seq_len, embed_dim)
            attn_weights: (batch_size, num_query_heads, seq_len, seq_len)
        """

        batch_size, seq_len, _ = x.shape

        # Project input x through query linear projection and reshape for multi-query heads
        q = self.q_proj(x).view(batch_size, seq_len, self.num_query_heads, self.head_dim).transpose(1, 2)
        # Project input x through shared key and value linear projections, then broadcast to share across query heads
        k = self.k_proj(x).unsqueeze(1) # (batch_size, 1, seq_len, head_dim)
        v = self.v_proj(x).unsqueeze(1) # (batch_size, 1, seq_len, head_dim)

        # Multi-head Scaled dot-product attention. attn_output: (batch_size, num_query_heads, seq_len, head_dim)
        attn_output, attn_weights = self.attention(q, k, v, mask)
        # (batch_size, num_query_heads, seq_len, head_dim) -> (batch_size, seq_len, num_query_heads, head_dim) 
        # -> reshape to (batch_size, seq_len, embed_dim)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)

        # Final output projection: (batch_size, seq_len, embed_dim)
        output = self.out_proj(attn_output)

        return output, attn_weights

## Explanation

Most parts are the same as the [multi head attention (MHA)](../MHA/MHA.ipynb). We just pick up the differences here for a brief explanation.

### def \_\_init\_\_(self, embed_dim, num_query_heads, dropout=0.1)

- Create weight matrices (linear projections) differently for Q and K/V so that Q will be split into multiple query heads, K/V will be shared across these query heads. Specifically, the last dimension of Q projection is of full embedding dimension, while that of K and V is of each head's embedding dimension from `self.head_dim = embed_dim/num_query_head`
  - ```
    self.q_proj = nn.Linear(embed_dim, embed_dim) # multiple query heads
    self.k_proj = nn.Linear(embed_dim, self.head_dim) # shared key head
    self.v_proj = nn.Linear(embed_dim, self.head_dim) # shared value head
    ```

### def forward(self, x, mask=None)

- Project x through Q linear projection layer, then split and reshape it for multi-head batching is the same as MHA
- Project x through K and V linear projection layers. Since these two are to be shared and their sizes are only of a single head, we don't split them but add an extra head dimension with value `1` to match the multi-head Q tensor. This enables **broadcasting**
  - ```
    k = self.k_proj(x).unsqueeze(1) # (batch_size, 1, seq_len, head_dim)
    v = self.v_proj(x).unsqueeze(1) # (batch_size, 1, seq_len, head_dim)
    ```
- Everything else works the same as MHA.

Notes about the mask shape:
- For the mask to be working with MQA, it needs to either be able to broadcast, or match the shape of the multi query head. This is the same as MHA. 
  - For decoder-only Transformer the attention is autoregressive (self-attention, no cross-attention), so the sequence length of query (query_len) and key (key_len) are the same and we can name them both `seq_len`. Here we differentiate the names just for clarity.
- In the case of matching the shape, the mask would be in shape 
  - `(batch_size, num_query_head, query_len, key_len)`
- In the case of broadcast, it could be one of these
  - Causal mask (same for all batches/heads)
    - `(1, 1, query_len, key_len)`
  - Padding mask per token
    - `(batch_size, 1, 1, key_len)`
  - Fully customized
    - `(batch_size, num_query_head, query_len, key_len)`