Step 1: Start with the input

In [1]:
import torch
x = torch.tensor([[
    [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], # the
    [6.0, 5.0, 4.0, 3.0, 2.0, 1.0], # kid
    [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]  # smiles
]])

In [2]:
print(x.shape)
batch_size, seq_len, d_model = x.shape

torch.Size([1, 3, 6])


Step 2: Decide d_out and num_heads

In [3]:
d_out = 6 # Keeping dim same as input
num_heads = 2

Step 3: Initialise Wq, Wk, Wv

In [4]:
# Each weight matrix will have dim (d_model x d_model) i.e 6x6.
torch.manual_seed(0)
Wq = torch.randn(d_model, d_model)
Wk = torch.randn(d_model, d_model)
Wv = torch.randn(d_model, d_model)

Step 4: Calculate Q, K, V

In [5]:
# result will be of size (1, 3, 6) since size of x is (1, 3, 6) and size of W is (6,6)
Q = x @ Wq
K = x @ Wk
V = x @ Wv

print("Q:\n", Q)
print("K:\n", K)
print("V:\n", V)

Q:
 tensor([[[ -9.0244, -11.7287,  15.5360,  -1.4474,  -4.5326,   9.4674],
         [ -8.0564, -13.2309,   8.2228,  -8.9680,   3.1995,   4.8321],
         [ -2.4401,  -3.5657,   3.3941,  -1.4879,  -0.1904,   2.0428]]])
K:
 tensor([[[  8.2602,  14.1116,  -5.0345, -16.4865,  -2.9948,   8.3139],
         [ -6.1188,  -0.1587,  -5.0885, -14.3014,   4.9540,   5.6093],
         [  0.3059,   1.9933,  -1.4461,  -4.3983,   0.2799,   1.9890]]])
V:
 tensor([[[ 0.5076, -3.4353,  1.8576,  2.8041,  8.9427, 13.1841],
         [-1.9113, -3.6934,  1.8502,  1.7622,  1.6981,  3.0978],
         [-0.2005, -1.0184,  0.5297,  0.6523,  1.5201,  2.3260]]])


Step 5: Unroll the last dimension of Q, K, V to include num_heads

Q, K, V : 1 x 3 x 6 [3D] --> 1 x 3 x 2 x 3 [4D]

In [6]:
head_dim = 3
Q = Q.view(1, 3, num_heads, head_dim)
K = K.view(1, 3, num_heads, head_dim)
V = V.view(1, 3, num_heads, head_dim)

print("Q after unrolling:\n", Q)
print("K after unrolling:\n", K)
print("V after unrolling:\n", V)

Q after unrolling:
 tensor([[[[ -9.0244, -11.7287,  15.5360],
          [ -1.4474,  -4.5326,   9.4674]],

         [[ -8.0564, -13.2309,   8.2228],
          [ -8.9680,   3.1995,   4.8321]],

         [[ -2.4401,  -3.5657,   3.3941],
          [ -1.4879,  -0.1904,   2.0428]]]])
K after unrolling:
 tensor([[[[  8.2602,  14.1116,  -5.0345],
          [-16.4865,  -2.9948,   8.3139]],

         [[ -6.1188,  -0.1587,  -5.0885],
          [-14.3014,   4.9540,   5.6093]],

         [[  0.3059,   1.9933,  -1.4461],
          [ -4.3983,   0.2799,   1.9890]]]])
V after unrolling:
 tensor([[[[ 0.5076, -3.4353,  1.8576],
          [ 2.8041,  8.9427, 13.1841]],

         [[-1.9113, -3.6934,  1.8502],
          [ 1.7622,  1.6981,  3.0978]],

         [[-0.2005, -1.0184,  0.5297],
          [ 0.6523,  1.5201,  2.3260]]]])


 ```
Q after unrolling:
 tensor([[[[ -9.0244, -11.7287,  15.5360],  Head 1 for token 1
          [ -1.4474,  -4.5326,   9.4674]],  Head 2 for token 1

         [[ -8.0564, -13.2309,   8.2228],   Head 1 for token 2
          [ -8.9680,   3.1995,   4.8321]],  Head 2 for token 2

         [[ -2.4401,  -3.5657,   3.3941],   Head 1 for token 3
          [ -1.4879,  -0.1904,   2.0428]]]])Head 2 for token 3
```
Here the grouping is according to number of tokens

Step 6: Group matrices by number of heads

Q, K, V : 1x3x2x3 [batch, num_tokens, num_heads, head_dim]

to

Q, K, V : 1x2x3x3 [batch, num_heads, num_tokens, head_dim]

In [7]:
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)

