In [1]:
import torch

## Self-attention without trainable weights

Self-attention captures the relevancy of any position in the input sequence with respect to each position.

Given an input sequence `x = ["Your", "journey", "starts", "with", "one", "step"]`, we compute context vectors for each element `x_i` which can be interpreted as enriched embedding vectors.

We use the following example embeddings for this input sequence:


In [2]:
inputs = torch.tensor(
    [
        [0.43, 0.15, 0.89],  # Your    (x_1)
        [0.55, 0.87, 0.66],  # journey (x_2)
        [0.57, 0.85, 0.64],  # starts  (x_3)
        [0.22, 0.58, 0.33],  # with    (x_4)
        [0.77, 0.25, 0.10],  # one     (x_5)
        [0.05, 0.80, 0.55],  # step    (x_6)
    ]
)

Now we want to compute the attention score assuming the second input token ("journey") serves as the query.

In [3]:
query = inputs[1]

attention_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attention_scores_2[i] = torch.dot(query, x_i)

print(attention_scores_2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


Next, we want to normalize the attention score. This is done to maintain training stability.

In [4]:
attention_weights_2 = attention_scores_2 / torch.sum(attention_scores_2)
print("Attention weights for input query 2: ", attention_weights_2)
print("Sum of attention weights: ", torch.sum(attention_weights_2))

Attention weights for input query 2:  tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum of attention weights:  tensor(1.0000)


A better way to normalize attention weights is to use the softmax function as its gradient behaves better.

In [5]:
def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)

attention_weights_2_softmax = softmax_naive(attention_scores_2)
print("Attention weights for input query 2: ", attention_weights_2_softmax)
print("Sum of attention weights: ", torch.sum(attention_weights_2_softmax))

Attention weights for input query 2:  tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum of attention weights:  tensor(1.)


Torch has a built-in implementation which is more stable and should be used.

In [6]:
attention_weights_2_torch = torch.softmax(attention_scores_2, dim=0)
print("Attention weights for input query 2: ", attention_weights_2_torch)
print("Sum of attention weights: ", torch.sum(attention_weights_2_torch))

Attention weights for input query 2:  tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum of attention weights:  tensor(1.)


Now we can compute the context vector for the second input element. This means multiplying the input tokens `x_i` with the attention weights and summing the resulting vectors so we obtain the weighted sum of all input tokens.

In [7]:
query = inputs[1]
context_vector = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vector += attention_weights_2_torch[i] * x_i

print("Context vector for input query 2: ", context_vector)

Context vector for input query 2:  tensor([0.4419, 0.6515, 0.5683])


### Attention weights for all input tokens

Now we want to generalize this procedure to compute the attention weights for each input token simultaneously. Naively, this can be done as follows:

In [8]:
attention_scores = torch.empty(inputs.shape[0], inputs.shape[0])
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attention_scores[i, j] = torch.dot(x_i, x_j)

print("Attention scores for all input tokens:\n", attention_scores)

Attention scores for all input tokens:
 tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


Those values are unnormalized. The double for-loop can be optimized by using the "@" operator.

In [9]:
attention_scores = inputs @ inputs.T
print("Attention scores for all input tokens:\n", attention_scores)

Attention scores for all input tokens:
 tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


We can then normalize it using the softmax function and verify that each row sums to 1.

In [10]:
attention_weights = torch.softmax(attention_scores, dim=1)
print("Attention weights for all input tokens:\n", attention_weights)
print("Sum of attention weights for each input token:\n", torch.sum(attention_weights, dim=1))

Attention weights for all input tokens:
 tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
Sum of attention weights for each input token:
 tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


The last step is to compute the context vector for each input token by multiplying the input tokens with the attention weights and summing the resulting vectors.

In [11]:
context_vectors = attention_weights @ inputs
print("Context vectors for all input tokens:\n", context_vectors)

Context vectors for all input tokens:
 tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


### Self-attention with trainable weights

We now want to extend the self-attention mechanism to include trainable weights so that the model can learn to attend to the most relevant parts of the input sequence.

