# Multi-Head Attention

In multi-head attention, queries, keys, and values are transformed with  $h$  independently learned linear projections. 

These  $h$  projected queries, keys, and values are fed into attention pooling in parallel. 

In the end,  $h$ outputs are concatenated and transformed with another learned linear projection to produce the final output.

![jupyter](../images/10/multi-head-attention.svg)

Mathematically, given a query  $\mathbf{q}\in\mathbb{R}^{d_q}$ , a key  $\mathbf{k}\in\mathbb{R}^{d_k}$ , and a value  $\mathbf{v}\in\mathbb{R}^{d_v}$ , each attention head  $\mathbf{h}_{i}(i=1,...,h)$ is computed as:

$$\mathbf{h}_{i} = f(\mathbf{W}_{i}^{(q)}\mathbf{q}, \mathbf{W}_{i}^{(k)}\mathbf{k}, \mathbf{W}_{i}^{(v)}\mathbf{v}) \in \mathbb{R}^{p_v}$$

where $\mathbf{W}_{i}^{(q)}\in\mathbb{R}^{p_{q}\times{d_{q}}}$ ,  $\mathbf{W}_{i}^{(k)}\in\mathbb{R}^{p_{k}\times{d_{k}}}$ ,  $\mathbf{W}_{i}^{(v)}\in\mathbb{R}^{p_{v}\times{d_{v}}}$  and  $f$  is attention pooling.

The multi-head attention output is another linear transformation via learnable parameters  $\mathbf{W}_{o}\in\mathbb{R}^{p_{o}\times{h{p_v}}}$:

$$\mathbf{W}_{o}\begin{bmatrix}
 \mathbf{h}_{1}\\
 \vdots \\
\mathbf{h}_{h}
\end{bmatrix}\in\mathbb{R}^{p_{o}}$$

In our implementation, we choose the scaled dot-product attention and set  $p_{q}=p_{k}=p_{v}=\frac{p_{o}}{h}$.

## Implementation

In [1]:
import torch
from torch import nn
import d2l

In [2]:
#@save
class MultiHeadAttention(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        # `num_hiddens` % `num_heads` == 0
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # Shape of `queries`, `keys`, or `values`:
        # (`batch_size`, no. of queries or key-value pairs, `num_hiddens`)
        # Shape of `valid_lens`:
        # (`batch_size`,) or (`batch_size`, no. of queries)
        # After transposing, shape of output `queries`, `keys`, or `values`:
        # (`batch_size` * `num_heads`, no. of queries or key-value pairs, `num_hiddens` / `num_heads`)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens,
                                                 repeats=self.num_heads, dim=0)

        # Shape of `output`: (`batch_size` * `num_heads`, no. of queries, `num_hiddens` / `num_heads`)
        output = self.attention(queries, keys, values, valid_lens)

        # Shape of `output_concat`: (`batch_size`, no. of queries, `num_hiddens`)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

In [3]:
#@save
def transpose_qkv(X, num_heads):
    # Shape of input `X`:
    # (`batch_size`, no. of queries or key-value pairs, `num_hiddens`).
    # Shape of output `X`:
    # (`batch_size`, no. of queries or key-value pairs, `num_heads`,
    # `num_hiddens` / `num_heads`)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

    # Shape of output `X`:
    # (`batch_size`, `num_heads`, no. of queries or key-value pairs,
    # `num_hiddens` / `num_heads`)
    X = X.permute(0, 2, 1, 3)

    # Shape of `output`:
    # (`batch_size` * `num_heads`, no. of queries or key-value pairs,
    # `num_hiddens` / `num_heads`)
    return X.reshape(-1, X.shape[2], X.shape[3])

#@save
def transpose_output(X, num_heads):
    """Reverse the operation of `transpose_qkv`"""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

## Self-Attention

Given a sequence of input tokens  $\mathbf{x}_{1},...,\mathbf{x}_{n}$  where any  $\mathbf{x}_{i}\in\mathbb{R}^{d}$, its self-attention outputs a sequence of the same length  $\mathbf{y}_{1},...,\mathbf{y}_{n}$ , where

$$\mathbf{y}_{i}=f(\mathbf{x}_{i},(\mathbf{x}_{1},\mathbf{x}_{1}),...,(\mathbf{x}_{n},\mathbf{x}_{n}))$$

In [4]:
num_hiddens, num_heads = 100, 5
# Suppose: `d` = `num_hiddens`
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                                   num_hiddens, num_heads, 0.5)

In [5]:
attention.eval()
batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
attention(X, X, X, valid_lens).shape

torch.Size([2, 4, 100])