# **Grouped-Query Attention (GQA)**

This notebook contains a simple implementation of the GQA mechanism. The goal is to show the core steps involved. 
In subsequent notebooks, this will be formally incorporated into the GPT architecture that we have implemented so far.

In [29]:
import torch
import math

Let us set up some hyperparameters of our toy GPT model.

In [30]:
# hparams
batch_size = 2
seq_len = 5
d_model = 64

n_q_heads = 8
n_kv_heads = 2
head_dim = d_model // n_q_heads 

We must confirm that the number of KV heads is a factor of the number of Q heads (aka, `n_q_heads` is divisible by `n_kv_heads`) in order to succesfully group the KV heads.

In [31]:
assert n_q_heads % n_kv_heads == 0, "Number of KV heads must be a factor of number of Q heads"
group_size = n_q_heads // n_kv_heads

In [32]:
print(f"[INFO] Number of query heads: {n_q_heads}")
print(f"[INFO] Number of key-value heads: {n_kv_heads}")
print(f"[INFO] Using group size: {group_size}")

[INFO] Number of query heads: 8
[INFO] Number of key-value heads: 2
[INFO] Using group size: 4


Now, we will use some random Q, K and V tensors.


**Note**: assume these QKV tensors have been obtained after multiplying the input (x) with the Q, K and V matrices respectively.
- The number of query weight matrices (W_query) will be equal to the number of query heads (which is the same as the number of attention heads since each head gets its own unique query matrix).
- However, the number of key and value weight matrices (W_key and W_value) will be less than the number of attention heads. As each attention head doesn't get its own KV matrices. Instead, as per the GQA mechanism, the attention heads are grouped such that each group shares a common K and V matrix.
- In this case, since there's 8 attention heads and the group size is 4. That means 4 attention heads share a common K and V weight matrix (W_key and W_value is the same for all heads in a group).

In [33]:
# random Q, K, V
Q = torch.randn(batch_size, n_q_heads, seq_len, head_dim)
K = torch.randn(batch_size, n_kv_heads, seq_len, head_dim)
V = torch.randn(batch_size, n_kv_heads, seq_len, head_dim)
print(Q.shape, K.shape, V.shape)

torch.Size([2, 8, 5, 8]) torch.Size([2, 2, 5, 8]) torch.Size([2, 2, 5, 8])


Now, we apply the core GQA trick: expanding these K and V tensors such that each attention heads gets one for itself (but the heads in a group will get the same K, V tensors).

In [34]:
# expand K, V to match Q heads (core GQA trick)
K = K.repeat_interleave(group_size, dim=1)
V = V.repeat_interleave(group_size, dim=1)
# now: (batch, n_q_heads, seq_len, head_dim)
print(Q.shape, K.shape, V.shape)

torch.Size([2, 8, 5, 8]) torch.Size([2, 8, 5, 8]) torch.Size([2, 8, 5, 8])


The rest of the attention mechanism is the same as we say in MHA. We calculate the attention scores and then attention weights using the Q and K tensors and then get the final context vector (`output`) by taking a weighted sum of the V tensor.

In [35]:
# scaled dot-product attention
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(head_dim)
attn_weights = torch.softmax(attn_scores, dim=-1)
output = attn_weights @ V

print(output.shape)

torch.Size([2, 8, 5, 8])


- For the formal implementation of the GQA mechanism, checkout the [gpt_with_kv_gqa](gpt_with_kv_gqa.py) script.
- The [README](README.md) file has the results of the comparison of the MHA and GQA mechanism.