print("Q after grouping by heads:\n", Q)
print("K after grouping by heads:\n", K)
print("V after grouping by heads:\n", V)

Q after grouping by heads:
 tensor([[[[ -9.0244, -11.7287,  15.5360],
          [ -8.0564, -13.2309,   8.2228],
          [ -2.4401,  -3.5657,   3.3941]],

         [[ -1.4474,  -4.5326,   9.4674],
          [ -8.9680,   3.1995,   4.8321],
          [ -1.4879,  -0.1904,   2.0428]]]])
K after grouping by heads:
 tensor([[[[  8.2602,  14.1116,  -5.0345],
          [ -6.1188,  -0.1587,  -5.0885],
          [  0.3059,   1.9933,  -1.4461]],

         [[-16.4865,  -2.9948,   8.3139],
          [-14.3014,   4.9540,   5.6093],
          [ -4.3983,   0.2799,   1.9890]]]])
V after grouping by heads:
 tensor([[[[ 0.5076, -3.4353,  1.8576],
          [-1.9113, -3.6934,  1.8502],
          [-0.2005, -1.0184,  0.5297]],

         [[ 2.8041,  8.9427, 13.1841],
          [ 1.7622,  1.6981,  3.0978],
          [ 0.6523,  1.5201,  2.3260]]]])


```
Q after grouping by heads:
 tensor([[[[ -9.0244, -11.7287,  15.5360],  Head 1 Token 1
          [ -8.0564, -13.2309,   8.2228],   Head 1 Token 2
          [ -2.4401,  -3.5657,   3.3941]],  Head 1 Token 3

         [[ -1.4474,  -4.5326,   9.4674],   Head 2 Token 1
          [ -8.9680,   3.1995,   4.8321],   Head 2 Token 2
          [ -1.4879,  -0.1904,   2.0428]]]])    Head 2 Token 3
```

Step 7: Prepare K transpose for attention scores calculation

K: 1x2x3x3 [batch, num_heads, num_tokens, head_dim]

K_T: 1x2x3x3 [batch, num_heads, head_dim, num_tokens]

In [8]:
K_T = K.transpose(2, 3)

print("K before transpose:\n", K)
print("K after transpose:\n", K_T)

K before transpose:
 tensor([[[[  8.2602,  14.1116,  -5.0345],
          [ -6.1188,  -0.1587,  -5.0885],
          [  0.3059,   1.9933,  -1.4461]],

         [[-16.4865,  -2.9948,   8.3139],
          [-14.3014,   4.9540,   5.6093],
          [ -4.3983,   0.2799,   1.9890]]]])
K after transpose:
 tensor([[[[  8.2602,  -6.1188,   0.3059],
          [ 14.1116,  -0.1587,   1.9933],
          [ -5.0345,  -5.0885,  -1.4461]],

         [[-16.4865, -14.3014,  -4.3983],
          [ -2.9948,   4.9540,   0.2799],
          [  8.3139,   5.6093,   1.9890]]]])


Step 8: Find attention scores

Q*K_T : [batch, num_heads, num_tokens, num_tokens]

In [9]:
attention_scores = Q @ K_T
print("Attention scores shape:\n", attention_scores.shape)
print("Attention scores:\n", attention_scores)

Attention scores shape:
 torch.Size([1, 2, 3, 3])
Attention scores:
 tensor([[[[-318.2692,  -21.9748,  -48.6063],
          [-294.6534,    9.5538,  -40.7285],
          [ -87.5604,   -1.7744,  -12.7621]],

         [[ 116.1476,   51.3506,   23.9283],
          [ 178.4425,  171.2106,   49.9505],
          [  42.0843,   31.7945,   10.5541]]]])


```
Attention scores:   the          kid      smiles
 tensor([[      [[-318.2692,  -21.9748,  -48.6063],     the
        Head 1  [-294.6534,    9.5538,  -40.7285],      kid
                [ -87.5604,   -1.7744,  -12.7621]],     smiles

                [[ 116.1476,   51.3506,   23.9283],     the
        Head 2  [ 178.4425,  171.2106,   49.9505],      kid
                [  42.0843,   31.7945,   10.5541]]]])   smiles
```

Step 9: Apply causal mask

In [10]:
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
print("Causal mask\n", mask)
attention_scores.masked_fill_(mask, -torch.inf)
print("Attention scores after masking:\n", attention_scores)

Causal mask
 tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])
Attention scores after masking:
 tensor([[[[-318.2692,      -inf,      -inf],
          [-294.6534,    9.5538,      -inf],
          [ -87.5604,   -1.7744,  -12.7621]],

         [[ 116.1476,      -inf,      -inf],
          [ 178.4425,  171.2106,      -inf],
          [  42.0843,   31.7945,   10.5541]]]])


