### MHA -> MQA -> GQA
![关系](./pics/MHA_MQA_GQA.png)
### · coding Multi-Head-Attention

In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim, head_num, dropout_rate = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.head_num = head_num
        self.head_dim = hidden_dim // head_num

        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)

        self.att_dropout = nn.Dropout(dropout_rate)

    def forward(self, x, attention_mask = None):
        batch_size, seq_len, _ = x.size()

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        Q_state = q.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        K_state = k.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        V_state = v.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)

        attention_weight = Q_state @ K_state.transpose(-1, -2) / math.sqrt(self.head_dim)

        if attention_mask is not None:
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0,
                float("-inf")
            )
        
        attention_weight = F.softmax(attention_weight, dim = -1)
        attention_weight = self.att_dropout(attention_weight)

        output = attention_weight @ V_state

        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        output = self.out_proj(output)

        return output
        
# Test
attention_mask_MHA = (
    torch.tensor(
        [
            [0, 1],
            [1, 1],
            [1, 0]
        ]
    ).unsqueeze(1).unsqueeze(2).expand(3, 8, 2, 2)
)

x_MHA = torch.randn(3, 2, 128)

# head_dim = 16
MHA_net = MultiHeadAttention(128, 8)
MHA_net(x_MHA, attention_mask_MHA)


