In [1]:
import torch
from torch import nn
#创建多头注意力机制网络模型
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, hidden_size, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        
        ## 初始化Q、K、V投影矩阵
        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, hidden_size)
        self.v_linear = nn.Linear(hidden_size, hidden_size)
        
        ## 输出线性层
        self.o_linear = nn.Linear(hidden_size, hidden_size)
        
    def forward(self,query,attention_mask=None):
        batch_size = query.size()[0]
        
        #先传入线性层
        query = self.q_linear(query)
        key = self.k_linear(query)
        value = self.v_linear(query)
        
        #分割张量，产生多个头
        query = self.split_head(query)
        key = self.split_head(key)
        value = self.split_head(value)
        
        #使用缩放点积模型计算注意力分数
        attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))

        ## 对注意力分数进行归一化
        attention_probs = torch.softmax(attention_scores, dim=-1)
        
        output = torch.matmul(attention_probs, value)
        
        ## 对注意力输出进行拼接
        output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)
        #线性输出
        output = self.o_linear(output)
        
        return output,attention_probs
 
    #分割张量元素，得到多个头
    def split_head(self, x):
        splited=torch.chunk(x, num_heads, dim=-1)
        stacked=torch.stack((splited[:num_heads])).transpose(0,1)
        return stacked
    """
    (batch_size ,sequence_length ,hidden_size )->
    (batch_size ,num_heads ,sequence_length , head_dim))
    """

In [2]:
class MultiQueryAttention(torch.nn.Module):
    def __init__(self, hidden_size, num_heads):
        super(MultiQueryAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        
        ## 初始化Q、K、V投影矩阵
        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, hidden_size)
        self.v_linear = nn.Linear(hidden_size, hidden_size)
        
        ## 输出线性层
        self.o_linear = nn.Linear(hidden_size, hidden_size)
        
    def forward(self,query,attention_mask=None):
        batch_size = query.size()[0]
        
        #先传入线性层
        query = self.q_linear(query)
        key = self.k_linear(query)
        value = self.v_linear(query)
        
        #分割张量，key和value自有一个头
        query = self.split_head(query)
        key = self.split_head(key,head=1)
        value = self.split_head(value,head=1)
        
        #使用缩放点积模型计算注意力分数
        attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))
        
        ## 对注意力分数进行归一化
        attention_probs = torch.softmax(attention_scores, dim=-1)
        
        output = torch.matmul(attention_probs, value)
        
        # 对注意力输出进行拼接
        output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)
        #线性输出
        output = self.o_linear(output)
        
        return output,attention_probs
 
    #分割张量元素，得到多个头
    def split_head(self, x, head=None):
        if head==None:
            splited=torch.chunk(x, num_heads, dim=-1)
            stacked=torch.stack((splited[:num_heads])).transpose(0,1)
            return stacked
            """
            (batch_size ,sequence_length ,hidden_size )->
            (batch_size ,num_heads ,sequence_length , head_dim))
            """
        else:
            batch_size = x.size()[0]
            return x.view(batch_size, -1, head, self.head_dim).transpose(1,2)
            #(batch_size ,1 ,sequence_length , head_dim))

