### Self-attention
Notebook to see lower level workings of self-attention

In [1]:
# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F

The idea of self attention (in context of text processing) is: given an input sequence create a vector representation of each token. The vector is going to contain information about the token in context with the rest of the tokens in the sequence.  

In [2]:
sentence = 'Life is short, eat dessert first'

sentence = sentence.replace(',', '').split()

vocab = {word: idx for idx, word in enumerate(sorted(sentence))}
vocab

{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}

In [3]:
tokens = torch.tensor([vocab[word] for word in sentence])
tokens

tensor([0, 4, 5, 2, 1, 3])

In [4]:
torch.manual_seed(123)
# Embeddings
emb_dim = 16
emb = nn.Embedding(len(vocab), emb_dim)

In [5]:
# Pass tokens through embedding layer
inputs = emb(tokens).detach()
inputs.shape

torch.Size([6, 16])

In [6]:
inputs

tensor([[ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196, -0.3792,
          0.7671, -1.1925,  0.6984, -1.4097,  0.1794,  1.8951,  0.4954,  0.2692],
        [ 0.5146,  0.9938, -0.2587, -1.0826, -0.0444,  1.6236, -2.3229,  1.0878,
          0.6716,  0.6933, -0.9487, -0.0765, -0.1526,  0.1167,  0.4403, -1.4465],
        [ 0.2553, -0.5496,  1.0042,  0.8272, -0.3948,  0.4892, -0.2168, -1.7472,
         -1.6025, -1.0764,  0.9031, -0.7218, -0.5951, -0.7112,  0.6230, -1.3729],
        [-1.3250,  0.1784, -2.1338,  1.0524, -0.3885, -0.9343, -0.4991, -1.0867,
          0.8805,  1.5542,  0.6266, -0.1755,  0.0983, -0.0935,  0.2662, -0.5850],
        [-0.0770, -1.0205, -0.1690,  0.9178,  1.5810,  1.3010,  1.2753, -0.2010,
          0.4965, -1.5723,  0.9666, -1.1481, -1.1589,  0.3255, -0.6315, -2.8400],
        [ 0.8768,  1.6221, -1.4779,  1.1331, -1.2203,  1.3139,  1.0533,  0.1388,
          2.2473, -0.8036, -0.2808,  0.7697, -0.6596, -0.7979,  0.1838,  0.2293]])

We have arbitrarily chosen embedding dimension of 24. To calculate self-attention for token 2 we will need the query vector of token 2 and key and value vectors for all tokens of the sequence. The vectors are created by matrix multiplying embedding vectors with Query, Key and Value matrices

In [14]:
token_2 = inputs[1]
print(token_2)
token_2.shape

tensor([ 0.5146,  0.9938, -0.2587, -1.0826, -0.0444,  1.6236, -2.3229,  1.0878,
         0.6716,  0.6933, -0.9487, -0.0765, -0.1526,  0.1167,  0.4403, -1.4465])


torch.Size([16])

In [44]:
torch.manual_seed(123)

# Dimensions for query, key and value matrices
d_q, d_k, d_v = 24, 24, 28
W_query = nn.Parameter(torch.rand(d_q, emb_dim))
W_key = nn.Parameter(torch.rand(d_k, emb_dim))
W_value = nn.Parameter(torch.rand(d_v, emb_dim))

In [45]:
W_query.shape, W_key.shape, W_value.shape

(torch.Size([24, 16]), torch.Size([24, 16]), torch.Size([28, 16]))

In [46]:
query_2 = token_2.matmul(W_query.T) # Query vector for token 2
key_2 = token_2.matmul(W_key.T) # Key vector for token 2
value_2 = token_2.matmul(W_value.T)
query_2.size(), key_2.size(), value_2.size()

(torch.Size([24]), torch.Size([24]), torch.Size([28]))

In [52]:
# Creating Key, Value vector for all tokens in the sequence
# keys = inputs.matmul(W_key.T)
keys = W_key.matmul(inputs.T).T
values = W_value.matmul(inputs.T).T
keys.shape, values.shape

(torch.Size([6, 24]), torch.Size([6, 28]))

In [48]:
omega_24 = query_2.dot(keys[4])
omega_24

tensor(11.1466, grad_fn=<DotBackward0>)

In [49]:
omega_2 = query_2.matmul(keys.T) 
omega_2

tensor([ 8.5808, -7.6597,  3.2558,  1.0395, 11.1466, -0.4800],
       grad_fn=<SqueezeBackward4>)

In [50]:
attention_weights2 = F.softmax((omega_2 / d_q**0.5), dim=0)
attention_weights2

tensor([0.2912, 0.0106, 0.0982, 0.0625, 0.4917, 0.0458],
       grad_fn=<SoftmaxBackward0>)

In [55]:
context_vector2 = attention_weights2.matmul(values)
contex

