The goal of this notebook is to better understand the tensor method of multi-headed attention such that concatenating the results of each head is no longer necessary

In [1]:
import torch
import torch.nn.functional as F

In [2]:
# For this example, we will use 3 batches, 5 tokens per batch, 4 dimensional embeddings, and 2 attention heads
batch, toks, num_embed, num_heads = 3, 5, 4, 2
embeds_per_head = num_embed // num_heads

X = torch.randn(batch, toks, num_embed)

In [3]:
X

tensor([[[-0.4107,  0.3778,  0.7939, -0.6678],
         [ 1.0667, -0.6242,  0.1589,  0.3039],
         [-0.2674, -0.1341,  0.4748,  0.1322],
         [ 0.6318, -0.0575, -1.3057, -1.2878],
         [ 0.1779, -0.0569, -0.2827,  0.2198]],

        [[-0.7141,  1.9954, -1.2001, -1.6837],
         [-1.6832,  0.7826, -0.9433,  0.8802],
         [-0.1342, -0.0999, -0.8499, -0.0073],
         [ 0.1937, -0.8343,  0.2319,  1.2950],
         [-0.0885, -1.2308, -0.2638,  1.3568]],

        [[-0.7015, -0.8345, -0.1715, -1.6207],
         [ 1.0907, -0.6023, -0.2791,  1.2498],
         [ 0.3522, -0.3134,  0.2640,  0.5063],
         [-1.3687, -1.8509, -0.8862, -0.4323],
         [ 0.7272, -0.2031, -0.4499,  0.4885]]])

## Define the weight matrices
Recall that in multi-headed attention, if we had 3x5x4 with two heads, we'd pass in 3x5x2 for each head.  
Since we're just using entire square matrices instead of size num_embed, that means that the first two columns of our weights represents our first head and the second two columns represent our second head for Q, K, and V respectively

In [4]:

W_query, W_key, W_val = torch.randn(4, 4), torch.randn(4, 4), torch.randn(4, 4)
W_query, W_key, W_val

(tensor([[-1.4780, -0.9516,  0.4366, -2.9062],
         [-0.1026, -0.4280, -1.3834,  1.7105],
         [ 1.1762, -0.0482,  0.2124,  0.4522],
         [-1.3839,  0.0114,  1.0256,  1.1996]]),
 tensor([[-1.3381, -0.4638, -0.3170, -0.2636],
         [-0.0936,  0.3665,  0.6969, -2.0485],
         [ 1.0330, -0.0390,  0.6749,  0.5791],
         [-0.8078, -0.1961,  1.3025,  0.3515]]),
 tensor([[-0.8592,  0.3266, -0.4633, -0.3311],
         [-0.5480, -1.2023,  0.2072, -0.6125],
         [-0.3491, -0.8961, -0.2169,  2.0952],
         [ 0.4727,  1.8188,  0.0249, -0.8341]]))

In [5]:
# Now we perform the multiplications to get our Q, K, and V. Again, the first two rows represent our result for the first head and second two for the second in Q, K, and V respectively
Q, K, V = X@W_query, X@W_key, X@W_val
Q.shape, K.shape, V.shape, Q, K, V

