### Multi-Head Attention
![Multi-Head Self Attention](./pics/Multi-head_attention.png)
### ·公式

In [7]:
import torch
import math
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim, head_num, attention_dropout = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.head_num = head_num
        # head number * head_dim = hidden_dim
        self.head_dim = hidden_dim // head_num

        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)

        self.attention_dropout = nn.Dropout(attention_dropout)

    def forward(self, x, attention_mask = None):
        # x shape: [batch_size, seq_len, hidden_dim]
        batch_size, seq_len, _ = x.size()

        # Q, K, V shape: [batch_size, seq_len, hidden_dim]
        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x)

        # [batch_size, seq_len, hidden_dim] -> [batch_size, head_num, seq_len, head_dim]
        # hidden_dim -> head_num * head_dim
        Q_state = Q.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        K_state = K.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        V_state = V.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)

        attention_weight = torch.matmul(
            # [batch_size, head_num, seq_len, head_dim] 
            #           -> [batch_size, head_num, head_dim, seq_len]
            Q_state, K_state.transpose(-1, -2) 
        ) / math.sqrt(self.head_dim)

        if attention_mask is not None:
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0, 
                float("-inf")
            )
        
        # print(attention_weight)

        attention_weight = torch.softmax(attention_weight, dim = -1)
        attention_weight = self.attention_dropout(attention_weight)

        # [b, head_num, seq_len, head_dim]
        output_mid = torch.matmul(attention_weight, V_state)

        # [b, head_num, seq_len, head_dim] -> [b, seq_len, hidden_dim]
        output_mid = output_mid.transpose(1, 2).contiguous()
        output_mid = output_mid.view(batch_size, seq_len, -1)

        output = self.out_proj(output_mid)
        return output

attention_mask = (
    torch.tensor(
        [
            [0, 1],
            [1, 1],
            [1, 0]
        ]
    ).unsqueeze(1).unsqueeze(2).expand(3, 8, 2, 2)
)

x = torch.randn(3, 2, 128)

# head_dim = 16
net = MultiHeadAttention(128, 8)
net(x, attention_mask)

tensor([[[ 2.0993e-01,  3.0547e-01, -2.3037e-05,  6.8221e-01, -1.2496e-02,
           2.5840e-01,  2.3654e-01,  1.3889e-01, -2.1414e-01,  4.3203e-01,
           3.9672e-01, -2.6677e-02, -2.9180e-01,  5.1848e-02, -2.3070e-01,
           9.0426e-02,  5.0347e-01,  3.0282e-01, -3.1114e-01,  2.1110e-02,
          -7.3000e-02, -2.7051e-01,  6.9139e-01, -9.5220e-02,  1.6011e-01,
          -1.2424e-01, -1.8378e-01, -1.4660e-01,  6.3168e-02, -1.4993e-01,
          -5.7103e-01,  2.2359e-01,  1.0795e-01,  1.0616e-01,  2.4938e-01,
           3.0051e-01, -5.3310e-01,  3.5771e-01, -1.8586e-01,  1.9984e-01,
          -1.7034e-01,  3.9097e-01,  7.2686e-02,  6.6708e-02,  6.4364e-02,
          -2.7825e-03,  4.1006e-02,  1.6895e-01,  1.6615e-01,  1.5083e-01,
          -2.4489e-01,  2.5287e-01,  2.8236e-01, -1.2226e-01, -6.2271e-02,
          -3.2344e-01, -1.4886e-01,  2.7082e-01,  2.2715e-01,  8.1514e-02,
           1.3073e-01,  3.6592e-02,  1.9625e-01, -4.5066e-01, -8.4207e-02,
          -4.4004e-02,  6