### 多头注意力机制

对一组 Q、K、V，我们将三个矩阵映射到同一个向量空间（hiden_size）

然后对 Q、K 进行相关性学习时，考虑将两个向量按照维度分为 n 个部分分别进行点积学习

得到 V 向量的各个维度的权重

这就是 n 头注意力，图解如下：

![](md-img\多头注意力权重.jpg)

然后使用多头注意力权重，计算最终的输出：

![](md-img\多头注意力输出.jpg)

`其实多头注意力就可以理解为将 Q、K、V 最后一个维度分为多个部分，每个部分分别应用单头注意力机制`

这样做的效果是可以发掘序列中的每一个 token 的不同的表征 `子空间` 之间的相关性

<br>

### 代码实现

虽然多头注意力是在不同的维度实行单头注意力，但各个维度的单头注意力是可以并行的进行的（一次计算完成）

代码如下：

In [1]:
import torch
from torch import nn
import math

In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self, query_size, key_size, value_size, hidden_size, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.hidden_size = hidden_size
        self.linear_q = nn.Linear(query_size, hidden_size * num_heads)
        self.linear_k = nn.Linear(key_size, hidden_size * num_heads)
        self.linear_v = nn.Linear(value_size, hidden_size * num_heads)
        self.softmax = nn.Softmax(dim=2)
        self.linear_o = nn.Linear(hidden_size * num_heads, hidden_size * num_heads)

    # query 的形状：(batch_size, num_querys, query_size)
    # key 的形状：(batch_size, num_pairs, key_size)
    # value 的形状：(batch_size, num_pairs, value_size)
    def forward(self, query, key, value):
        batch_size = query.shape[0]
        num_querys = query.shape[1]
        num_pairs = key.shape[1]

        query = self.linear_q(query)    # (batch_size, num_querys, hidden_size * num_heads)
        key = self.linear_k(key)        # (batch_size, num_pairs, hidden_size * num_heads)
        value = self.linear_v(value)    # (batch_size, num_pairs, hidden_size * num_heads)

        # 分头
        query = query.reshape(batch_size, num_querys, self.num_heads, -1)    # (batch_size, num_querys, num_heads, hidden_size)
        key = key.reshape(batch_size, num_pairs, self.num_heads, -1)         # (batch_size, num_pairs, num_heads, hidden_size)
        value = value.reshape(batch_size, num_pairs, self.num_heads, -1)     # (batch_size, num_pairs, num_heads, hidden_size)

        # 分别对各个头计算点积注意力权重
        query = query.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, num_querys, -1)     # (batch_size * num_heads, num_querys, hidden_size)
        key = key.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.hidden_size, -1)  # (batch_size * num_heads, hidden_size, num_pairs)
        score = torch.bmm(query, key) / math.sqrt(self.hidden_size)    # (batch_size * num_heads, num_querys, num_pairs)
        weight = self.softmax(score)    # (batch_size * num_heads, num_querys, num_pairs)

        # 分别计算每一个头的输出值
        value = value.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, num_pairs, -1)  # (batch_size * num_heads, num_pairs, hidden_size)
        output = torch.bmm(weight, value)    # (batch_size * num_heads, num_querys, hidden_size)

        # 拼接每一个头
        output =  output.reshape(batch_size, self.num_heads, num_querys, -1)       # (batch_size, num_heads, num_querys, hidden_size)
        output = output.permute(0, 2, 1, 3).reshape(batch_size, num_querys, -1)    # (batch_size, num_querys, num_heads * hidden_size)

        # 进行最后的线性融合
        return self.linear_o(output)