In [None]:
import torch
import torch.nn as nn
import einops
class MultiHeadAttention(nn.Module):
    def __init__(self,hidden_size,num_heads):
        super(MultiHeadAttention,self).__init__()
        self.num_heads=num_heads
        self.hidden_size=hidden_size
        
        ### 初始化 QKV 对应的投影矩阵
        self.q_linear= nn.Linear(hidden_size,hidden_size)
        self.k_linear= nn.Linear(hidden_size,hidden_size)
        self.v_linear= nn.Linear(hidden_size,hidden_size)
        ## 输出线性层
        self.out_linear=nn.Linear(hidden_size,hidden_size)
        
    def forward(self,hidden_state,mask=None):
        batch_size= hidden_state.size(0)
        # 传入一个隐藏层向量
        query=self.q_linear(hidden_state)
        key=self.k_linear(hidden_state)
        value=self.v_linear(hidden_state)
        # 分割头部
        query=einops.rearrange(query,"b s (h d) -> b h s d",h=self.num_heads)
        key=einops.rearrange(key,"b s (h d) -> b h s d")
        value=einops.rearrange(value,"b s (h d) -> b h s d")
        ## 计算注意力分数 q*k/sqrt(d_k)
        # attention score 
        attention_scores= torch.matmul(query,key.transpose(-2,-1))/torch.sqrt(torch.tensor(self.hidden_size))
        
        if mask is not None:
            attention_scores=attention_scores.masked_fill(mask==0,-1e9)
        attention_probs=torch.softmax(attention_scores,dim=-1)# 对最后一层进行处理
        output =torch.matmul(attention_probs,value)
        output=einops.rearrange(output,"b h s d -> b s (h d)")
        return self.out_linear(output)
         
        