Step 10: Calculate attention weights

In [11]:
torch.set_printoptions(precision=3, sci_mode=False) # Turn off scientific notation
attention_weights = torch.softmax(attention_scores / head_dim**0.5, dim=-1)
print("Attention weights shape:\n", attention_weights.shape)
print("Attention weights:\n", attention_weights)

Attention weights shape:
 torch.Size([1, 2, 3, 3])
Attention weights:
 tensor([[[[    1.000,     0.000,     0.000],
          [    0.000,     1.000,     0.000],
          [    0.000,     0.998,     0.002]],

         [[    1.000,     0.000,     0.000],
          [    0.985,     0.015,     0.000],
          [    0.997,     0.003,     0.000]]]])


Step 11: Apply dropouts

In [12]:
dropout = torch.nn.Dropout(0.1)
attention_weights = dropout(attention_weights)
print("Attention weights after dropout:\n", attention_weights)

Attention weights after dropout:
 tensor([[[[    1.111,     0.000,     0.000],
          [    0.000,     1.111,     0.000],
          [    0.000,     1.109,     0.002]],

         [[    1.111,     0.000,     0.000],
          [    1.094,     0.017,     0.000],
          [    1.108,     0.003,     0.000]]]])


Step 12: Calculate context vectors

attention weight: [batch_size, num_heads, num_tokens, num_tokens] -> 1x2x3x3

value matrix : [batch_size, num_heads, num_tokens, head_dim] -> 1x2x3x3

context vector  = attention weights * value matrix

In [13]:
context_vectors = attention_weights @ V
print("Context vectors shape:\n", context_vectors.shape)
print("Context vectors:\n", context_vectors)

Context vectors shape:
 torch.Size([1, 2, 3, 3])
Context vectors:
 tensor([[[[ 0.564, -3.817,  2.064],
          [-2.124, -4.104,  2.056],
          [-2.120, -4.099,  2.053]],

         [[ 3.116,  9.936, 14.649],
          [ 3.098,  9.814, 14.479],
          [ 3.113,  9.915, 14.620]]]])


```
Context vectors:
 tensor([[      [[ 0.564, -3.817,  2.064],       token 1
        Head 1  [-2.124, -4.104,  2.056],        token 2
                [-2.120, -4.099,  2.053]],       token 3

                [[ 3.116,  9.936, 14.649],       token 1
       Head 2   [ 3.098,  9.814, 14.479],        token 2
                [ 3.113,  9.915, 14.620]]]])     token 3
```

Step 13: Reformat and concatenate

context vectors : attention weights * values matrix --> [batch, num_heads, num_tokens, head_dim]

Our starting vector input was [1, 3, 6]

First swap num_heads with num_tokens

In [14]:
context_vectors = context_vectors.transpose(1, 2)
print("Context vectors shape after swapping dimension 1 and 2:\n", context_vectors.shape)
print("Context vectors:\n", context_vectors)

Context vectors shape after swapping dimension 1 and 2:
 torch.Size([1, 3, 2, 3])
Context vectors:
 tensor([[[[ 0.564, -3.817,  2.064],
          [ 3.116,  9.936, 14.649]],

         [[-2.124, -4.104,  2.056],
          [ 3.098,  9.814, 14.479]],

         [[-2.120, -4.099,  2.053],
          [ 3.113,  9.915, 14.620]]]])


```
Context vectors:
 tensor([[[[ 0.564, -3.817,  2.064],    Token 1 Head 1
          [ 3.116,  9.936, 14.649]],    Token 1 Head 2

         [[-2.124, -4.104,  2.056],     Token 2 Head 1
          [ 3.098,  9.814, 14.479]],    Token 2 Head 2

         [[-2.120, -4.099,  2.053],     Token 3 Head 1
          [ 3.113,  9.915, 14.620]]]])  Token 3 Head 2
```

Then merge last two dimensions to get the shape [batch, num_tokens, d_out]

In [15]:
context_vectors = context_vectors.reshape(batch_size, seq_len, num_heads * head_dim)
print("Context vectors shape after concatenating heads:\n", context_vectors.shape)
print("Context vectors:\n", context_vectors)

Context vectors shape after concatenating heads:
 torch.Size([1, 3, 6])
Context vectors:
 tensor([[[ 0.564, -3.817,  2.064,  3.116,  9.936, 14.649],
         [-2.124, -4.104,  2.056,  3.098,  9.814, 14.479],
         [-2.120, -4.099,  2.053,  3.113,  9.915, 14.620]]])