tensor([[[-5.5567e-01,  2.4076e-01,  4.6895e-01, -3.6922e-02,  3.9975e-04,
          -8.1978e-02, -3.2290e-02,  3.8538e-01,  4.8378e-01,  1.2396e-02,
           4.5937e-02, -6.3421e-01,  4.0401e-01, -8.3526e-01,  1.4382e-01,
          -2.6344e-01, -2.1243e-01, -7.8360e-02,  3.4908e-02, -1.7554e-01,
           3.0448e-01, -1.0930e-02, -1.3162e-01, -1.8322e-01,  8.2782e-02,
           8.8300e-01,  2.3892e-01,  1.4224e-01,  4.7802e-01,  4.7979e-01,
          -7.9672e-01,  4.7388e-01,  5.0549e-01, -2.7657e-01, -5.8219e-01,
           1.5372e-01, -6.4708e-01, -2.2582e-01,  4.8574e-01, -6.4262e-01,
          -2.1122e-01,  7.3353e-01,  3.3868e-01, -2.2551e-02,  4.5843e-02,
           7.1220e-01,  5.6761e-01, -1.4531e-01, -1.0196e+00,  3.3974e-01,
           3.1525e-01, -9.0425e-01,  7.6110e-02,  6.1162e-01, -3.1718e-01,
          -2.6062e-01, -6.2284e-01,  5.2902e-01,  3.9362e-01, -2.2538e-02,
           1.9929e-01, -6.9513e-01,  2.3361e-02,  1.7679e-01, -6.6484e-01,
          -5.3844e-01, -1

### Multi Query Attention

In [9]:
class MultiQueryAttention(nn.Module):
    def __init__(self, hidden_dim, head_num, num_key_value_head = 1, attention_dropout = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.head_num = head_num
        self.num_key_value_head = num_key_value_head
        self.head_dim = hidden_dim // head_num

        self.q_proj = nn.Linear(hidden_dim, head_num * self.head_dim)
        self.k_proj = nn.Linear(hidden_dim, num_key_value_head * self.head_dim)
        self.v_proj = nn.Linear(hidden_dim, num_key_value_head * self.head_dim)
        self.out_proj = nn. Linear(hidden_dim, hidden_dim)

        self.attention_dropout = nn.Dropout(attention_dropout)

    def forward(self, x, attention_mask = None):
        # x shape:[b,s,h]
        batch_size, seq_len, _ = x.size()

        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x)

        # Q_state shape [b, s, head_num, head_dim]
        Q_state = Q.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        K_state = K.view(batch_size, seq_len, self.num_key_value_head, self.head_dim).transpose(1, 2)
        V_state = V.view(batch_size, seq_len, self.num_key_value_head, self.head_dim).transpose(1, 2)

        # K, V repeat_interleave
        K_state = K_state.repeat_interleave(self.head_num // self.num_key_value_head, dim = 1)
        V_state = V_state.repeat_interleave(self.head_num // self.num_key_value_head, dim = 1)

        # attention_weight shape [b, head_num, seq_len, seq_len]
        attenention_weight = Q_state @ K_state.transpose(-1, -2) / math.sqrt(self.head_dim)

        if attention_mask is not None:
            attenention_weight = attenention_weight.masked_fill(
                attention_mask == 0,
                float("-inf")
            )

        attenention_weight = F.softmax(attenention_weight, dim = -1)
        output_mid = self.attention_dropout(attenention_weight) @ V_state

        # output shape: [batch_size, head_num, seq_len, head_dim] 
        #                   -> [batch_size, seq_len, head_num * head_dim]
        output = output_mid.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) 

        output = self.out_proj(output)       
        return output

# Test
x_MQA = torch.randn(3, 2, 128)
attention_mask_MQA = (
    torch.tensor(
        [
            [0, 1],
            [1, 1],
            [1, 0]
        ]
    ).unsqueeze(1).unsqueeze(2).expand(3, 8, 2, 2)
)
# head_dim = 16
MQA_net = MultiQueryAttention(128, 8)
MQA_net(x_GQA, attention_mask_MQA).shape

torch.Size([3, 2, 128])

### Group Query Attention

In [2]:
class GroupQueryAttention(nn.Module):
    def __init__(self, hidden_dim, head_num, num_key_value_head, attention_dropout = 0.1):
        super().__init__()
        assert hidden_dim % head_num == 0, "hidden_dim must be divisible by head_num"
        assert head_num % num_key_value_head == 0, "head_num must be divisible by num_key_value_head"

        self.hidden_dim = hidden_dim
        self.head_num = head_num
        # head number * head_dim = hidden_dim
        self.head_dim = hidden_dim // head_num
        self.num_key_value_head = num_key_value_head

        # K,V out features = head_num * head_dim
        self.q_proj = nn.Linear(hidden_dim, head_num * self.head_dim)
        self.k_proj = nn.Linear(hidden_dim, num_key_value_head * self.head_dim)
        self.v_proj = nn.Linear(hidden_dim, num_key_value_head * self.head_dim)
        # Q out features = head_num * head_dim = hidden_dim
        self.out_proj = nn.Linear(hidden_dim, hidden_dim) 

        self.attention_dropout = nn.Dropout(attention_dropout)

    def forward(self, x, attention_mask = None):
        # x shape: [batch_size, seq_len, hidden_dim]
        batch_size, seq_len, _ = x.size()

        Q = self.q_proj(x)  # [batch_size, seq_len, head_num * head_dim]
        K = self.k_proj(x)
        V = self.v_proj(x)

        # [batch_size, seq_len, head_num * head_dim] -> [batch_size, seq_len, head_num, head_dim]
        # head_num * head_dim -> head_num * head_dim
        Q_state = Q.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        K_state = K.view(batch_size, seq_len, self.num_key_value_head, self.head_dim).transpose(1, 2)
        V_state = V.view(batch_size, seq_len, self.num_key_value_head, self.head_dim).transpose(1, 2)

        # head_num 和 num_key_value_head 可能本相同，因此K,V需要repeat操作
        K_state = K_state.repeat_interleave(self.head_num // self.num_key_value_head, dim = 1)
        V_state = V_state.repeat_interleave(self.head_num // self.num_key_value_head, dim = 1)

        # attention_weight 目标shape -> [batch_size, head_num, seq_len, seq_len]
        attention_weight = Q_state @ K_state.transpose(2, 3) / math.sqrt(self.head_dim)

        if attention_mask is not None:
            # attention_mask shape: [batch_size, seq_len, seq_len]
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0, 
                float("-inf")
            )

        attention_weight = torch.softmax(attention_weight, dim = -1)
        output = self.attention_dropout(attention_weight) @ V_state

        # output shape: [batch_size, head_num, seq_len, head_dim] 
        #                   -> [batch_size, seq_len, head_num * head_dim]
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)

        output = self.out_proj(output)
        return output

# Test
x_GQA = torch.randn(3, 2, 128)
attention_mask_GQA = (
    torch.tensor(
        [
            [0, 1],
            [1, 1],
            [1, 0]
        ]
    ).unsqueeze(1).unsqueeze(2).expand(3, 8, 2, 2)
)
# head_dim = 16
GQA_net = GroupQueryAttention(128, 8, 4)
GQA_net(x_GQA, attention_mask_GQA).shape

torch.Size([3, 2, 128])