In [24]:
import torch

In [25]:
# Create an embedding layer with 4 possible inputs and an embedding dimension of 8
inputs = torch.nn.Embedding( 4, 8)

In [26]:
inputs = inputs.weight
inputs = inputs.data


In [27]:
inputs.shape

torch.Size([4, 8])

In [28]:
# Create a batch of inputs. This determines d_in and context_length.
batches = torch.stack((inputs, inputs), dim = 0)
batches

tensor([[[-0.1921,  1.0251,  1.5100, -0.4795, -0.3530,  0.3964,  0.1063,
          -0.2738],
         [ 1.7395, -0.2655, -0.4987, -0.5834, -0.6828, -0.1388, -0.5295,
           0.5828],
         [ 1.5744,  1.3453, -1.8660,  0.0479,  0.6781,  0.1494,  0.7446,
          -0.0604],
         [-1.2803,  0.7010,  0.0425,  0.3091,  1.6385, -0.7256,  0.3384,
          -0.2931]],

        [[-0.1921,  1.0251,  1.5100, -0.4795, -0.3530,  0.3964,  0.1063,
          -0.2738],
         [ 1.7395, -0.2655, -0.4987, -0.5834, -0.6828, -0.1388, -0.5295,
           0.5828],
         [ 1.5744,  1.3453, -1.8660,  0.0479,  0.6781,  0.1494,  0.7446,
          -0.0604],
         [-1.2803,  0.7010,  0.0425,  0.3091,  1.6385, -0.7256,  0.3384,
          -0.2931]]])

In [29]:
# Pick a d_out dimension
d_out = 16
d_in = 4
num_heads = 4
assert (d_out % num_heads) == 0, "d_out must be divisible by num_heads"


In [30]:
# In a sequence of cells, go step-by-step through the same calculation of queries and attention scores that the MultiHeadAttention class does. 
# At each step, print the tensors and their shapes to see why reshaping is necessary.
b, num_tokens, d_in = batches.shape
print("batches:", batches)
print("batches.shape:", batches.shape)



batches: tensor([[[-0.1921,  1.0251,  1.5100, -0.4795, -0.3530,  0.3964,  0.1063,
          -0.2738],
         [ 1.7395, -0.2655, -0.4987, -0.5834, -0.6828, -0.1388, -0.5295,
           0.5828],
         [ 1.5744,  1.3453, -1.8660,  0.0479,  0.6781,  0.1494,  0.7446,
          -0.0604],
         [-1.2803,  0.7010,  0.0425,  0.3091,  1.6385, -0.7256,  0.3384,
          -0.2931]],

        [[-0.1921,  1.0251,  1.5100, -0.4795, -0.3530,  0.3964,  0.1063,
          -0.2738],
         [ 1.7395, -0.2655, -0.4987, -0.5834, -0.6828, -0.1388, -0.5295,
           0.5828],
         [ 1.5744,  1.3453, -1.8660,  0.0479,  0.6781,  0.1494,  0.7446,
          -0.0604],
         [-1.2803,  0.7010,  0.0425,  0.3091,  1.6385, -0.7256,  0.3384,
          -0.2931]]])
batches.shape: torch.Size([2, 4, 8])


In [33]:
W_keys = torch.nn.Linear( d_in, d_out, bias=False )
W_queries = torch.nn.Linear( d_in, d_out, bias=False )
W_values = torch.nn.Linear( d_in, d_out, bias=False )
print("W_keys, W_queries, W_values:", W_keys, W_queries, W_values)

keys = W_keys( batches )
queries = W_queries( batches )
values = W_values( batches )
print("keys:", keys)
print("keys.shape:", keys.shape)
print("queries:", queries)  
print("queries.shape:", queries.shape)
print("values:", values)
print("values.shape:", values.shape)