We first define the trainable weights for the query, key, and value matrices and again start by taking the second input element as an example.

In [12]:
x_2 = inputs[1]
d_in = inputs.shape[1]  # dimension of the input embeddings
d_out = 2  # dimension of the output embeddings

torch.manual_seed(123)
# We do not compute the gradient so can set requires_grad to False
W_query = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)

Next we compute the query, key and value vectors with respect to x_2

In [13]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print("Query vector for input token 2: ", query_2)

# In order to compute the context vector for x_2, we need keys and values for all other input tokens as they are required to 
# compute the attention weights.
keys = inputs @ W_key
values = inputs @ W_value
print("keys.shape: ", keys.shape)
print("values.shape: ", values.shape)

Query vector for input token 2:  tensor([-1.1729, -0.0048])
keys.shape:  torch.Size([6, 2])
values.shape:  torch.Size([6, 2])


The second step is computing the attention scores. We start by doing this for $\omega_{22}$

In [14]:
keys_2 = keys[1]
attention_scores_22 = query_2.dot(keys_2)
print("attention_scores_22: ", attention_scores_22)

attention_scores_22:  tensor(0.1376)


We can generalize this to compute all attention scores using matrix multiplication

In [15]:
attention_scores_2 = query_2 @ keys.T
print("attention_scores_2: ", attention_scores_2)

attention_scores_2:  tensor([ 0.2172,  0.1376,  0.1730, -0.0491,  0.7616, -0.3809])


Next, we want to normalize the attention scores to obtain attention weights. Here we divide the attention scores by the square root of the embedding dimension of the keys (2).

In [16]:
d_k = keys.shape[1]
attention_weights_2 = torch.softmax(attention_scores_2 / torch.sqrt(torch.tensor(d_k)), dim=-1)
print("attention_weights_2: ", attention_weights_2)

attention_weights_2:  tensor([0.1704, 0.1611, 0.1652, 0.1412, 0.2505, 0.1117])


The final step is to compute the context vector for x_2 by multiplying the attention weights with the values.

In [17]:
context_vector_2 = attention_weights_2 @ values
print("context_vector_2: ", context_vector_2)

context_vector_2:  tensor([0.2854, 0.4081])


### Implementing a self-attention class

In order to make the self-attention mechanism reusable, we can implement it as a class.

In [18]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.randn(d_in, d_out))
        self.W_key = nn.Parameter(torch.randn(d_in, d_out))
        self.W_value = nn.Parameter(torch.randn(d_in, d_out))

    def forward(self, x):
        query = x @ self.W_query
        key = x @ self.W_key
        value = x @ self.W_value

        attention_scores = query @ key.T
        attention_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1)
        context_vector = attention_weights @ value

        return context_vector

We can apply this class to the input sequence as follows:

In [19]:
torch.manual_seed(123)

self_attention_v1 = SelfAttention_v1(d_in=3, d_out=2)
context_vector_2 = self_attention_v1(inputs)
print("context_vector_2:\n", context_vector_2)

context_vector_2:
 tensor([[0.2845, 0.4071],
        [0.2854, 0.4081],
        [0.2854, 0.4075],
        [0.2864, 0.3974],
        [0.2863, 0.3910],
        [0.2860, 0.4039]], grad_fn=<MmBackward0>)


We can see that the context vector for the second input element is equal to the one we computed before.

We can optimize our implementation by making use of nn.Linear layers which have a more stable and effective implementation.

In [20]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        query = self.W_query(x)
        key = self.W_key(x)
        value = self.W_value(x)

        attention_scores = query @ key.T
        attention_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1)
        context_vector = attention_weights @ value

        return context_vector

Applying `SelfAttention_v2` to our example, we see that the output differs due to the difference in weight initialization.

In [21]:
torch.manual_seed(123)

self_attention_v2 = SelfAttention_v2(d_in=3, d_out=2)
context_vector_2 = self_attention_v2(inputs)
print("context_vector_2:\n", context_vector_2)

context_vector_2:
 tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)


### Causal attention