(torch.Size([3, 5, 4]),
 torch.Size([3, 5, 4]),
 torch.Size([3, 5, 4]),
 tensor([[[ 2.4263,  0.1833, -1.2183,  1.3977],
          [-1.7461, -0.7521,  1.6746, -3.7312],
          [ 0.7845,  0.2905,  0.3053,  0.9209],
          [-0.6815, -0.5284, -1.2427, -4.0697],
          [-0.8939, -0.1288,  0.3218, -0.4786]],
 
         [[ 1.7692, -0.1359, -5.0540,  2.9260],
          [ 0.0800,  1.3224, -1.1150,  6.8596],
          [-0.7809,  0.2113, -0.1084, -0.1740],
          [-1.7200,  0.1764,  2.6162, -0.3316],
          [-1.9308,  0.6392,  2.9996, -0.3397]],
 
         [[ 3.1636,  1.0145, -0.8505, -1.4105],
          [-3.6081, -0.7524,  2.5320, -2.8268],
          [-0.8785, -0.2079,  1.1627, -0.8327],
          [ 1.7688,  2.1324,  1.3316, -0.1077],
          [-2.2591, -0.5778,  1.0038, -2.0781]]]),
 tensor([[[ 1.8739e+00,  4.2898e-01,  5.9448e-02, -4.4064e-01],
          [-1.4501e+00, -7.8936e-01, -2.7011e-01,  1.1963e+00],
          [ 7.5407e-01,  3.0389e-02,  4.8396e-01,  6.6672e-01],
       

In [6]:
# Now we want to slice on the head. For each head, we want to perform the calculations for the tokens - but first we need to reshape to split the heads from one another
Q_reshaped, K_reshaped, V_reshaped = Q.reshape((batch, toks, num_heads, embeds_per_head)), K.reshape((batch, toks, num_heads, embeds_per_head)), V.reshape((batch, toks, num_heads, embeds_per_head))
Q, Q_reshaped

(tensor([[[ 2.4263,  0.1833, -1.2183,  1.3977],
          [-1.7461, -0.7521,  1.6746, -3.7312],
          [ 0.7845,  0.2905,  0.3053,  0.9209],
          [-0.6815, -0.5284, -1.2427, -4.0697],
          [-0.8939, -0.1288,  0.3218, -0.4786]],
 
         [[ 1.7692, -0.1359, -5.0540,  2.9260],
          [ 0.0800,  1.3224, -1.1150,  6.8596],
          [-0.7809,  0.2113, -0.1084, -0.1740],
          [-1.7200,  0.1764,  2.6162, -0.3316],
          [-1.9308,  0.6392,  2.9996, -0.3397]],
 
         [[ 3.1636,  1.0145, -0.8505, -1.4105],
          [-3.6081, -0.7524,  2.5320, -2.8268],
          [-0.8785, -0.2079,  1.1627, -0.8327],
          [ 1.7688,  2.1324,  1.3316, -0.1077],
          [-2.2591, -0.5778,  1.0038, -2.0781]]]),
 tensor([[[[ 2.4263,  0.1833],
           [-1.2183,  1.3977]],
 
          [[-1.7461, -0.7521],
           [ 1.6746, -3.7312]],
 
          [[ 0.7845,  0.2905],
           [ 0.3053,  0.9209]],
 
          [[-0.6815, -0.5284],
           [-1.2427, -4.0697]],
 
          [

In [7]:
#Now we have to reorder these matrices so that the next operation we perform is on each head (i.e. Each head's tokens within each batch)
Q_transposed, K_transposed, V_transposed = Q_reshaped.transpose(1, 2), K_reshaped.transpose(1,2), V_reshaped.transpose(1,2)
Q_transposed.shape, Q_transposed

(torch.Size([3, 2, 5, 2]),
 tensor([[[[ 2.4263,  0.1833],
           [-1.7461, -0.7521],
           [ 0.7845,  0.2905],
           [-0.6815, -0.5284],
           [-0.8939, -0.1288]],
 
          [[-1.2183,  1.3977],
           [ 1.6746, -3.7312],
           [ 0.3053,  0.9209],
           [-1.2427, -4.0697],
           [ 0.3218, -0.4786]]],
 
 
         [[[ 1.7692, -0.1359],
           [ 0.0800,  1.3224],
           [-0.7809,  0.2113],
           [-1.7200,  0.1764],
           [-1.9308,  0.6392]],
 
          [[-5.0540,  2.9260],
           [-1.1150,  6.8596],
           [-0.1084, -0.1740],
           [ 2.6162, -0.3316],
           [ 2.9996, -0.3397]]],
 
 
         [[[ 3.1636,  1.0145],
           [-3.6081, -0.7524],
           [-0.8785, -0.2079],
           [ 1.7688,  2.1324],
           [-2.2591, -0.5778]],
 
          [[-0.8505, -1.4105],
           [ 2.5320, -2.8268],
           [ 1.1627, -0.8327],
           [ 1.3316, -0.1077],
           [ 1.0038, -2.0781]]]]))

In [8]:
# Now we can perform the mutliplications of Q, K, and V
attention_scores = Q_transposed.matmul(K_transposed.transpose(-1, -2)) / torch.sqrt(torch.tensor(embeds_per_head))
attention_weights = attention_scores.softmax(-1)
attention_output = attention_weights.matmul(V_transposed)

In [9]:
attention_output.shape

torch.Size([3, 2, 5, 2])

In [10]:
result = attention_output.transpose(1, 2).reshape((batch, toks, num_embed))

In [11]:
result.shape

torch.Size([3, 5, 4])

In [12]:
result

tensor([[[-0.3653, -2.1769, -0.2775, -0.3521],
         [-0.4260,  0.7236,  0.0292,  0.8496],
         [-0.2518, -1.1838, -0.2297,  0.4191],
         [-0.3408,  0.3906, -0.0524, -1.8018],
         [-0.3533,  0.3619, -0.0861,  0.2703]],

        [[ 0.2911, -1.5369, -0.0725, -0.9939],
         [ 0.2940, -1.3769, -0.1282, -0.8760],
         [ 0.8445,  1.6671,  0.5764, -1.7316],
         [ 0.9703,  2.5523,  0.8498, -2.1205],
         [ 0.9332,  2.3235,  0.8879, -2.1783]],

        [[ 0.6329, -1.4681, -0.0613,  0.2982],
         [ 0.0656,  3.4597, -0.3245, -0.9952],
         [ 0.0967,  2.5831, -0.3388, -0.8774],
         [ 0.7639, -1.1742, -0.3501, -0.8506],
         [ 0.0478,  3.2505, -0.3105, -0.9926]]])