W_keys, W_queries, W_values: Linear(in_features=8, out_features=16, bias=False) Linear(in_features=8, out_features=16, bias=False) Linear(in_features=8, out_features=16, bias=False)
keys: tensor([[[-0.2110, -0.3065, -0.4244, -0.0132,  0.1109,  0.3079,  0.6386,
           0.2159,  0.1624, -0.0054,  0.1911, -0.2630, -0.4189, -0.0223,
           0.4067,  0.5185],
         [-0.1475,  0.0765,  0.4400,  0.6493, -0.1237, -0.1922,  0.5323,
          -0.1795, -0.4362, -0.5583,  0.7247,  0.3939, -0.0954, -0.3837,
          -0.0972,  0.0892],
         [ 0.4775, -0.7502,  0.3627, -0.6154, -0.4659,  0.3159,  0.1196,
          -0.1515, -0.5103, -0.4454,  0.0662,  0.6648,  1.0871,  0.4892,
          -0.0080, -0.6345],
         [ 0.1429, -0.4597, -0.1869, -1.3242, -0.5193, -0.2505, -0.5958,
          -0.3976,  0.4893,  0.3918, -0.7144, -0.6985,  0.2271,  0.5787,
           0.2730, -0.4339]],

        [[-0.2110, -0.3065, -0.4244, -0.0132,  0.1109,  0.3079,  0.6386,
           0.2159,  0.1624, -0.0054, 

In [34]:
keys = keys.view( b, num_tokens, num_heads, d_out // num_heads )
queries = queries.view( b, num_tokens, num_heads, d_out // num_heads )
values = values.view( b, num_tokens, num_heads, d_out // num_heads )
print("keys reshaped:", keys)
print("keys reshaped.shape:", keys.shape)
print("queries reshaped:", queries)
print("queries reshaped.shape:", queries.shape)
print("values reshaped:", values)
print("values reshaped.shape:", values.shape)



keys reshaped: tensor([[[[-0.2110, -0.3065, -0.4244, -0.0132],
          [ 0.1109,  0.3079,  0.6386,  0.2159],
          [ 0.1624, -0.0054,  0.1911, -0.2630],
          [-0.4189, -0.0223,  0.4067,  0.5185]],

         [[-0.1475,  0.0765,  0.4400,  0.6493],
          [-0.1237, -0.1922,  0.5323, -0.1795],
          [-0.4362, -0.5583,  0.7247,  0.3939],
          [-0.0954, -0.3837, -0.0972,  0.0892]],

         [[ 0.4775, -0.7502,  0.3627, -0.6154],
          [-0.4659,  0.3159,  0.1196, -0.1515],
          [-0.5103, -0.4454,  0.0662,  0.6648],
          [ 1.0871,  0.4892, -0.0080, -0.6345]],

         [[ 0.1429, -0.4597, -0.1869, -1.3242],
          [-0.5193, -0.2505, -0.5958, -0.3976],
          [ 0.4893,  0.3918, -0.7144, -0.6985],
          [ 0.2271,  0.5787,  0.2730, -0.4339]]],


        [[[-0.2110, -0.3065, -0.4244, -0.0132],
          [ 0.1109,  0.3079,  0.6386,  0.2159],
          [ 0.1624, -0.0054,  0.1911, -0.2630],
          [-0.4189, -0.0223,  0.4067,  0.5185]],

         [[-0

In [35]:
keys = keys.transpose(1,2)
queries = queries.transpose(1,2)
values = values.transpose(1,2)
print("keys transposed:", keys)
print("keys transposed.shape:", keys.shape)
print("queries transposed:", queries)
print("queries transposed.shape:", queries.shape)
print("values transposed:", values)
print("values transposed.shape:", values.shape)

keys transposed: tensor([[[[-0.2110, -0.3065, -0.4244, -0.0132],
          [-0.1475,  0.0765,  0.4400,  0.6493],
          [ 0.4775, -0.7502,  0.3627, -0.6154],
          [ 0.1429, -0.4597, -0.1869, -1.3242]],

         [[ 0.1109,  0.3079,  0.6386,  0.2159],
          [-0.1237, -0.1922,  0.5323, -0.1795],
          [-0.4659,  0.3159,  0.1196, -0.1515],
          [-0.5193, -0.2505, -0.5958, -0.3976]],

         [[ 0.1624, -0.0054,  0.1911, -0.2630],
          [-0.4362, -0.5583,  0.7247,  0.3939],
          [-0.5103, -0.4454,  0.0662,  0.6648],
          [ 0.4893,  0.3918, -0.7144, -0.6985]],

         [[-0.4189, -0.0223,  0.4067,  0.5185],
          [-0.0954, -0.3837, -0.0972,  0.0892],
          [ 1.0871,  0.4892, -0.0080, -0.6345],
          [ 0.2271,  0.5787,  0.2730, -0.4339]]],


        [[[-0.2110, -0.3065, -0.4244, -0.0132],
          [-0.1475,  0.0765,  0.4400,  0.6493],
          [ 0.4775, -0.7502,  0.3627, -0.6154],
          [ 0.1429, -0.4597, -0.1869, -1.3242]],

         [[

In [36]:
attn_scores = queries @ keys.transpose(2, 3)
print("attn_scores:", attn_scores)
print("attn_scores.shape:", attn_scores.shape)

attn_scores: tensor([[[[ 0.1442,  0.2383, -0.2741, -0.4362],
          [-0.0103, -0.4822,  1.3495,  1.3038],
          [-0.5999,  0.5301,  0.1345, -0.5533],
          [-0.2151,  0.9204, -1.3538, -1.8418]],

         [[ 0.0976,  0.1440,  0.2287,  0.0583],
          [-0.5689, -0.2454, -0.1370,  0.5763],
          [-1.0879, -0.7065,  0.3808,  1.4644],
          [-0.1056,  0.2363,  0.1563,  0.3135]],

         [[ 0.0089,  0.1627,  0.0754, -0.2195],
          [ 0.0607,  0.0641,  0.0663,  0.2689],
          [-0.0890, -0.4280, -0.0488,  0.4349],
          [-0.0275, -0.5885, -0.4825,  0.2237]],

         [[-0.5195, -0.1201,  0.7729,  0.3285],
          [-0.0365, -0.0776,  0.0500, -0.0540],
          [ 0.3822,  0.1222, -0.7518, -0.5604],
          [ 0.3543,  0.2606, -0.7685, -0.4231]]],


        [[[ 0.1442,  0.2383, -0.2741, -0.4362],
          [-0.0103, -0.4822,  1.3495,  1.3038],
          [-0.5999,  0.5301,  0.1345, -0.5533],
          [-0.2151,  0.9204, -1.3538, -1.8418]],

         [[ 0.0

In [37]:
attn_scores.masked_fill_( torch.tril(torch.ones(num_tokens, num_tokens)) == 0, -torch.inf )
print("attn_scores after masking:", attn_scores)

attn_scores after masking: tensor([[[[ 0.1442,    -inf,    -inf,    -inf],
          [-0.0103, -0.4822,    -inf,    -inf],
          [-0.5999,  0.5301,  0.1345,    -inf],
          [-0.2151,  0.9204, -1.3538, -1.8418]],

         [[ 0.0976,    -inf,    -inf,    -inf],
          [-0.5689, -0.2454,    -inf,    -inf],
          [-1.0879, -0.7065,  0.3808,    -inf],
          [-0.1056,  0.2363,  0.1563,  0.3135]],

         [[ 0.0089,    -inf,    -inf,    -inf],
          [ 0.0607,  0.0641,    -inf,    -inf],
          [-0.0890, -0.4280, -0.0488,    -inf],
          [-0.0275, -0.5885, -0.4825,  0.2237]],

         [[-0.5195,    -inf,    -inf,    -inf],
          [-0.0365, -0.0776,    -inf,    -inf],
          [ 0.3822,  0.1222, -0.7518,    -inf],
          [ 0.3543,  0.2606, -0.7685, -0.4231]]],


        [[[ 0.1442,    -inf,    -inf,    -inf],
          [-0.0103, -0.4822,    -inf,    -inf],
          [-0.5999,  0.5301,  0.1345,    -inf],
          [-0.2151,  0.9204, -1.3538, -1.8418]],

 

In [38]:
attn_weights = torch.softmax( attn_scores / (keys.shape[-1] ** 0.5), dim = -1 )
attn_weights = torch.dropout( attn_weights, p=0, train=True )

print("attn_weights:", attn_weights)
print("attn_weights.shape:", attn_weights.shape)

attn_weights: tensor([[[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.5587, 0.4413, 0.0000, 0.0000],
          [0.2379, 0.4186, 0.3435, 0.0000],
          [0.2650, 0.4675, 0.1500, 0.1175]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.4596, 0.5404, 0.0000, 0.0000],
          [0.2329, 0.2818, 0.4853, 0.0000],
          [0.2193, 0.2602, 0.2500, 0.2705]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.4996, 0.5004, 0.0000, 0.0000],
          [0.3491, 0.2947, 0.3562, 0.0000],
          [0.2713, 0.2050, 0.2161, 0.3076]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.5051, 0.4949, 0.0000, 0.0000],
          [0.4089, 0.3591, 0.2320, 0.0000],
          [0.3123, 0.2980, 0.1781, 0.2117]]],


        [[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.5587, 0.4413, 0.0000, 0.0000],
          [0.2379, 0.4186, 0.3435, 0.0000],
          [0.2650, 0.4675, 0.1500, 0.1175]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.4596, 0.5404, 0.0000, 0.0000],
      

In [None]:
# sum of attention weights for each query should be 1
attn_weights.sum(dim=-1)

tensor([[[1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000]],

        [[1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000]]], grad_fn=<SumBackward1>)

In [None]:
context_vector = (attn_weights @ values).transpose(1,2)
print("context_vector:", context_vector)
print("context_vector.shape:", context_vector.shape)

context_vector: tensor([[[[-0.5607,  0.6784, -0.2333,  0.2510],
          [ 0.4472,  1.0648,  1.3709,  0.6711],
          [-0.7534, -0.0681, -0.3097, -0.8914],
          [-0.7052,  1.3109, -0.2726,  0.1822]],

         [[-0.4828,  0.1275,  0.1373, -0.2131],
          [ 0.1001,  0.5819,  0.7361,  0.5679],
          [-0.2963, -0.0278,  0.2488, -0.0653],
          [ 0.0186,  0.6759,  0.3274, -0.6581]],

         [[-0.5356,  0.3507,  0.4577,  0.0535],
          [ 0.0167,  0.5395,  0.4742,  0.8185],
          [ 0.1197, -0.0714,  0.0216, -0.1946],
          [-0.5843,  0.6565,  0.4837,  0.2180]],

         [[-0.5560,  0.2566,  0.2996,  0.0626],
          [-0.0361,  0.5094,  0.6006,  0.6674],
          [-0.1068, -0.1472,  0.0408, -0.1809],
          [-0.4264,  0.8434,  0.3065, -0.0407]]],


        [[[-0.5607,  0.6784, -0.2333,  0.2510],
          [ 0.4472,  1.0648,  1.3709,  0.6711],
          [-0.7534, -0.0681, -0.3097, -0.8914],
          [-0.7052,  1.3109, -0.2726,  0.1822]],

         [[-

In [None]:
context_vector = context_vector.contiguous().view(b, num_tokens, d_out)
print("context_vector:", context_vector)
print("context_vector.shape:", context_vector.shape)

context_vector: tensor([[[-0.5607,  0.6784, -0.2333,  0.2510,  0.4472,  1.0648,  1.3709,
           0.6711, -0.7534, -0.0681, -0.3097, -0.8914, -0.7052,  1.3109,
          -0.2726,  0.1822],
         [-0.4828,  0.1275,  0.1373, -0.2131,  0.1001,  0.5819,  0.7361,
           0.5679, -0.2963, -0.0278,  0.2488, -0.0653,  0.0186,  0.6759,
           0.3274, -0.6581],
         [-0.5356,  0.3507,  0.4577,  0.0535,  0.0167,  0.5395,  0.4742,
           0.8185,  0.1197, -0.0714,  0.0216, -0.1946, -0.5843,  0.6565,
           0.4837,  0.2180],
         [-0.5560,  0.2566,  0.2996,  0.0626, -0.0361,  0.5094,  0.6006,
           0.6674, -0.1068, -0.1472,  0.0408, -0.1809, -0.4264,  0.8434,
           0.3065, -0.0407]],

        [[-0.5607,  0.6784, -0.2333,  0.2510,  0.4472,  1.0648,  1.3709,
           0.6711, -0.7534, -0.0681, -0.3097, -0.8914, -0.7052,  1.3109,
          -0.2726,  0.1822],
         [-0.4828,  0.1275,  0.1373, -0.2131,  0.1001,  0.5819,  0.7361,
           0.5679, -0.2963, -0.027