### MultiHead-Self-Attention
怎么手写一个 Single Head Self-Attention，但是一般在实际上的训练过程中都会使用 Multi Head, 而且其实也仅仅是 每个 Head 做完 Self-Attention 得到结果之后，进行拼接，然后过一个 output 投影矩阵

In [7]:
import math
import torch
import torch.nn as nn
# 定义常量
DROPOUT_PROB = 0.1
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim, nums_head) -> None:
        super().__init__()
        self.nums_head = nums_head
        self.hidden_dim = hidden_dim
        assert (hidden_dim % nums_head == 0)  # 确保词向量维度是头数的整数倍
        self.head_dim=hidden_dim // nums_head  # 被拆分为多头后的某一头词向量的维度

        self.q_proj=nn.Linear(hidden_dim,hidden_dim)
        self.k_proj=nn.Linear(hidden_dim,hidden_dim)
        self.v_proj=nn.Linear(hidden_dim,hidden_dim)

        self.att_drop=nn.Dropout(DROPOUT_PROB)
        # 输出时候的 proj
        self.out_proj=nn.Linear(hidden_dim,hidden_dim)
    def forward(self,X,attention_mask=None):
        # 需要在 mask 之前 masked_fill
        # X shape is (batch, seq, hidden_dim)
        # attention_mask shape is (batch, seq)
        batch_size,seq_len,_=X.size()
        Q=self.q_proj(X)
        K=self.k_proj(X)
        V=self.v_proj(X)

        # shape 变成 （batch_size, num_head, seq_len, head_dim）
        q_state=Q.view(batch_size,seq_len,self.nums_head,self.head_dim).permute(0,2,1,3)
        
        k_state=K.view(batch_size,seq_len,self.nums_head,self.head_dim).transpose(1,2)

        v_state = V.view(batch_size, seq_len, self.nums_head, self.head_dim).transpose( 1, 2)
        
        print("q_state:",q_state.shape)
        
        # 主意这里需要用 head_dim，而不是 hidden_dim
        attention_weight=(q_state @ k_state.transpose(-1,-2)/math.sqrt(self.head_dim))
       
        if attention_mask is not None:
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0, float("-1e20")
            )
        # 第四个维度 softmax
        attention_weight = torch.softmax(attention_weight, dim=3)
        print(attention_weight)
        attention_weight = self.att_drop(attention_weight)
        output_mid = attention_weight @ v_state

        # 重新变成 (batch, seq_len, num_head, head_dim)
        # 这里的 contiguous() 是相当于返回一个连续内存的 tensor，一般用了 permute/tranpose 都要这么操作
        # 如果后面用 Reshape 就可以不用这个 contiguous()，因为 view 只能在连续内存中操作
        output_mid=output_mid.transpose(1,2).contiguous()

        # 变成 (batch, seq, hidden_dim)
        output=output_mid.view(batch_size,seq_len,-1)
        return self.out_proj(output)

attention_mask = (
    torch.tensor(
        [
            [0, 1],
            [0, 0],
            [1, 0],
        ]
    )
    .unsqueeze(1)
    .unsqueeze(2)
    .expand(3, 8, 2, 2)
)

x = torch.rand(3, 2, 128)
net = MultiHeadAttention(128, 8)
net(x, attention_mask).shape


q_state: torch.Size([3, 8, 2, 16])
tensor([[[[0.0000, 1.0000],
          [0.0000, 1.0000]],

         [[0.0000, 1.0000],
          [0.0000, 1.0000]],

         [[0.0000, 1.0000],
          [0.0000, 1.0000]],

         [[0.0000, 1.0000],
          [0.0000, 1.0000]],

         [[0.0000, 1.0000],
          [0.0000, 1.0000]],

         [[0.0000, 1.0000],
          [0.0000, 1.0000]],

         [[0.0000, 1.0000],
          [0.0000, 1.0000]],

         [[0.0000, 1.0000],
          [0.0000, 1.0000]]],


        [[[0.5000, 0.5000],
          [0.5000, 0.5000]],

         [[0.5000, 0.5000],
          [0.5000, 0.5000]],

         [[0.5000, 0.5000],
          [0.5000, 0.5000]],

         [[0.5000, 0.5000],
          [0.5000, 0.5000]],

         [[0.5000, 0.5000],
          [0.5000, 0.5000]],

         [[0.5000, 0.5000],
          [0.5000, 0.5000]],

         [[0.5000, 0.5000],
          [0.5000, 0.5000]],

         [[0.5000, 0.5000],
          [0.5000, 0.5000]]],


        [[[1.0000, 0.0000],
     

torch.Size([3, 2, 128])