In [4]:
import torch

In [11]:
# Create an embedding layer with 4 possible inputs and an embedding dimension of 8
inputs = torch.nn.Embedding( 4, 8)

In [None]:
inputs = inputs.weight
inputs = inputs.data


torch.Size([4, 8])

In [13]:
inputs.shape

torch.Size([4, 8])

In [7]:
# Create a batch of inputs. This determines d_in and context_length.
batches = torch.stack((inputs, inputs), dim = 0)


In [14]:
# Pick a d_out dimension
d_out = 16
d_in = 4
num_heads = 4
assert (d_out % num_heads) == 0, "d_out must be divisible by num_heads"


In [15]:
# In a sequence of cells, go step-by-step through the same calculation of queries and attention scores that the MultiHeadAttention class does. 
# At each step, print the tensors and their shapes to see why reshaping is necessary.
b, num_tokens, d_in = batches.shape
print("batches:", batches)
print("batches.shape:", batches.shape)



batches: tensor([[[ 2.0086, -0.1510, -2.0556,  1.2846, -0.2302,  0.5610, -0.7713,
          -0.0872],
         [ 1.8591, -0.2699, -1.4520, -0.4074,  0.8306,  0.4252,  0.4360,
           0.1402],
         [ 1.1846,  0.1166, -1.6958,  1.6242,  0.5902,  1.7603, -1.1903,
           1.2696],
         [ 0.4485,  1.8525, -0.2593,  1.1432,  0.9514,  0.2961, -0.1009,
          -1.9448]],

        [[ 2.0086, -0.1510, -2.0556,  1.2846, -0.2302,  0.5610, -0.7713,
          -0.0872],
         [ 1.8591, -0.2699, -1.4520, -0.4074,  0.8306,  0.4252,  0.4360,
           0.1402],
         [ 1.1846,  0.1166, -1.6958,  1.6242,  0.5902,  1.7603, -1.1903,
           1.2696],
         [ 0.4485,  1.8525, -0.2593,  1.1432,  0.9514,  0.2961, -0.1009,
          -1.9448]]])
batches.shape: torch.Size([2, 4, 8])


In [16]:
W_keys = torch.nn.Linear( d_in, d_out, bias=False )
W_queries = torch.nn.Linear( d_in, d_out, bias=False )
W_values = torch.nn.Linear( d_in, d_out, bias=False )
print("W_keys, W_queries, W_values:", W_keys, W_queries, W_values)

keys = W_keys( batches )
queries = W_queries( batches )
values = W_values( batches )
print("keys:", keys)
print("keys.shape:", keys.shape)
print("queries:", queries)  
print("queries.shape:", queries.shape)
print("values:", values)
print("values.shape:", values.shape)