In order to make the attention mechanism causal, we need to modify the attention scores so that each token can only attend to previous tokens. We can achieve this by setting the attention scores of the tokens that should not be attended to to $-\infty$.

As a first step, we compute the attention weights we have done previously.

In [22]:
query = self_attention_v2.W_query(inputs)
key = self_attention_v2.W_key(inputs)
attention_scores = query @ key.T
attention_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1)
print("attention_weights:\n", attention_weights)

attention_weights:
 tensor([[0.1717, 0.1762, 0.1761, 0.1555, 0.1627, 0.1579],
        [0.1636, 0.1749, 0.1746, 0.1612, 0.1605, 0.1652],
        [0.1637, 0.1749, 0.1746, 0.1611, 0.1606, 0.1651],
        [0.1636, 0.1704, 0.1702, 0.1652, 0.1632, 0.1674],
        [0.1667, 0.1722, 0.1721, 0.1618, 0.1633, 0.1639],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)


We then mask the values above the diagonal (= the future tokens) using pytorch's `tril` function.

In [23]:
context_length = attention_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print("mask_simple:\n", mask_simple)

attention_weights_masked = attention_weights * mask_simple
print("attention_weights_masked:\n", attention_weights_masked)

mask_simple:
 tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])
attention_weights_masked:
 tensor([[0.1717, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1636, 0.1749, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1637, 0.1749, 0.1746, 0.0000, 0.0000, 0.0000],
        [0.1636, 0.1704, 0.1702, 0.1652, 0.0000, 0.0000],
        [0.1667, 0.1722, 0.1721, 0.1618, 0.1633, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<MulBackward0>)


As afinal step we need to renormalize the attention weights.

In [24]:
row_sums = attention_weights_masked.sum(dim=-1, keepdim=True)
masked_attention_weights_norm = attention_weights_masked / row_sums
print("masked_attention_weights_norm:\n", masked_attention_weights_norm)

masked_attention_weights_norm:
 tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
        [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<DivBackward0>)


This gives us a correct implementation of causal attention but there are again ways to improve its efficiency. We can use the fact that the softmax function converts the inputs into a probability distribution. Because of the exponential in the denomiator of the softmax function, any value which has an exponent of $-\infty$ will be set to 0.

We can use this to implement our causal mask as follows:

In [25]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1).type(torch.bool)
masked = attention_scores.masked_fill(mask, -torch.inf)
print("masked:\n", masked)

attention_weights_causal = torch.softmax(masked / keys.shape[-1]**0.5, dim=-1)
print("attention_weights_causal:\n", attention_weights_causal)

masked:
 tensor([[0.3111,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.1655, 0.2602,   -inf,   -inf,   -inf,   -inf],
        [0.1667, 0.2602, 0.2577,   -inf,   -inf,   -inf],
        [0.0510, 0.1080, 0.1064, 0.0643,   -inf,   -inf],
        [0.1415, 0.1875, 0.1863, 0.0987, 0.1121,   -inf],
        [0.0476, 0.1192, 0.1171, 0.0731, 0.0477, 0.0966]],
       grad_fn=<MaskedFillBackward0>)
attention_weights_causal:
 tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
        [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)


### Adding dropout

In order to prevent overfitting, we can add dropout to the attention mechanism. This will randomly set some attention weights to 0 during training.

Dropout in transformer architectures is usually applied in either of two times:

- after calculating the attention weights
- after applying the the attention weights to the value vectors

We will implement the more commonly used way of applying the dropout mask after computing the weights.

In [26]:
torch.manual_seed(123)

dropout = nn.Dropout(p=0.5)
example = torch.ones(6, 6)
print(dropout(example))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


Because half of the weights have randomly been set to 0, the remaining values are scaled with $1 / 0.5 = 2$. This is crucial to keep the average influence of the attention mechanism consistent between training and inference.

Finally, we apply dropout to our attention weight matrix.

In [27]:
torch.manual_seed(123)
print(dropout(attention_weights_causal))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6380, 0.6816, 0.6804, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.5090, 0.5085, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4120, 0.0000, 0.3869, 0.0000, 0.0000],
        [0.0000, 0.3418, 0.3413, 0.3308, 0.3249, 0.0000]],
       grad_fn=<MulBackward0>)