tensor([[ 0.7392,  1.6789,  2.7937,  1.4616, -0.8727,  1.1213, -0.8441,  0.2829,
          1.7343,  1.6112,  2.1563,  1.1398,  1.6928,  1.5737,  1.7709,  0.9618,
          1.3077,  0.2716,  0.3070,  0.3427,  2.4012,  1.9869, -1.1107,  0.6782,
         -0.2181,  0.8178,  0.5018,  0.9887],
        [ 0.0166, -0.0809,  1.2402,  1.2786,  1.6755,  0.5242,  0.5165,  0.2638,
          0.1946,  0.1296, -0.2176, -1.2548, -0.9272, -1.3402, -0.4107, -0.0859,
          1.0926,  0.4078, -0.6770,  0.1110, -1.1055,  0.3156, -0.3169,  0.7937,
         -1.1166,  3.0497, -0.2863,  1.5513],
        [-4.1774, -1.6440, -1.9643, -1.6642, -1.0216, -5.0441, -1.4350, -3.0582,
         -1.3735, -1.0167, -0.9397, -2.5408, -2.1351, -1.8701, -1.9994, -3.7609,
         -3.8755, -3.1365, -2.1639, -3.0949, -3.7118, -1.8682, -1.8869, -1.7023,
         -1.4043, -4.1602, -3.5326, -1.8202],
        [ 0.5068, -1.0825, -2.4869, -0.3825,  0.3522, -3.0291, -1.0645, -0.9245,
         -3.0223, -0.6932, -2.1795,  0.1399,  0.0171

We are going to use a single sentence to implement self-attention. We will use the individual words in the sentence as our vocabulary for simplicity. We will assign a number to each unique word in the vocabulary which will the token for the word.

In [5]:
sentence = 'Life is short, eat dessert first'
sentence = sentence.replace(',', '')

vocab = {w:i for i,w in enumerate(sorted(sentence.split()))}
vocab

{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}

In [6]:
# Convert words in sentence to tokens
sentence_token = torch.tensor([vocab[word] for word in sentence.split()])
sentence_token

tensor([0, 4, 5, 2, 1, 3])

We will pass the tokens through an embedding layer. Embeddings are vectors with floating point numbers. The vector values after successful round of model training captures semantic meanings of the word.

In [7]:
torch.manual_seed(123)
dim = 16

emb = nn.Embedding(len(vocab), dim)

# Passing sentence through embedding layer
sentence_emb = emb(sentence_token)
sentence_emb.shape

torch.Size([6, 16])

The idea of self-attention is to capture the representation a word/token in relationship with all the other tokens in the sequence/input. To capture a vector representation second word in our sequence we calculate a query vector for the second word. We also calculate key and value vector for all the words in the sequence. These vectors are computed by Query, Key and Value matrices; values of which learned during training. Dot product between query vector for second word and all the other key vector creates an un-normalized score vector. This vector is divided by the square root of the number of dimensions of key vector. This is done so that the values do not exhibit numerical instability and helps convergence during training. The score vector goes through softmax to return a vector of probability distribution. This normalized score vector does a dot product with the value vectors to form self-attention score for the second word. 

In [12]:
torch.manual_seed(123)

q_d, k_d, v_d = 24, 24, 28
W_query = nn.Parameter(torch.rand(q_d, dim))
W_key = nn.Parameter(torch.rand(k_d, dim))
W_value = nn.Parameter(torch.rand(v_d, dim))
W_query.shape, W_key.shape, W_value.shape

(torch.Size([24, 16]), torch.Size([24, 16]), torch.Size([28, 16]))

In [15]:
token_2 = sentence_emb[1]
query_2 = W_query.matmul(token_2) # Query vector for token 2
query_2.shape

torch.Size([24])

In [20]:
keys = W_key.matmul(sentence_emb.T).T # Key vector for all tokens
values = W_value.matmul(sentence_emb.T).T # Value vector for all tokens
keys.shape, values.shape

(torch.Size([6, 24]), torch.Size([6, 28]))

In [21]:
omega24 = query_2.dot(keys[4]) # un-normalized score between second and fifth token
omega24

tensor(11.1466, grad_fn=<DotBackward0>)

In [23]:
omega2 = query_2.matmul(keys.T) # un-normalized score second and all tokens
omega2

tensor([ 8.5808, -7.6597,  3.2558,  1.0395, 11.1466, -0.4800],
       grad_fn=<SqueezeBackward4>)

In [27]:
attention_weights2 = F.softmax((omega2 / (k_d**0.5)), dim=0) # normalized attention scores
attention_weights2

tensor([0.2912, 0.0106, 0.0982, 0.0625, 0.4917, 0.0458],
       grad_fn=<SoftmaxBackward0>)

In [41]:
# context_2 = (values.T).matmul(attention_weights2)
context_2 = attention_weights2.matmul(values)
context_2

tensor([-1.5993,  0.0156,  1.2670,  0.0032, -0.6460, -1.1407, -0.4908, -1.4632,
         0.4747,  1.1926,  0.4506, -0.7110,  0.0602,  0.7125, -0.1628, -2.0184,
         0.3838, -2.1188, -0.8136, -1.5694,  0.7934, -0.2911, -1.3640, -0.2366,
        -0.9564, -0.5265,  0.0624,  1.7084], grad_fn=<SqueezeBackward4>)

In [32]:
queries = W_query.matmul(sentence_emb.T).T # Query vectors for all tokens
queries.shape

torch.Size([6, 24])

In [33]:
omega = queries.matmul(keys.T) # Un-normalized attention score 
omega

tensor([[  16.4255,    8.1306,  -24.5414,  -19.6606,   -9.5164,   19.2777],
        [   8.5808,   -7.6597,    3.2558,    1.0395,   11.1466,   -0.4800],
        [ -39.2836,   -1.5165,  145.4604,   74.2561,   58.8008, -141.6884],
        [  -5.2174,   -4.6914,   74.9203,   30.6947,   35.7423,  -73.7312],
        [ -21.6148,   10.6362,   65.4889,   39.2832,   21.8496,  -80.2922],
        [  40.0110,   -8.6863, -129.7707,  -64.2901,  -39.9965,  102.5285]],
       grad_fn=<MmBackward0>)

In [35]:
attention_weights = F.softmax((omega / (k_d**0.5)), dim=1) 
attention_weights

tensor([[3.3559e-01, 6.1726e-02, 7.8361e-05, 2.1222e-04, 1.6829e-03, 6.0071e-01],
        [2.9123e-01, 1.0581e-02, 9.8213e-02, 6.2474e-02, 4.9169e-01, 4.5814e-02],
        [4.1922e-17, 9.3433e-14, 1.0000e+00, 4.8723e-07, 2.0779e-08, 3.5016e-26],
        [7.8632e-08, 8.7544e-08, 9.9954e-01, 1.2001e-04, 3.3626e-04, 6.6351e-14],
        [1.8886e-08, 1.3652e-05, 9.9512e-01, 4.7287e-03, 1.3467e-04, 1.1868e-13],
        [2.8696e-06, 1.3829e-10, 2.5508e-21, 1.6275e-15, 2.3183e-13, 1.0000e+00]],
       grad_fn=<SoftmaxBackward0>)

In [42]:
context = attention_weights.matmul(values)
context.shape

torch.Size([6, 28])

We have created a context vector for each token in the sequence. 

### Multi-headed attention
For multi-headed attention we will be doing the same computation we did above to create vector representations of each token weighted by the attention scores of each token. We will be using different Query, Key and Value matrices to create the initial vectors. The idea is to be able to capture different relationships betweens tokens with different attention weighted value vectors. Since the values of the matrices are updated during training; they will help capture dependencies between tokens in the sequence.

In [79]:
h = 3
multiheaded_w_query = nn.Parameter(torch.rand(h, q_d, dim))
multiheaded_w_key = nn.Parameter(torch.rand(h, k_d, dim))
multiheaded_w_value = nn.Parameter(torch.rand(h, v_d, dim))
multiheaded_w_query.shape, multiheaded_w_key.shape, multiheaded_w_value.shape

(torch.Size([3, 24, 16]), torch.Size([3, 24, 16]), torch.Size([3, 28, 16]))

In [88]:
multiheaded_query_2 = multiheaded_w_query.matmul(token_2) # Three heads will create 3 three queries 
multiheaded_key_2 = multiheaded_w_key.matmul(token_2)
multiheaded_value_2 = multiheaded_w_value.matmul(token_2)
multiheaded_query_2.shape, multiheaded_key_2.shape, multiheaded_value_2.shape

(torch.Size([3, 24]), torch.Size([3, 24]), torch.Size([3, 28]))

In [90]:
stacked_inputs = sentence_emb.repeat(3, 1, 1) # Stacking inputs since we have multiple weight matrices 
stacked_inputs.shape

torch.Size([3, 6, 16])

In [96]:
multi_headed_queries = multiheaded_w_query.bmm(stacked_inputs.permute(0, 2, 1))
multi_headed_keys = multiheaded_w_key.bmm(stacked_inputs.permute(0, 2, 1))
multi_headed_values = multiheaded_w_value.bmm(stacked_inputs.permute(0, 2, 1))
multi_headed_queries.shape, multi_headed_keys.shape, multi_headed_values.shape

(torch.Size([3, 24, 6]), torch.Size([3, 24, 6]), torch.Size([3, 28, 6]))

In [98]:
multi_headed_scores = multi_headed_queries.permute(0, 2, 1).bmm(multi_headed_keys) # Attention scores (un-normalized)
multi_headed_scores.shape

torch.Size([3, 6, 6])

In [101]:
multi_headed_attention = F.softmax((multi_headed_scores / (q_d**0.5)), dim=-1) # Normalized attention scores 
multi_headed_attention.shape

torch.Size([3, 6, 6])

In [102]:
multi_headed_context = multi_headed_attention.bmm(multi_headed_values.permute(0, 2, 1))
multi_headed_context.shape

torch.Size([3, 6, 28])