In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [12]:
class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, num_heads, num_kv_heads = None, dropout = 0.1):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads # Defaults to multi head attention
        self.head_dim = d_model // num_heads

        assert self.num_heads % self.num_kv_heads == 0, "num_heads must be divisible by num_kv_heads"

        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        # Dimension for each projection
        self.q_proj_dim  = d_model
        self.kv_proj_dim = self.num_kv_heads * self.head_dim

        # Projections
        self.W_q = nn.Linear(self.q_proj_dim, self.q_proj_dim, bias=False)
        self.W_k = nn.Linear(self.q_proj_dim, self.kv_proj_dim, bias=False)
        self.W_v = nn.Linear(self.q_proj_dim, self.kv_proj_dim, bias=False)
        self.W_o = nn.Linear(self.q_proj_dim, self.q_proj_dim, bias=False)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask = None):
        batch_size, seq_len, _ = x.shape

        # Applying Linear Projections
        q = self.W_q(x)
        k = self.W_k(x)
        v = self.W_v(x)

        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
        k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
        v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)

        # Repeating k and v for grouped query attention
        # [batch_size, seq_len, num_kv_heads, head_dim] -> [batch_size, seq_len, num_heads,  head_dim]
        # We repeat each key and value head num_queries_per_kv number of times

        if self.num_kv_heads < self.num_heads:
            k = k.repeat_interleave(self.num_queries_per_kv, dim=2)
            v = v.repeat_interleave(self.num_queries_per_kv, dim=2)

        # [batch_size, seq_len, num_heads, head_him] -> [batch_size, num_heads, seq_len, head_dim]
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        #
        context = torch.matmul(attn_weights, v)

        #[batch_size, num_heads, seq_len, head_dim] -> [batch_size, seq_len, num_heads,  head_dim]
        context = context.transpose(1,2)

        context = context.contiguous().view(batch_size, seq_len, self.d_model)

        output = self.W_o(context)

        return output


In [26]:
class MultiQUeryAttention(GroupedQueryAttention):
    def __init__(self, d_model, num_heads, dropout = 0.1):
        super().__init__(d_model, num_heads, num_kv_heads=1, dropout=dropout)

    def forward(self,x, mask = None):
        return super().forward(x, mask)

In [27]:
batch_size = 2
seq_len = 10
d_model = 512
num_heads = 8

In [28]:
x = torch.randn(batch_size, seq_len, d_model)

# Standard Multi-Head Attention (MHA)
mha = GroupedQueryAttention(d_model, num_heads, num_kv_heads=num_heads)

# Grouped Query Attention (GQA) with 2 KV heads
gqa = GroupedQueryAttention(d_model, num_heads, num_kv_heads=4)

# Multi-Query Attention (MQA)
mqa = MultiQueryAttention(d_model, num_heads)

# Forward passes
mha_output = mha(x)
gqa_output = gqa(x)
mqa_output = mqa(x)

# Print shapes and parameter counts
mha_params = sum(p.numel() for p in mha.parameters())
gqa_params = sum(p.numel() for p in gqa.parameters())
mqa_params = sum(p.numel() for p in mqa.parameters())

print(f"MHA output shape: {mha_output.shape}")
print(f"GQA output shape: {gqa_output.shape}")
print(f"MQA output shape: {mqa_output.shape}")
print(f"MHA parameter count: {mha_params}")
print(f"GQA parameter count: {gqa_params}")
print(f"MQA parameter count: {mqa_params}")

# Compare parameter savings
print(f"GQA saves {(mha_params - gqa_params) / mha_params * 100:.2f}% parameters compared to MHA")
print(f"MQA saves {(mha_params - mqa_params) / mha_params * 100:.2f}% parameters compared to MHA")

MHA output shape: torch.Size([2, 10, 512])
GQA output shape: torch.Size([2, 10, 512])
MQA output shape: torch.Size([2, 10, 512])
MHA parameter count: 1048576
GQA parameter count: 786432
MQA parameter count: 589824
GQA saves 25.00% parameters compared to MHA
MQA saves 43.75% parameters compared to MHA