### Implementing a causal attention class

We will now put all the pieces together into a causal attention class.

Such a class should be able to handle batches of inputs and we will create a batch with 2 inputs for our test.

In [28]:
batch = torch.stack([inputs, inputs], dim=0)
print("batch.shape: ", batch.shape)

batch.shape:  torch.Size([2, 6, 3])


In [29]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        # By using a buffer, PyTorch will automatically move it to the correct device
        self.register_buffer(
            "mask", torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        query = self.W_query(x)
        key = self.W_key(x)
        value = self.W_value(x)

        attention_scores = query @ key.transpose(1, 2)
        attention_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attention_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1)
        attention_weights = self.dropout(attention_weights)

        context_vector = attention_weights @ value
        return context_vector

We can use `CausalAttention` in the same way we used `SelfAttention`.

In [30]:
torch.manual_seed(123)
context_length = batch.shape[1]

causal_attention = CausalAttention(d_in=3, d_out=2, context_length=context_length, dropout=0.0)
context_vector = causal_attention(batch)
print("context_vector:\n", context_vector)
print("context_vector.shape: ", context_vector.shape)

context_vector:
 tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)
context_vector.shape:  torch.Size([2, 6, 2])


### Extending the single-head attention class to multi-head attention

In order to be able to attend to different parts of the input sequence, we can extend the single-head attention class to multi-head attention.

To do this, we create multiple instances of the single-head attention class with separate weights and concatenate their outputs.

We code this as a wrapper around our `CausalAttention` class which will stack multiple instances.

In [31]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)]
        )

    def forward(self, x):
        # Concatenate the outputs of the individual heads
        return torch.cat([head(x) for head in self.heads], dim=-1)            

This means that, if we use this class with `num_heads = 2`, we will get a 4-dimensional context vector (see example below).

In [32]:
torch.manual_seed(123)
d_out, context_length, d_in = batch.shape
num_heads = 2

multi_head_attention = MultiHeadAttentionWrapper(d_in, d_out, context_length, dropout=0.0, num_heads=num_heads)
context_vector = multi_head_attention(batch)
print("context_vector:\n", context_vector)
print("context_vector.shape: ", context_vector.shape)


context_vector:
 tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
context_vector.shape:  torch.Size([2, 6, 4])


### Multi-head attention using weight splits

A more efficient implementation of multi-head attention than using a wrapper around the CausalAttention class is achieved using weight splitting. This allows us to compute all keys using a single matrix multiplication for example. In the previous implementation we had to compute the keys for each head separately. Given the matrix multiplication is the most expensive operation, this is an important optimization. 

In [34]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask", 
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )
        
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        values = self.W_value(x)
        query = self.W_query(x)

        # Split the last dimension of the weighs into multiple heads
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        query = query.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose num_heads and num_tokens
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        query = query.transpose(1, 2)

        # Compute the dot prodcut for each head and apply causal mask
        attention_scores = query @ keys.transpose(2, 3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attention_scores.masked_fill_(mask_bool, -torch.inf)

        # Normalize the attention scores and apply dropout
        attention_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # Compute the context vector, flip back num_heads and num_tokens and combine 
        # the heads again
        context_vector = attention_weights @ values
        context_vector = context_vector.transpose(1, 2)
        context_vector = context_vector.contiguous().view(b, num_tokens, self.d_out)

        # Add optional output projection
        context_vector = self.out_proj(context_vector)
        
        return context_vector

We can use the MultiHeadAttention class in the same way we used the stacked version.

In [35]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
num_heads = 2
dropout = 0.0

multi_head_attention = MultiHeadAttention(d_in, d_out, context_length, dropout, num_heads)
context_vector = multi_head_attention(batch)
print("context_vector:\n", context_vector)
print("context_vector.shape: ", context_vector.shape)

context_vector:
 tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vector.shape:  torch.Size([2, 6, 2])
