## Attention mechanism toy tutorial
- set up
- self attention
- causal attention
- multihead attention

In [1]:
import torch
print(torch.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

2.2.1
Using device: cpu


#### set up

In [25]:
seq = "This is simple toy tutorial on how attention mechanism works"

# import tiktoken
# tokenizer = tiktoken.get_encoding("gpt2")
# tokens = tokenizer.encode(seq, allowed_special={'|eos|'})
# print(tokens)

seq_list = seq.split(" ")
print(seq_list)
token_tensors = [torch.tensor(x) for x in range(len(seq_list))]
token_tensors

['This', 'is', 'simple', 'toy', 'tutorial', 'on', 'how', 'attention', 'mechanism', 'works']


[tensor(0),
 tensor(1),
 tensor(2),
 tensor(3),
 tensor(4),
 tensor(5),
 tensor(6),
 tensor(7),
 tensor(8),
 tensor(9)]

In [26]:
from torch import nn
# convert tokens into 3 dimensional embeddings
embed = nn.Embedding(len(token_tensors), 3)
input = torch.LongTensor(token_tensors)
embeddings = embed(input)
embeddings

tensor([[-0.9182,  0.0033,  0.9627],
        [-1.0063, -0.3444, -1.4247],
        [ 1.1410,  0.3782, -0.5953],
        [ 0.4299, -0.0343, -0.3688],
        [-1.2274, -0.6004, -0.1838],
        [ 0.4596, -0.0693, -1.5469],
        [ 2.7773,  0.3163, -0.4481],
        [-1.2793,  0.0345, -0.4968],
        [-1.6882, -0.5491, -0.9991],
        [-1.8605, -0.1361, -0.8169]], grad_fn=<EmbeddingBackward0>)

### self attention
- also called scaled dot-product attention
- 3 steps: 
    1. attention score
    2. normalization
    3. context vector 
- make it with trainable weights

In [29]:
print(seq_list)
print(seq_list[3])

['This', 'is', 'simple', 'toy', 'tutorial', 'on', 'how', 'attention', 'mechanism', 'works']
toy


In [39]:
# calculate attention score
# the dot product between each input token's embedding and each of the rest of the tokens' embeddings in the same sequence
# we refer to the token of focus as query - here take "toy" as example, which is the indexed by 3 in the seq_list 
query = embeddings[3]
attention_score_for_toy = torch.empty(embeddings.shape[0])
for x in range(len(embeddings)):
    attention_score_for_toy[x] = torch.dot(embeddings[x], query)
print(attention_score_for_toy)


tensor([-0.7499,  0.1047,  0.6971,  0.3220, -0.4393,  0.7705,  1.3484, -0.3680,
        -0.3385, -0.4939], grad_fn=<CopySlices>)


In [41]:
# attention scores' meaning:
# dot product mathematically combines two vectors to yield a scalar value
# dot product also measures similarity between the two vectors: the higher the more similar

# in practice, rather than loop through each token as the query, we do matmul
attention_scores = torch.matmul(embeddings, embeddings.T)
attention_scores

tensor([[ 1.7698, -0.4488, -1.6195, -0.7499,  0.9481, -1.9114, -2.9804,  0.6965,
          0.5864,  0.9214],
        [-0.4488,  3.1609, -0.4302,  0.1047,  1.7037,  1.7652, -2.2653,  1.9832,
          3.3113,  3.0828],
        [-1.6195, -0.4302,  1.7993,  0.6971, -1.5181,  1.4192,  3.5553, -1.1509,
         -1.5391, -1.6880],
        [-0.7499,  0.1047,  0.6971,  0.3220, -0.4393,  0.7705,  1.3484, -0.3680,
         -0.3385, -0.4939],
        [ 0.9481,  1.7037, -1.5181, -0.4393,  1.9008, -0.2383, -3.5165,  1.6408,
          2.5854,  2.5154],
        [-1.9114,  1.7652,  1.4192,  0.7705, -0.2383,  2.6089,  1.9478,  0.1780,
          0.8076,  0.4179],
        [-2.9804, -2.2653,  3.5553,  1.3484, -3.5165,  1.9478,  8.0142, -3.3196,
         -4.4146, -4.8441],
        [ 0.6965,  1.9832, -1.1509, -0.3680,  1.6408,  0.1780, -3.3196,  1.8847,
          2.6371,  2.7813],
        [ 0.5864,  3.3113, -1.5391, -0.3385,  2.5854,  0.8076, -4.4146,  2.6371,
          4.1497,  4.0317],
        [ 0.9214,  

In [46]:
# normalize attention scores into attention weights
# use torch.softmax() to avoid overflow/underflow and optimize compute
attention_weights = torch.softmax(attention_scores, dim=-1)
attention_weights

tensor([[3.6066e-01, 3.9226e-02, 1.2165e-02, 2.9027e-02, 1.5857e-01, 9.0853e-03,
         3.1196e-03, 1.2330e-01, 1.1045e-01, 1.5440e-01],
        [6.7964e-03, 2.5115e-01, 6.9234e-03, 1.1820e-02, 5.8488e-02, 6.2197e-02,
         1.1050e-03, 7.7349e-02, 2.9190e-01, 2.3227e-01],
        [4.0431e-03, 1.3281e-02, 1.2346e-01, 4.1004e-02, 4.4746e-03, 8.4413e-02,
         7.1471e-01, 6.4598e-03, 4.3818e-03, 3.7755e-03],
        [3.4629e-02, 8.1389e-02, 1.4719e-01, 1.0115e-01, 4.7241e-02, 1.5839e-01,
         2.8229e-01, 5.0735e-02, 5.2255e-02, 4.4732e-02],
        [5.4625e-02, 1.1629e-01, 4.6378e-03, 1.3641e-02, 1.4162e-01, 1.6677e-02,
         6.2869e-04, 1.0920e-01, 2.8083e-01, 2.6185e-01],
        [3.8281e-03, 1.5126e-01, 1.0702e-01, 5.5941e-02, 2.0399e-02, 3.5169e-01,
         1.8156e-01, 3.0933e-02, 5.8055e-02, 3.9317e-02],
        [1.6540e-05, 3.3814e-05, 1.1401e-02, 1.2545e-03, 9.6760e-06, 2.2846e-03,
         9.8498e-01, 1.1782e-05, 3.9412e-06, 2.5651e-06],
        [3.7602e-02, 1.3614

In [48]:
# check to make sure attention scores sum up to 1 for each row
attention_weights[3].sum()

tensor(1.0000, grad_fn=<SumBackward0>)

In [47]:
# calculate context vector bease on attention weights and input embedding 
all_context_vecs = torch.matmul(attention_weights, embeddings)
# all_context_vecs = attention_weights @ embeddings
all_context_vecs

tensor([[-1.1575, -0.1810, -0.0689],
        [-1.3100, -0.3126, -0.9870],
        [ 2.1370,  0.2556, -0.5660],
        [ 0.6602,  0.0412, -0.7017],
        [-1.4214, -0.3106, -0.7217],
        [ 0.4206, -0.0289, -1.0308],
        [ 2.7501,  0.3157, -0.4522],
        [-1.4282, -0.2850, -0.7887],
        [-1.5255, -0.3317, -0.8916],
        [-1.5547, -0.3089, -0.8561]], grad_fn=<MmBackward0>)

#### make it with trainable weights
- instead of interacting directly with embeddings, now interact with transformed version
- transformed by weight matrices: W_query, W_key, W_value
    - query = search query in db; represent what the nn is now focused on understanding 
    - key = what is being indexed and searched; key to index each token in the context window 
    - value = value in key-value pair; "attention" finds the most relevant key to the query and retrieves the key's value

In [49]:
# make two batches of training data (by stacking the embeddings for simplicity)
from torch import nn
class self_attention(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        # inherit from the nn.Module parent 
        super().__init__() 

        # technically, we initialize the weight
        # self.w_query = nn.Parameter(torch.rand(d_in, d_out))
        # in practice, we use the Linear layer with bias turned off to do so 
        self.w_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.w_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.w_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        # initialize the w_query, w_key, w_value 
        # AND matmul with input embeddings x
        queries = self.w_query(x)
        keys = self.w_key(x)
        values = self.w_value(x)
        # attention score query @ key.T
        attention_scores = torch.matmul(queries, keys.T)
        # attention weights = normalized attention scores
        # scale the attention scores by the sqrt(embedding dimentsion) first 
        # to improve the training performance by avoiding small gradients.
        attention_weights = torch.softmax(
            attention_scores / (keys.shape[-1]**0.5),
            dim=-1
        )
        # calculate context vector attention weights @ values
        context_vectors = torch.matmul(attention_weights, values)
        return context_vectors



In [50]:
embeddings.shape[-1]

3

In [52]:
torch.manual_seed(123)
# the example embeddings have dim 3
d_in = embeddings.shape[-1]
# usually in models like GPTs, d_in = d_out
# here just for easy visualization
d_out = embeddings.shape[-1]-1
# generate a self attention object from the class
self_attn = self_attention(d_in, d_out)
# apply self_attn to embeddings
self_attn(embeddings)

tensor([[ 0.2513, -0.1568],
        [ 0.6291, -0.0860],
        [ 0.3139, -0.1451],
        [ 0.3260, -0.1417],
        [ 0.4510, -0.1157],
        [ 0.5232, -0.1044],
        [ 0.0267, -0.2098],
        [ 0.5554, -0.0983],
        [ 0.6226, -0.0863],
        [ 0.6407, -0.0842]], grad_fn=<MmBackward0>)

### causal attention
- mask out the future tokens that the current token has not seen 
- only look at the previous and current token in a context window
- further apply a small portion of dropout (usually 0.1, 0.2)
- re-applying softmax normalization

In [53]:
# reuse the weights from the self attention
queries = self_attn.w_query(embeddings)
keys = self_attn.w_key(embeddings)
attention_scores = torch.matmul(queries, keys.T)
attention_weights = torch.softmax(
            attention_scores / (keys.shape[-1]**0.5),
            dim=-1
        )
print(attention_weights)

tensor([[0.1035, 0.0925, 0.1075, 0.1049, 0.0971, 0.0995, 0.1187, 0.0947, 0.0910,
         0.0905],
        [0.0709, 0.1422, 0.0542, 0.0666, 0.1128, 0.0905, 0.0316, 0.1162, 0.1606,
         0.1543],
        [0.1002, 0.0996, 0.1016, 0.0996, 0.0976, 0.1001, 0.0998, 0.1015, 0.0987,
         0.1012],
        [0.0994, 0.1012, 0.0989, 0.0991, 0.1003, 0.1000, 0.0972, 0.1009, 0.1014,
         0.1016],
        [0.0899, 0.1172, 0.0802, 0.0883, 0.1097, 0.0986, 0.0668, 0.1066, 0.1237,
         0.1189],
        [0.0831, 0.1268, 0.0711, 0.0797, 0.1086, 0.0964, 0.0505, 0.1134, 0.1358,
         0.1346],
        [0.1106, 0.0687, 0.1344, 0.1147, 0.0788, 0.0937, 0.1903, 0.0802, 0.0627,
         0.0660],
        [0.0796, 0.1314, 0.0658, 0.0760, 0.1108, 0.0949, 0.0444, 0.1140, 0.1432,
         0.1398],
        [0.0717, 0.1414, 0.0548, 0.0677, 0.1143, 0.0909, 0.0328, 0.1149, 0.1598,
         0.1516],
        [0.0694, 0.1440, 0.0524, 0.0650, 0.1129, 0.0896, 0.0298, 0.1165, 0.1635,
         0.1568]], grad_fn=<

In [55]:
# take vantage of the fact that softmax(-inf) approaches 0
context_length = embeddings.shape[0]
mask_original = torch.tril(torch.ones(context_length, context_length))
print("technically the mask use if manually re-normalize")
print(mask_original)

mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
print("practically the mask we use to take advantage of softmax(-inf)")
print(mask)

technically the mask use if manually re-normalize
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])
practically the mask we use to take advantage of softmax(-inf)
tensor([[0., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0.,

In [58]:
# then apply the mask where the mask is used as boolean values 
# when 1 fill with -inf
masked = attention_scores.masked_fill(mask.bool(), -torch.inf)
print("causally masked attention scores")
print(masked)

causally masked attention scores
tensor([[-0.0087,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [ 0.0422,  1.0274,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [ 0.0033, -0.0055,  0.0226,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [ 0.0016,  0.0270, -0.0058, -0.0019,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [ 0.0124,  0.3880, -0.1481, -0.0119,  0.2947,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [ 0.0280,  0.6264, -0.1917, -0.0309,  0.4076,  0.2384,    -inf,    -inf,
            -inf,    -inf],
        [-0.0254, -0.7004,  0.2491,  0.0260, -0.5058, -0.2611,  0.7416,    -inf,
            -inf,    -inf],
        [ 0.0311,  0.7404, -0.2384, -0.0336,  0.4987,  0.2799, -0.7939,  0.5392,
            -inf,    -inf],
        [ 0.0390,  0.9997, -0.3394, -0.0411,  0.6986,  0.3752, -1.0649,  0.7066,
          1.17

In [59]:
# causal attention weights
causal_attention_weights = torch.softmax(
    masked / keys.shape[-1]**0.5, 
    dim=1
    )
print(causal_attention_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.3326, 0.6674, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.3325, 0.3304, 0.3371, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.2494, 0.2539, 0.2481, 0.2487, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.1852, 0.2415, 0.1653, 0.1820, 0.2261, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.1468, 0.2242, 0.1257, 0.1408, 0.1920, 0.1704, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.1399, 0.0868, 0.1698, 0.1450, 0.0996, 0.1184, 0.2406, 0.0000, 0.0000,
         0.0000],
        [0.1110, 0.1833, 0.0918, 0.1061, 0.1545, 0.1324, 0.0619, 0.1590, 0.0000,
         0.0000],
        [0.0845, 0.1666, 0.0646, 0.0798, 0.1347, 0.1072, 0.0387, 0.1355, 0.1884,
         0.0000],
        [0.0694, 0.1440, 0.0524, 0.0650, 0.1129, 0.0896, 0.0298, 0.1165, 0.1635,
         0.1568]], grad_fn=<

In [60]:
# spot check that each row still adds up to 1
causal_attention_weights[3].sum()

tensor(1., grad_fn=<SumBackward0>)

#### apply dropout

In [64]:
# make two batches of training data (by stacking the embeddings for simplicity)
from torch import nn
class Causal_Attention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout_rate, qkv_bias=False):
        # inherit from the nn.Module parent 
        super().__init__() 

        # technically, we initialize the weight
        # self.w_query = nn.Parameter(torch.rand(d_in, d_out))
        # in practice, we use the Linear layer with bias turned off to do so 
        self.d_out = d_out
        self.w_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.w_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.w_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        # add the buffer to create mask and send it to device with the model 
        # but not update it
        self.register_buffer(
            'mask',
            torch.triu(
                torch.ones(context_length,context_length),
                diagonal=1)
        )
        # add the dropout - object from nn.Dropout with param dropout_rate
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        # allowing batching: first is the batch dim of tensors
        batch, n_tokens, d_in = x.shape

        # initialize the w_query, w_key, w_value 
        # AND matmul with input embeddings x
        queries = self.w_query(x)
        keys = self.w_key(x)
        values = self.w_value(x)
        # attention score query @ key.T 
        # but remember the first dim is now batch so transpose the second and third
        attention_scores = torch.matmul(queries, keys.transpose(1, 2))

        # ###### add causal attention masks ######
        # computeation with trailing underscore are performed in-place
        attention_scores.masked_fill_(
            # change the mask to boolean
            self.mask.bool()[:n_tokens, :n_tokens],
            # fill value when 1 in mask
            -torch.inf
        )
        # ###### add causal attention masks ######

        # attention weights = normalized attention scores
        # scale the attention scores by the sqrt(embedding dimentsion) first 
        # to improve the training performance by avoiding small gradients.
        attention_weights = torch.softmax(
            attention_scores / (keys.shape[-1]**0.5),
            dim=-1
        )

        # ###### apply dropout to attention weights ######
        attention_weights = self.dropout(attention_weights)
        # ###### apply dropout to attention weights ######

        # calculate context vector attention weights @ values
        context_vectors = torch.matmul(attention_weights, values)
        return context_vectors

In [62]:
# to mimic actual training with batching 
# stack two embeddings as different batches
batch = torch.stack((embeddings, embeddings), dim=0)
# first dim shows the batch num
print(batch.shape)

torch.Size([2, 10, 3])


In [65]:
# try out Causal_Attention 
torch.manual_seed(123)
context_length = batch.shape[1]
causal_attention = Causal_Attention(d_in, d_out, context_length, dropout_rate=0.1)
context_vecs = causal_attention(batch)
print("context_vecs.shape:", context_vecs.shape)

context_vecs.shape: torch.Size([2, 10, 2])


### multihead attention

In [None]:
# write causal attention into a function 
