In this notebook, we will take a closer look at how attention is implemented in PyTorch. For that purpose, we will first do a naive implementation of multi-head attention, using the formula 

\begin{align*}
\textrm{MultiHeadAttention}(Q, K, V) = (\textrm{head}_1, \dots, \textrm{head}_h) W^O
\end{align*}

where each head is given by

$$
\textrm{head}_i = \textrm{softmax}\left(\frac{Q W_i^Q (K W^K_i)^t}{\sqrt{d_k}}\right) (V W_i^V) \, i = 0, \dots, h 
$$

as the [original transformer paper](https://arxiv.org/abs/1706.03762) specifies it. It is straightforward to implement this as a module using PyTorch. Once we have this, we will learn how to extract the weight matrices that appear in this formula from the [PyTorch multi-head attention layer](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) and will verify our understanding by running our implementation and the PyTorch module on the same input to see that they yield the same results (on non-batched input).

In [1]:
import torch
import math

In [2]:
class MultiHeadSelfAttention(torch.nn.Module):
    
    def __init__(self, D, kdim = None, vdim = None, heads = 1):
        super().__init__()
        self._D = D
        self._heads = heads
        self._kdim = kdim if kdim is not None else D // heads
        self._vdim = vdim if vdim is not None else D // heads
        for h in range(self._heads):
            wq_name = f"_wq_h{h}"
            wk_name = f"_wk_h{h}"
            wv_name = f"_wv_h{h}"
            wq = torch.randn(self._D, self._kdim)
            wk = torch.randn(self._D, self._kdim)
            wv = torch.randn(self._D, self._vdim)
            setattr(self, wq_name, torch.nn.Parameter(wq))
            setattr(self, wk_name, torch.nn.Parameter(wk))
            setattr(self, wv_name, torch.nn.Parameter(wv))
        wo = torch.randn(self._heads*self._vdim, self._D)
        self._wo = torch.nn.Parameter(wo)
        
    def forward(self, X):
        for h in range(self._heads):
            wq_name = f"_wq_h{h}"
            wk_name = f"_wk_h{h}"
            wv_name = f"_wv_h{h}"
            Q = X@getattr(self, wq_name)
            K = X@getattr(self, wk_name)
            V = X@getattr(self, wv_name)
            head = Q@K.t() / math.sqrt(float(self._kdim))
            head = torch.softmax(head, dim = -1)
            head = head@V
            if 0 == h:
                out = head
            else:
                out = torch.cat([out, head], dim = 1)
        return out@self._wo

Let us now try to match this to the PyTorch implementation. We start with the case of only one head. Let us create an input, a PyTorch model and our own model and try to synchronize the parameters. Unfortunately PyTorch does not provide an easy way to access the weights, but a short look at the source code provides the following translation table (which, however, only holds for the case that all dimensions are equal):

| Attribute in torch.nn.MultiheadAttention | Weight matrix |
| ---                                      | ---           |
| out_proj                                 | $W^O$         |
| in_proj_weight                           | $W^Q$, $W^K$ and $W^V$|

More specifically, [this function](https://github.com/pytorch/pytorch/blob/5ee5a164ffeb7b7a167c53009fb8fe5f5bd439d9/torch/nn/functional.py#L4732) suggest that the the weight matrices are placed in one tensor *in_proj_weight* packed along dimension 0, in q, k, v order. Also note that the out projection is an instance of *torch.nn.Linear*, and [this function](https://github.com/pytorch/pytorch/blob/5ee5a164ffeb7b7a167c53009fb8fe5f5bd439d9/torch/nn/functional.py#L4732) reveals that PyTorch uses *torch.nn.linear* under the hood instead of a plain matrix multiplication to carry out the projections, which uses the **transpose** of the weight matrix. Thus we have to transpose the weight matrices from the PyTorch model before pulling them into our model. 

In [3]:
D = 3
L = 2
ptAttention = torch.nn.MultiheadAttention(embed_dim = D, num_heads = 1)
myAttention = MultiHeadSelfAttention(D, heads = 1)
#
# Extract weights 
#
wq, wk, wv = ptAttention.in_proj_weight.chunk(3)
wo = ptAttention.out_proj
print(f"Shape of wq: {wq.shape}")
print(f"Shape of wk: {wk.shape}")
print(f"Shape of wv: {wv.shape}")
wo = wo.weight
print(f"Shape of wo: {wo.shape}")
#
# Copy weights to our model
#
myAttention._wq_h0 = torch.nn.Parameter(wq.clone().t())
myAttention._wk_h0 = torch.nn.Parameter(wk.clone().t())
myAttention._wv_h0 = torch.nn.Parameter(wv.clone().t())
myAttention._wo = torch.nn.Parameter(wo.clone().t())
#
# Create input, feed through both models and compare
#
X = torch.randn(L, D)
out, _ = ptAttention(X, X, X)
_out = myAttention(X)
print(f"Outputs match: {torch.allclose(out, _out)}")

Shape of wq: torch.Size([3, 3])
Shape of wk: torch.Size([3, 3])
Shape of wv: torch.Size([3, 3])
Shape of wo: torch.Size([3, 3])
Outputs match: True


let us now go in more detail through the PyTorch implementation of attention which is essentially [this function](https://github.com/pytorch/pytorch/blob/37cde56658e20afae6d94b70d53e4131043e09e8/torch/nn/functional.py#L5025) and repeat what it does, ignoring a few special cases.

Our example will use low dimensions to be able to track the inputs visually. Specifically, we will use two heads, embedding dimension 4 (which, as it should be, is a multiple of the number of heads) and sequence length $L = 2$.

In [4]:
L = 2
D = 4
h = 2
#
# Prepare an input X of shape L x D, i.e. with 8 elements
#
X = torch.arange(1, 9).reshape(L, D)

For every head $i$, we have three weight matrices $W_i^Q$, $W_i^V$ and $W_i^K$. The key and value dimension as well as the query dimension are two (embedding dimension divided by number of heads). Each of these matrices is of dimension $D \times d_k$ and therefore has $d_k D$ parameters. The total number of parameters is therefore $3 h d_k D = 3 D^2$, so that we could as well organize our weights in a single matrix $W$ of dimensions $3 D \times D$. Let us locate this matrix in the PyTorch attention module.

In [5]:
ptAttention = torch.nn.MultiheadAttention(embed_dim = D, num_heads = h)
W = ptAttention.in_proj_weight.clone()
print(W.shape)

torch.Size([12, 4])


Let us now follow the path of our input through the attention layer. We start with a new value of $W$ and $X$ to be able to identify the individual components of $W$ more easily as they are propagated through the code. If we pass our input and $W$ to the forward function of the attention module, PyTorch will first [unsqueeze our input](https://github.com/pytorch/pytorch/blob/37cde56658e20afae6d94b70d53e4131043e09e8/torch/nn/functional.py#L5153) to put it into the form $L \times B \times D$, where $B = 1$ is the batch size. 

In [6]:
W = torch.arange(1, 1+D*D*3).reshape(3*D, D).to(torch.float32)
print(W.shape)
X = torch.arange(51, 59).reshape(L, D).to(torch.float32)
X = X.unsqueeze(1)
print(X.shape)

torch.Size([12, 4])
torch.Size([2, 1, 4])


Next a few parameters are calculated, namely the head dimension (which is the embedding dimension divided by the number of heads) and the target and source length (which is 𝐿
in our case). So the head dimension is 2 and the target and source length are both equal to 𝐿=2. In the next step, the projections onto query, key and value are calculated in [this function](https://github.com/pytorch/pytorch/blob/37cde56658e20afae6d94b70d53e4131043e09e8/torch/nn/functional.py#L4732). In the most general case, where the inputs are not equal, this will split 𝑊 into three parts of dimension 𝐷×𝐷 and apply torch.nn.functional.linear to query, key and value using the resulting matrices as weights. Note that this will calculate the product of the input with the transpose of the weight matrix. 

In [7]:
_w_q, _w_k, _w_v = W.chunk(3)
w_q = _w_q.t()
w_k = _w_k.t()
w_v = _w_v.t()
print(w_q)
print(w_k)
print(w_q)

tensor([[ 1.,  5.,  9., 13.],
        [ 2.,  6., 10., 14.],
        [ 3.,  7., 11., 15.],
        [ 4.,  8., 12., 16.]])
tensor([[17., 21., 25., 29.],
        [18., 22., 26., 30.],
        [19., 23., 27., 31.],
        [20., 24., 28., 32.]])
tensor([[ 1.,  5.,  9., 13.],
        [ 2.,  6., 10., 14.],
        [ 3.,  7., 11., 15.],
        [ 4.,  8., 12., 16.]])


Next, query, key and value are determined by applying a linear layer with the respective weights to the input. Note that in PyTorch, applying a linear layer without bias amounts to multiplying the input matrix from the right with the **transpose** of the weight matrix, this is why we have applied the transpose when extracting the weight matrices above. Let us go through this and verify that the linear layer gives in fact the same result as the product with the transpose.

In [8]:
_q = torch.nn.functional.linear(X, _w_q)
q = torch.matmul(X, w_q)
print(q)
assert(torch.allclose(q, _q))
k = torch.matmul(X, w_k)
v = torch.matmul(X, w_v)

tensor([[[ 530., 1370., 2210., 3050.]],

        [[ 570., 1474., 2378., 3282.]]])


Note that query, key and value for each head are matrices of dimensions $L \times B \times 2$, and this is a matrix of dimension $L \times B \times D$, so it still contains the information for all heads. 

[Back in the main function](https://github.com/pytorch/pytorch/blob/37cde56658e20afae6d94b70d53e4131043e09e8/torch/nn/functional.py#L5247), the three matrices are reshaped. First, the matrix is reshaped to dimension $(L, B \cdot h, D / h)$, i.e. each columm of $W$ is split into one part for every head. Then, the first and second dimension of the resulting matrix are switched.

In [9]:
B = 1
head_dim = D // h
q = q.view(L, B * h, head_dim).transpose(0, 1)
k = k.view(L, B * h, head_dim).transpose(0, 1)
v = v.view(L, B * h, head_dim).transpose(0, 1)
print(q)

tensor([[[ 530., 1370.],
         [ 570., 1474.]],

        [[2210., 3050.],
         [2378., 3282.]]])


So now the first dimension is reflecting both the batch dimensions and the different heads. As our batch size is one, we can easily print the values of $q$ for both heads. Note that this arrangement will allow us to apply batched multiplication, i.e. we essentially treat the head dimension as an additional batch dimension.

In [10]:
q_0 = q[0, :, :]
q_1 = q[1, :, :]
print(q_0)
print(q_1)

tensor([[ 530., 1370.],
        [ 570., 1474.]])
tensor([[2210., 3050.],
        [2378., 3282.]])


Let us now try to understand by which matrix we would have to multiply $X$ to get the same result. The values that we see in the query for head 0 are the elements of the first two columns of $q = X \cdot w_q$, ignoring the batch dimension. We are looking for a matrix of dimension $4 \times 2$ that, when being multiplied by $X$, gives $q_0$. This matrix should be of dimension $D \times 2$, i.e. embedding dimension times head dimension. Let us try to split $w_q$ along the columns.

In [11]:
w_q_0, w_q_1 = w_q.split(head_dim, dim = 1)
print(w_q_0)
_q_0 = torch.matmul(X, w_q_0)[:, 0, :]
_q_1 = torch.matmul(X, w_q_1)[:, 0, :]
print(_q_0)
assert(torch.allclose(q_0, _q_0))
assert(torch.allclose(q_1, _q_1))

tensor([[1., 5.],
        [2., 6.],
        [3., 7.],
        [4., 8.]])
tensor([[ 530., 1370.],
        [ 570., 1474.]])


Next, the attention scores are determined [here](https://github.com/pytorch/pytorch/blob/37cde56658e20afae6d94b70d53e4131043e09e8/torch/nn/functional.py#L5307), using batched multiplication (let us suppose that we have requested those, so that we enter the corresponding branch in the function). For that purpose, the function *torch.bmm* is invoked, which performs a batch multiplication (and assumes that the batch dimension is the first one, which was the reason for the reordering). This gives the attention matrix per batch item and head. Then softmax is applied to the last dimension. Finally, the attention weights are multiplied by $v$.

In [12]:
A = torch.bmm(q, k.transpose(1, 2)) / math.sqrt(float(head_dim))
A = torch.softmax(A, dim = -1)
out = torch.bmm(A, v)
print(out.shape)

torch.Size([2, 2, 2])


As expected, the output has the dimensions $(B*h, L, L)$. Next, the batch / head dimension is moved back into the middle, and the first two dimensions are combined into one. Finally, the output projection is applied to this batch again and the result is returned.

Armed with this understanding, let us now try to set up a PyTorch attention layer with two heads, extract the weights, feed them into our implementation and compare the results.

In [13]:
D = 4
L =2
h = 2
head_dim = D // h
ptAttention = torch.nn.MultiheadAttention(embed_dim = D, num_heads = h)
myAttention = MultiHeadSelfAttention(D, heads = h)
#
# Extract weights from W as before
#
_w_q, _w_k, _w_v = ptAttention.in_proj_weight.chunk(3)
w_q = _w_q.clone().t()
w_k = _w_k.clone().t()
w_v = _w_v.clone().t()
w_q_0, w_q_1 = w_q.split(head_dim, dim = 1)
w_k_0, w_k_1 = w_k.split(head_dim, dim = 1)
w_v_0, w_v_1 = w_v.split(head_dim, dim = 1)
#
# Inject weights into our model
#
myAttention._wq_h0 = torch.nn.Parameter(w_q_0)
myAttention._wq_h1 = torch.nn.Parameter(w_q_1)
myAttention._wk_h0 = torch.nn.Parameter(w_k_0)
myAttention._wk_h1 = torch.nn.Parameter(w_k_1)
myAttention._wv_h0 = torch.nn.Parameter(w_v_0)
myAttention._wv_h1 = torch.nn.Parameter(w_v_1)
wo = ptAttention.out_proj.weight.clone().t()
myAttention._wo = torch.nn.Parameter(wo)
#
# Generate input, feed into both models and compare
#
X = torch.randn(L, D)
print(X.shape)
out = myAttention(X)
_out, _ = ptAttention(X, X, X)
print(f"Outputs equal: {torch.allclose(_out, out)}")

torch.Size([2, 4])
Outputs equal: True