In [3]:
class GroupAttention(torch.nn.Module):
    def __init__(self, hidden_size, num_heads,group_num):
        super(GroupAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.group_num=group_num
        
        ## 初始化Q、K、V投影矩阵
        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, hidden_size)
        self.v_linear = nn.Linear(hidden_size, hidden_size)
        
        ## 输出线性层
        self.o_linear = nn.Linear(hidden_size, hidden_size)
        
    def forward(self,query,attention_mask=None):
        batch_size = query.size()[0]
        
        #先传入线性层
        query = self.q_linear(query)
        key = self.k_linear(query)
        value = self.v_linear(query)
        
        #分割张量，key和value自有一个头
        query = self.split_head(query)
        key = self.split_head(key,head=self.group_num)
        value = self.split_head(value,head=self.group_num)
        
        #使用缩放点积模型计算注意力分数
        attention_scores = torch.matmul(query, key.transpose(-1, -2) / self.head_dim**0.5)
        
        ## 对注意力分数进行归一化
        attention_probs = torch.softmax(attention_scores, dim=-1)
        
        output = torch.matmul(attention_probs, value)
        
        # 对注意力输出进行拼接
        output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)
        #线性输出
        output = self.o_linear(output)
        
        return output,attention_probs
 
    #分割张量元素，得到多个头
    def split_head(self, x, head=None):
        batch_size = x.size()[0]
        if head is None:
            x=x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)#这样分割张量更方便
            return x
            """
            (batch_size ,sequence_length ,hidden_size )->
            (batch_size ,num_heads ,sequence_length , head_dim))
            """
        else:
            x=x.view(batch_size, -1, head, self.head_dim).transpose(1,2).repeat(1,self.num_heads //self.group_num,1,1)
            return x

In [4]:
batch_size = 4
seq_len = 10#序列长度
hidden_size = 512#元素长度
num_heads = 8
group_num=4

In [5]:
query = torch.randn(batch_size, seq_len, hidden_size)

# 创建多头注意力对象
multi_head_attention = MultiHeadAttention(hidden_size, num_heads)
# 计算多头自注意力，得到输出和注意力权重
MHA_attention,MHA_attention_probs= multi_head_attention(query)

MHA_attention.shape ,MHA_attention_probs.shape

(torch.Size([4, 10, 512]), torch.Size([4, 8, 10, 10]))

In [6]:
query = torch.randn(batch_size,  hidden_size)

# 创建多查询注意力对象
multi_query_attention=MultiQueryAttention(hidden_size, num_heads)
MQA_attention,MQA_attention_probs= multi_query_attention(query)

print(MQA_attention_probs.shape)
print(MQA_attention_probs.shape)
MQA_attention ,MQA_attention_probs.shape

torch.Size([4, 4, 8, 8])
torch.Size([4, 4, 8, 8])


(tensor([[[-0.0274, -0.0297,  0.0498,  ...,  0.0218,  0.0803,  0.0251],
          [-0.0412, -0.0445,  0.0269,  ..., -0.0015,  0.0863,  0.0086],
          [-0.0314, -0.0184,  0.0244,  ..., -0.0115,  0.0972, -0.0123],
          [-0.0853, -0.0346,  0.0305,  ..., -0.0049,  0.0899, -0.0071]],
 
         [[ 0.0053, -0.0075, -0.1005,  ..., -0.0866, -0.0021,  0.0288],
          [-0.0068,  0.0050, -0.0956,  ..., -0.0637,  0.0059, -0.0073],
          [-0.0111, -0.0027, -0.0949,  ..., -0.0733,  0.0098, -0.0011],
          [ 0.0029, -0.0149, -0.1058,  ..., -0.0759,  0.0048, -0.0007]],
 
         [[-0.0466, -0.0379,  0.0785,  ..., -0.1324,  0.1220,  0.0476],
          [-0.0412, -0.0339,  0.0413,  ..., -0.1403,  0.1458,  0.0595],
          [-0.0190, -0.0591,  0.0495,  ..., -0.1382,  0.1322,  0.0639],
          [-0.0296, -0.0357,  0.0555,  ..., -0.1452,  0.1161,  0.0492]],
 
         [[ 0.0259,  0.0818, -0.0408,  ..., -0.0390, -0.1329,  0.0785],
          [ 0.0331,  0.0915, -0.0596,  ..., -0.0375, -0

In [7]:
query = torch.randn(batch_size,  hidden_size)
# 创建分组查询注意力对象
group_attention=GroupAttention(hidden_size, num_heads,group_num)
GQA_attention,GQA_attention_probs= group_attention(query)
#输出和注意力权重
GQA_attention,GQA_attention_probs.shape

(tensor([[[ 0.1727, -0.0994,  0.0781,  ...,  0.0127,  0.0398, -0.1223]],
 
         [[ 0.0168,  0.2205, -0.0086,  ..., -0.0311,  0.1290, -0.1132]],
 
         [[-0.1978, -0.0029,  0.0129,  ...,  0.1279, -0.0065, -0.1365]],
 
         [[ 0.1422, -0.0264, -0.2135,  ..., -0.2205,  0.1208, -0.3180]]],
        grad_fn=<ViewBackward0>),
 torch.Size([4, 8, 1, 2]))