The paper starts by discussing scaled dot-product attention. *Attention* refers to a mechanism that allows for "modeling of dependencies without regard to their input or output sequencies". In other words, attention allows the model to *attend* to different parts of the input when learning to approximate a function.

The common example shown for attention is how different words in a sentence are associated each other. For example, consider the sentence "A big red dog jumped over a small pond". As a reader, it's easy to understand that the words "big", "red", and "jumped" all refer to the dog, or are at least more relevant to understand what the dog is doing than the word "small". Attention allows a model to learn and understand the strength of these associations, allowing it to better understand the context and predict the correct output.

> Dot-product attention is much faster and more space-efficient in practice [than additive attention], since it can be implemented using highly optimized matrix multiplication code.

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import math

In [38]:
def scaled_dot_product_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor: # TODO: implement masking
    """
    Performs scaled dot-product attention as defined in the Transformers paper.
    Assumes that d_keys and d_values are equal to d_model / num_heads.

    :param Q: The query vector of shape (num_heads, d_model, d_keys)
    :param K: The key vector of shape (num_heads, d_model, d_keys)
    :param V: The values vector of shape (num_heads, d_model, d_values)
    :return: The scaled attention scores of shape (num_heads, d_model, d_values)
    """
    d_keys = Q.shape[2]

    scaling_factor = 1 / math.sqrt(d_keys)

    return F.softmax(Q @ torch.transpose(K, 1, 2) * scaling_factor, dim=0) @ V

In [39]:
# let's try this out with some random values
d_model = 512
num_heads = 8
d_keys = d_model // num_heads
d_values = d_model // num_heads

Q = torch.randn((num_heads, d_model, d_keys))
K = torch.randn((num_heads, d_model, d_keys))
V = torch.randn((num_heads, d_model, d_values))

scaled_scores = scaled_dot_product_attention(Q, K, V)
print(scaled_scores.shape)

torch.Size([8, 512, 64])


Ok, now we can start building out the attention layer. Let's see what the paper has to say about an attention head:

> Instead of performing a single attention function with dmodel-dimensional keys, values and queries, we found it beneficial to linearly project the queries, keys and values h times with different, learned linear projections to dk, dk and dv dimensions, respectively. On each of these projected versions of queries, keys and values we then perform the attention function in parallel, yielding dv-dimensional output values. These are concatenated and once again projected, resulting in the final values, as depicted in Figure 2.

> Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this.

> $\text{MultiHead}(Q,K,V)=\text{Concat}(\text{head}_1,...,\text{head}_h)W^O$
> $\text{where head}_i=\text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$


In [40]:
class MultiHeadedAttention(nn.Module):
    """
    Implements multi-headed attention from the Attention paper.

    :param d_model: Also equal the embedding dimension. "All sub-layers in the model, as well as the embedding layers, produce outputs of dimension d_model".
    :param num_heads: How many attention heads to use. Must be a multiple of d_model.
    """
    def __init__(self, d_model, num_heads):
        assert (d_model % num_heads == 0), "d_model is not a multiple of num_heads"
        
        super().__init__()
        
        self.d_model = d_model
        self.d_keys = d_model / num_heads
        self.d_values = d_model / num_heads
        self.num_heads = num_heads

        # each head has its own query, key, and value vector; in order to parallelize all of these calculations,
        # we batch them by num_heads and perform broadcasted matrix multiplication
        self.W_Q = nn.Parameter(torch.zeros((num_heads, d_model, d_keys)))
        self.W_K = nn.Parameter(torch.zeros((num_heads, d_model, d_keys)))
        self.W_V = nn.Parameter(torch.zeros((num_heads, d_model, d_values)))
        
        self.W_O = nn.Parameter(torch.zeros((num_heads * d_values, d_model)))
        self.b_O = nn.Parameter(torch.zeros((num_heads * d_values)))
        # TODO: double check if this is the correct initialization for each parameter

        nn.init.xavier_normal_(self.W_Q)
        nn.init.xavier_normal_(self.W_K)
        nn.init.xavier_normal_(self.W_V)
        nn.init.xavier_normal_(self.W_O)
        nn.init.xavier_normal_(self.b_O)

    def forward(self, x):
        Q = x @ self.W_Q
        K = x @ self.W_K
        V = x @ self.W_V

        Z = scaled_dot_product_attention(Q, K, V) # should be dim (num_heads, d_model, d_values)

        # now we need to "concatenate" all of the outputs Z along the num_heads axis
        Z = Z.reshape(self.num_heads * self.d_model, self.d_values)

        # finally, multiply the output by W_O and add the biases
        out = Z @ self.W_O + self.b_O

        return out

