# 1. 多头注意力
![img](https://zh.d2l.ai/_images/multi-head-attention.svg)

公式：
$$
\mathbf{h}_{i}=f\left(\mathbf{W}_{i}^{(q)} \mathbf{q}, \mathbf{W}_{i}^{(k)} \mathbf{k}, \mathbf{W}_{i}^{(v)} \mathbf{v}\right) \in \mathbb{R}^{p_{v}} \\

query: \mathbf{q} \in \mathbb{R}^{d_{q}} \\
key: \mathbf{k} \in \mathbb{R}^{d_{k}} \\
value: \mathbf{v} \in \mathbb{R}^{d_{v}} \\

学习的参数：\\
\mathbf{W}_{i}^{(q)} \in \mathbb{R}^{p_{q \times d} \times d_{q}} \quad \\

\mathbf{W}_{i}^{(k)} \in \mathbb{R}^{p k \times d k }  \\

\mathbf{W}_{i}^{(v)} \in \mathbb{R}^{p_{v} \times d_{v}} \\

注意力汇聚函数 (加性注意力，缩放点积注意力)：f \\

输出需要 W_o \in \mathbb{R}^{po \times hpv} \\

\mathbf{W}_{o}\left[\begin{array}{c}
\mathbf{h}_{1} \\
\vdots \\
\mathbf{h}_{h}
\end{array}\right] \in \mathbb{R}^{p_{o}}

$$

In [12]:
import math
import torch
from torch import nn
from d2l import torch as d2l

In [16]:
class MultiHeadAttention(nn.Module):
    """
    多头注意力
     Args:
         ...

    Inputs:
        ...

    Returns:
        ...

    """

    def __init__(
            self,
            key_size,
            query_size,
            value_size,
            num_hiddens,
            num_heads,
            dropout,
            bias=False,
            **kwargs
    ):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout=dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        """
        Args:
            queries (torch.Tensor): shape->(batch_size, qkv_dim, num_hiddens)
            keys (torch.Tensor): shape->(batch_size, qkv_dim, num_hiddens)
            values (torch.Tensor): shape->()
            valid_lens (torch.Tensor): shape->()
        Outputs:
            valid_lens (torch.Tensor)
        """
        queries = self.transpose_qkv(self.W_q(queries), self.num_heads)
        keys = self.transpose_qkv(self.W_k(keys), self.num_heads)
        values = self.transpose_qkv(self.W_v(values), self.num_heads)
        if valid_lens is not None:
            #
            valid_lens = torch.repeat_interleave(valid_lens, self.num_heads, dim=0)

        output = self.attention(queries, keys, values, valid_lens)

        output_concat = self.transpose_output(output, self.num_heads)
        # output_concat shape->(batch_size，查询的个数，num_hiddens)
        return self.W_o(output_concat)

    def transpose_qkv(self, X, num_heads):
        """
        为多注意力头的并行变换形状
        Args:
            X (torch.Tensor): shape->(batch_size, qkv_dim, num_hiddens), `qkv_dim`:`num_of_key-value_pairs`.
            num_heads (int): 头个数.
        Outputs:
            X (torch.Tensor): shape->(batch_size*num_heads, qkv_dim, num_hiddens/num_heads)
        """
        batch_size, qkv_dim, num_hiddens = X.shape
        X = X.reshape(batch_size, qkv_dim, num_heads, -1)
        # shape->(batch_size, qkv_dim, num_heads, num_hiddens/num_heads)
        X = X.permute(0, 2, 1, 3)
        # shape->(batch_size, num_heads, qkv_dim, num_hiddens/num_heads)
        # out: shape->(batch_size*num_heads, qkv_dim, num_hiddens/num_heads)
        return X.reshape(-1, X.shape[2], X.shape[3])

    def transpose_output(self, X, num_heads):
        """
        逆转transpose_qkv函数的操作
        Args:
            X (torch.Tensor): shape->(batch_size*num_heads, kv_dim, num_hiddens/num_heads)
        Outputs:
            X (torch.Tensor): shape->(batch_size, kv_dim, num_hiddens)
        """
        X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
        # shape->(batch_size, num_heads, kv_dim, num_hiddens/num_heads)
        X = X.permute(0, 2, 1, 3)
        X = X.reshape(X.shape[0], X.shape[1], -1)
        # shape->(batch_size, kv_dim, num_hiddens)
        return X


In [17]:
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                               num_hiddens, num_heads, 0.5)
attention.eval()

MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)

In [18]:
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens =  6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape

torch.Size([2, 4, 100])