## GroupQueryAttention

In [1]:
import torch
import torch.nn as nn
import math

In [4]:
# * 忽略dropour layer和attention_mask
class GroupQueryAttention(nn.Module):
    """Some Information about GroupQueryAttention"""
    def __init__(self, hidden_dim, nums_head, nums_key_value_head):
        super().__init__()
        assert hidden_dim % nums_head == 0
        assert nums_head % nums_key_value_head == 0

        self.hidden_dim = hidden_dim
        self.nums_head = nums_head
        self.head_dim = hidden_dim // nums_head
        self.nums_key_value_head = nums_key_value_head  # * 这个参数表示将nums_head分为多少组，也表示总共我们有多少组key/value矩阵

        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, nums_key_value_head * self.head_dim)
        self.v_proj = nn.Linear(hidden_dim, nums_key_value_head * self.head_dim)
        self.output_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x, attention_mask=None):
        # * x shape is: (batch, seq, hidden_dim)

        batch, seq, _ = x.size()

        q_matrix = self.q_proj(x)
        k_matrix = self.k_proj(x)
        v_matrix = self.v_proj(x)

        # * q_head_matrix: (batch, nums_head, seq, head_dim)
        q_head_matrix = q_matrix.view(batch, seq, self.nums_head, self.head_dim).transpose(1, 2)

        # * k_head_matrix and v_head_matrix: (batch, nums_key_value_head, seq, head_dim)
        k_head_matrix = k_matrix.view(batch, seq, self.nums_key_value_head, self.head_dim).transpose(1, 2)
        v_head_matrix = v_matrix.view(batch, seq, self.nums_key_value_head, self.head_dim).transpose(1, 2)

        # * 对key和value矩阵进行repeat，确保2个张量的head数量是一致的
        # * 为什么这里是用repeat_interleave是因为repeat_interleave是针对张量中的元素进行操作，而repeat函数是对整个张量在不同维度上进行重复
        # * 我们这里只需要对每组key/value矩阵，重复self.nums_head // self.nums_key_value_head次数，满足每个head的query能够对应进行相乘的key/value矩阵
        # * k_head_matrix and v_head_matrix: (batch, nums_head, seq, head_dim)
        k_head_matrix = torch.repeat_interleave(k_head_matrix, repeats=self.nums_head // self.nums_key_value_head, dim=1)
        v_head_matrix = torch.repeat_interleave(v_head_matrix, repeats=self.nums_head // self.nums_key_value_head, dim=1)

        # * attention_matrix: (batch, nums_head, seq, seq)
        attention_matrix = q_head_matrix @ k_head_matrix.transpose(2, 3) / math.sqrt(self.head_dim)

        # * 这里忽略attention mask

        # * attention_weight: (batch, nums_head, seq, seq)
        attention_weight = torch.softmax(attention_matrix, dim=-1)

        # * 这里也忽略了dropout层

        # * mid_output: (batch, nums_head, seq, head_dim)
        mid_output = attention_weight @ v_head_matrix

        mid_output = mid_output.transpose(1, 2).contiguous()

        mid_output = mid_output.view(batch, seq, -1)

        output = self.output_proj(mid_output)

        return output




x = torch.rand(3, 2, 128)
net = GroupQueryAttention(128, 8, 4)

print(net(x).shape)

torch.Size([3, 2, 128])
