In [1]:
import math
import torch
import torch.nn as nn

![GQA.png](../assets/GQA.png)

In [23]:
#忽略了attention mask和 dropout
class GroupedQueryAttention(nn.Module):
    def __init__(self,hidden_dim,num_heads,nums_key_value_heads):
        super().__init__()
        assert hidden_dim % num_heads==0 #必须整除
        assert num_heads % nums_key_value_heads==0 #N个query head为一组
        
        self.hidden_dim = hidden_dim
        self.nums_heads = num_heads
        self.nums_key_value_heads = nums_key_value_heads
        self.head_dim = hidden_dim // num_heads

        #初始化 qkv o
        self.q_proj = nn.Linear(hidden_dim,num_heads* self.head_dim) # out feature size(num_heads * head_dim)
        #k v out shape (num_key_value_head * head_dim)
        self.k_proj = nn.Linear(hidden_dim,nums_key_value_heads* self.head_dim)
        self.v_proj = nn.Linear(hidden_dim,nums_key_value_heads* self.head_dim)

        self.o_proj = nn.Linear(hidden_dim,hidden_dim) 

    def forward(self,X,attention_mask=None):
        # X shape (batch,seq,hidden_dim,)
        batch_size,seq,_ = X.size()

        # qkv projection

        q = self.q_proj(X) # (batch,seq,hidden_dim)
        k = self.k_proj(X)
        v = self.v_proj(X)

        # attention weight计算,目标shape是(batch,nums_head,seq,seq)
        q = q.view(batch_size,seq,self.nums_heads,self.head_dim)
        k = k.view(batch_size,seq,self.nums_key_value_heads,self.head_dim)
        v = v.view(batch_size,seq,self.nums_key_value_heads,self.head_dim)

        #关注: nums_head 和 nums_key_value_head的关系
        q = q.transpose(1,2) # (b,nums_head,seq,head_dim)
        k = k.transpose(1,2) # (b,nums_key_value_head,seq,head_dim)
        v = v.transpose(1,2) # (b,nums_key_value_head,seq,head_dim)

        #k,v repeat:(torch里会进行广播操作，这里的话手动使用 repeat)
        k = k.repeat_interleave(self.nums_heads// self.nums_key_value_heads,dim=1) # 对第二个dim,也就是 nums_key_value_head进行扩充
        v = v.repeat_interleave(self.nums_heads// self.nums_key_value_heads,dim=1)

        #@是矩阵相乘的意思,这里的目标是seq乘以seq，所以得调转一下
        attention_score = (q @ k.transpose(2,3)) / math.sqrt(self.head_dim)

        attention_weight = torch.softmax(attention_score,dim=-1)

        #(attention mask 可以忽略)

        output = attention_weight @ v # (b,nums_head,seq,head_dim)

        #把output projection 变成(b,seq,hidden_dim)
        output = output.transpose(1,2).contiguous() #因为view导致内存不连续，contiguous的作用是重新使内存连续

        final_output = self.o_proj(output.view(batch_size,seq,-1))

        return final_output






In [24]:
hidden_dim =128
num_heads = 8
num_key_value_heads = 4

In [25]:
batch_size = 3
seq_length = 2
hidden_dim = 128
#假设有一个 batch_size为 3，seq长度为 2，每个 token的 embedding长度为128
x = torch.rand(batch_size,seq_length,hidden_dim)

In [26]:
net = GroupedQueryAttention(hidden_dim,num_heads,num_key_value_heads)

In [27]:
net(x).shape

torch.Size([3, 2, 128])