W_keys, W_queries, W_values: Linear(in_features=8, out_features=16, bias=False) Linear(in_features=8, out_features=16, bias=False) Linear(in_features=8, out_features=16, bias=False)
keys: tensor([[[ 0.1792,  0.5344, -0.0737, -1.9711, -0.8171,  1.0168, -0.0925,
          -0.2938,  0.7098,  0.7146, -1.5393,  0.0260,  0.7139,  0.0185,
          -0.1698,  1.1959],
         [-0.4296,  0.0281, -0.2799, -0.5735, -0.5150,  0.3262, -0.9122,
           0.2082,  0.1979,  0.6983, -0.4326, -0.0775,  0.3171, -0.4864,
          -0.5605,  0.4821],
         [-0.7619, -0.0713,  0.8305, -1.7972, -0.3432,  0.7982,  0.3749,
          -0.5405,  1.0852,  0.8112, -1.4471, -0.1882,  0.3356, -0.2845,
          -0.6668,  0.5266],
         [ 0.5948,  0.2070, -0.1664, -1.0679, -0.2395,  0.1607,  0.3524,
           0.6726,  0.0969,  0.0148, -0.0999, -0.9833,  0.4822,  0.4107,
          -0.3367,  0.7070]],

        [[ 0.1792,  0.5344, -0.0737, -1.9711, -0.8171,  1.0168, -0.0925,
          -0.2938,  0.7098,  0.7146, 

In [17]:
keys = keys.view( b, num_tokens, num_heads, d_out // num_heads )
queries = queries.view( b, num_tokens, num_heads, d_out // num_heads )
values = values.view( b, num_tokens, num_heads, d_out // num_heads )
print("keys reshaped:", keys)
print("keys reshaped.shape:", keys.shape)
print("queries reshaped:", queries)
print("queries reshaped.shape:", queries.shape)
print("values reshaped:", values)
print("values reshaped.shape:", values.shape)



keys reshaped: tensor([[[[ 0.1792,  0.5344, -0.0737, -1.9711],
          [-0.8171,  1.0168, -0.0925, -0.2938],
          [ 0.7098,  0.7146, -1.5393,  0.0260],
          [ 0.7139,  0.0185, -0.1698,  1.1959]],

         [[-0.4296,  0.0281, -0.2799, -0.5735],
          [-0.5150,  0.3262, -0.9122,  0.2082],
          [ 0.1979,  0.6983, -0.4326, -0.0775],
          [ 0.3171, -0.4864, -0.5605,  0.4821]],

         [[-0.7619, -0.0713,  0.8305, -1.7972],
          [-0.3432,  0.7982,  0.3749, -0.5405],
          [ 1.0852,  0.8112, -1.4471, -0.1882],
          [ 0.3356, -0.2845, -0.6668,  0.5266]],

         [[ 0.5948,  0.2070, -0.1664, -1.0679],
          [-0.2395,  0.1607,  0.3524,  0.6726],
          [ 0.0969,  0.0148, -0.0999, -0.9833],
          [ 0.4822,  0.4107, -0.3367,  0.7070]]],


        [[[ 0.1792,  0.5344, -0.0737, -1.9711],
          [-0.8171,  1.0168, -0.0925, -0.2938],
          [ 0.7098,  0.7146, -1.5393,  0.0260],
          [ 0.7139,  0.0185, -0.1698,  1.1959]],

         [[-0

In [18]:
keys = keys.transpose(1,2)
queries = queries.transpose(1,2)
values = values.transpose(1,2)
print("keys transposed:", keys)
print("keys transposed.shape:", keys.shape)
print("queries transposed:", queries)
print("queries transposed.shape:", queries.shape)
print("values transposed:", values)
print("values transposed.shape:", values.shape)

keys transposed: tensor([[[[ 0.1792,  0.5344, -0.0737, -1.9711],
          [-0.4296,  0.0281, -0.2799, -0.5735],
          [-0.7619, -0.0713,  0.8305, -1.7972],
          [ 0.5948,  0.2070, -0.1664, -1.0679]],

         [[-0.8171,  1.0168, -0.0925, -0.2938],
          [-0.5150,  0.3262, -0.9122,  0.2082],
          [-0.3432,  0.7982,  0.3749, -0.5405],
          [-0.2395,  0.1607,  0.3524,  0.6726]],

         [[ 0.7098,  0.7146, -1.5393,  0.0260],
          [ 0.1979,  0.6983, -0.4326, -0.0775],
          [ 1.0852,  0.8112, -1.4471, -0.1882],
          [ 0.0969,  0.0148, -0.0999, -0.9833]],

         [[ 0.7139,  0.0185, -0.1698,  1.1959],
          [ 0.3171, -0.4864, -0.5605,  0.4821],
          [ 0.3356, -0.2845, -0.6668,  0.5266],
          [ 0.4822,  0.4107, -0.3367,  0.7070]]],


        [[[ 0.1792,  0.5344, -0.0737, -1.9711],
          [-0.4296,  0.0281, -0.2799, -0.5735],
          [-0.7619, -0.0713,  0.8305, -1.7972],
          [ 0.5948,  0.2070, -0.1664, -1.0679]],

         [[

In [None]:
attn_scores = queries @ keys.transpose(2, 3)
print("attn_scores:", attn_scores)
print("attn_scores.shape:", attn_scores.shape)

attn_scores: tensor([[[[ 1.1404, -0.1088,  1.6739,  0.5728],
          [ 0.1010, -0.1527,  0.6780, -0.0337],
          [ 3.0026,  0.5928,  4.1315,  1.4346],
          [ 0.8191, -0.2085,  1.2100,  0.6873]],

         [[ 0.8445,  0.5484,  0.4263,  0.4687],
          [ 0.2357, -0.0402,  0.1806,  0.5347],
          [-0.1721,  0.0657, -0.3189,  0.5850],
          [-0.0270,  1.2941, -0.7699, -0.3132]],

         [[-1.4983, -0.8562, -2.0953, -0.3200],
          [-0.7500, -0.2332, -0.6076,  0.6687],
          [-0.3512, -0.3216, -0.9137, -0.6057],
          [-0.8700, -0.2104, -0.8382,  0.0746]],

         [[ 1.4319,  0.5373,  0.6251,  0.9289],
          [ 0.9646,  0.0623,  0.1073,  0.6181],
          [ 0.8231,  0.0708,  0.1351,  0.5497],
          [ 0.7958,  0.6663,  0.6486,  0.4231]]],


        [[[ 1.1404, -0.1088,  1.6739,  0.5728],
          [ 0.1010, -0.1527,  0.6780, -0.0337],
          [ 3.0026,  0.5928,  4.1315,  1.4346],
          [ 0.8191, -0.2085,  1.2100,  0.6873]],

         [[ 0.8

In [23]:
attn_scores.masked_fill_( torch.tril(torch.ones(num_tokens, num_tokens)) == 0, -torch.inf )
print("attn_scores after masking:", attn_scores)

attn_scores after masking: tensor([[[[ 1.1404,    -inf,    -inf,    -inf],
          [ 0.1010, -0.1527,    -inf,    -inf],
          [ 3.0026,  0.5928,  4.1315,    -inf],
          [ 0.8191, -0.2085,  1.2100,  0.6873]],

         [[ 0.8445,    -inf,    -inf,    -inf],
          [ 0.2357, -0.0402,    -inf,    -inf],
          [-0.1721,  0.0657, -0.3189,    -inf],
          [-0.0270,  1.2941, -0.7699, -0.3132]],

         [[-1.4983,    -inf,    -inf,    -inf],
          [-0.7500, -0.2332,    -inf,    -inf],
          [-0.3512, -0.3216, -0.9137,    -inf],
          [-0.8700, -0.2104, -0.8382,  0.0746]],

         [[ 1.4319,    -inf,    -inf,    -inf],
          [ 0.9646,  0.0623,    -inf,    -inf],
          [ 0.8231,  0.0708,  0.1351,    -inf],
          [ 0.7958,  0.6663,  0.6486,  0.4231]]],


        [[[ 1.1404,    -inf,    -inf,    -inf],
          [ 0.1010, -0.1527,    -inf,    -inf],
          [ 3.0026,  0.5928,  4.1315,    -inf],
          [ 0.8191, -0.2085,  1.2100,  0.6873]],

 

In [26]:
attn_weights = torch.softmax( attn_scores / (keys.shape[-1] ** 0.5), dim = -1 )
attn_weights = torch.dropout( attn_weights, p=0, train=True )

print("attn_weights:", attn_weights)
print("attn_weights.shape:", attn_weights.shape)

attn_weights: tensor([[[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.5317, 0.4683, 0.0000, 0.0000],
          [0.3270, 0.0980, 0.5750, 0.0000],
          [0.2666, 0.1595, 0.3242, 0.2496]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.5344, 0.4656, 0.0000, 0.0000],
          [0.3273, 0.3686, 0.3041, 0.0000],
          [0.2226, 0.4309, 0.1535, 0.1929]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.4358, 0.5642, 0.0000, 0.0000],
          [0.3610, 0.3664, 0.2725, 0.0000],
          [0.1996, 0.2776, 0.2028, 0.3201]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.6109, 0.3891, 0.0000, 0.0000],
          [0.4175, 0.2866, 0.2960, 0.0000],
          [0.2705, 0.2536, 0.2513, 0.2245]]],


        [[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.5317, 0.4683, 0.0000, 0.0000],
          [0.3270, 0.0980, 0.5750, 0.0000],
          [0.2666, 0.1595, 0.3242, 0.2496]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.5344, 0.4656, 0.0000, 0.0000],
      

In [28]:
context_vector = (attn_weights @ values).transpose(1,2)
print("context_vector:", context_vector)
print("context_vector.shape:", context_vector.shape)

context_vector: tensor([[[[-0.0884,  0.8430, -0.0927,  1.3196],
          [-0.5998, -0.5374,  1.5991, -0.5446],
          [-0.7812,  0.3848, -0.0438, -0.0863],
          [ 0.6292,  1.2160, -0.4138, -0.3224]],

         [[-0.2552,  0.4645, -0.1990,  1.1270],
          [-0.3884, -0.5494,  1.1331, -0.4935],
          [-0.4176,  0.4528,  0.0896, -0.1705],
          [ 0.6013,  0.8854, -0.3189, -0.1474]],

         [[-0.3164,  1.0176, -0.4802,  1.4849],
          [-0.3590, -0.6487,  1.1662, -0.5169],
          [-0.7393,  0.3871, -0.0564, -0.1088],
          [ 0.5345,  1.1039, -0.1056, -0.1421]],

         [[-0.3172,  0.7105, -0.2045,  1.3013],
          [-0.2476, -0.4774,  0.8094, -0.3356],
          [-0.6597,  0.1512, -0.1611, -0.1583],
          [ 0.5857,  0.8594,  0.0154, -0.2862]]],


        [[[-0.0884,  0.8430, -0.0927,  1.3196],
          [-0.5998, -0.5374,  1.5991, -0.5446],
          [-0.7812,  0.3848, -0.0438, -0.0863],
          [ 0.6292,  1.2160, -0.4138, -0.3224]],

         [[-

In [29]:
context_vector = context_vector.contiguous().view(b, num_tokens, d_out)
print("context_vector:", context_vector)
print("context_vector.shape:", context_vector.shape)

context_vector: tensor([[[-0.0884,  0.8430, -0.0927,  1.3196, -0.5998, -0.5374,  1.5991,
          -0.5446, -0.7812,  0.3848, -0.0438, -0.0863,  0.6292,  1.2160,
          -0.4138, -0.3224],
         [-0.2552,  0.4645, -0.1990,  1.1270, -0.3884, -0.5494,  1.1331,
          -0.4935, -0.4176,  0.4528,  0.0896, -0.1705,  0.6013,  0.8854,
          -0.3189, -0.1474],
         [-0.3164,  1.0176, -0.4802,  1.4849, -0.3590, -0.6487,  1.1662,
          -0.5169, -0.7393,  0.3871, -0.0564, -0.1088,  0.5345,  1.1039,
          -0.1056, -0.1421],
         [-0.3172,  0.7105, -0.2045,  1.3013, -0.2476, -0.4774,  0.8094,
          -0.3356, -0.6597,  0.1512, -0.1611, -0.1583,  0.5857,  0.8594,
           0.0154, -0.2862]],

        [[-0.0884,  0.8430, -0.0927,  1.3196, -0.5998, -0.5374,  1.5991,
          -0.5446, -0.7812,  0.3848, -0.0438, -0.0863,  0.6292,  1.2160,
          -0.4138, -0.3224],
         [-0.2552,  0.4645, -0.1990,  1.1270, -0.3884, -0.5494,  1.1331,
          -0.4935, -0.4176,  0.452