## Simple Self Attention without Trainable Weights

In [2]:
import torch

inputs = torch.tensor(
    [[0.43, 0.15, 0.89],
     [0.55, 0.87, 0.66],
     [0.57, 0.85, 0.64],
     [0.22, 0.58, 0.33],
     [0.77, 0.25, 0.10],
     [0.05, 0.80, 0.55]]
)
# inputs with embedding dimension 3 for the six tokens
# ["A", "cat", "sat", "on", "the", "mat"]

In [3]:
inputs

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])

In self attention our goal is to calculate `context vectors` for each element in the input sequence. A context vector can be interpreted as an enriched embedding vector.

In [5]:
query = inputs[1] # "cat"
attention_scores_1 = torch.empty(inputs.shape[0])
attention_scores_1 # attention score for word "cat" with respect to all other words

tensor([0., 0., 0., 0., 0., 0.])

In [8]:
for index, tokens in enumerate(inputs):
    attention_scores_1[index] = torch.dot(query, tokens)
    # dot product between the query and the key
    
attention_scores_1

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

This gives us the attention score S21, S22, S23, S24, S25, S26 i.e we get the context of second word "cat" with respect to all the words in the inputs. Later we apply `Softmax` to get the `Weight Matrix` to get the `Attention Weights`.

The Higher the dot product the higher the token attends to each other.

In [9]:
attention_weights_1_tmp = attention_scores_1 / attention_scores_1.sum()

print(f"Attention weights: {attention_weights_1_tmp}")
print(f"Sum of attention weights: {attention_weights_1_tmp.sum()}")

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum of attention weights: 1.0000001192092896


In [11]:
torch.exp(attention_scores_1), torch.exp(attention_scores_1).sum(0)

(tensor([2.5971, 4.4593, 4.3728, 2.3243, 2.0279, 2.9639]), tensor(18.7453))

In [12]:
def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0) # broadcasting here [1, 6] / [1] = [1, 6]

In [13]:
attention_weights_naive = softmax_naive(attention_scores_1)

print(f"Attention weights: {attention_weights_naive}")
print(f"Attention weights sum: {attention_weights_naive.sum()}")

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Attention weights sum: 1.0


There are a lot of advantage of using softmax for calculating attention weight. They are better at handling extreme values and even handle negative cases and some numerical instability like overflow and underflow also it help in training LLM efficiently. So we use `**Softmax**`.

In [None]:
def softmax(x):
    w = torch.softmax(x, dim=0)
    return w

In [16]:
attention_weights_1 = softmax(attention_scores_1)

print(f"Attention weights: {attention_weights_1}")

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])


Now we can calculate the `Context Vectors` by multiplying embedded **input tokens** again with **attention weights** and then summing the resulting vectors.

In [19]:
query = inputs[1] # "cat"
context_vector_1 = torch.zeros(query.shape) # [1, 6]

for index, x_i in enumerate(inputs):
    context_vector_1 += attention_weights_1[index] * x_i
    
print(f"The context vector for word cat: {context_vector_1}")


The context vector for word cat: tensor([0.4419, 0.6515, 0.5683])


Doing everything for all words to get attention weights for all tokens in **parallel**

In [25]:
query = inputs
attention_scores = torch.matmul(query, query.T)

attention_scores

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [31]:
torch.dot(inputs[0], inputs[0])

tensor(0.9995)

In [32]:
inputs

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])

In [27]:
attention_scores_manually = torch.empty(inputs.shape[0], inputs.shape[0])

for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attention_scores_manually[i, j] = torch.dot(x_i, x_j)
        
attention_scores_manually

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [36]:
text = ["A", "cat", "sat", "on","the", "mat"]

for i, attention in enumerate(attention_scores):
    for j, word in enumerate(text):
        print(f"Attention for '{text[i]}' to '{text[j]} is\t {attention[j]: .6f}")
    print("\n")

Attention for 'A' to 'A is	  0.999500
Attention for 'A' to 'cat is	  0.954400
Attention for 'A' to 'sat is	  0.942200
Attention for 'A' to 'on is	  0.475300
Attention for 'A' to 'the is	  0.457600
Attention for 'A' to 'mat is	  0.631000


Attention for 'cat' to 'A is	  0.954400
Attention for 'cat' to 'cat is	  1.495000
Attention for 'cat' to 'sat is	  1.475400
Attention for 'cat' to 'on is	  0.843400
Attention for 'cat' to 'the is	  0.707000
Attention for 'cat' to 'mat is	  1.086500


Attention for 'sat' to 'A is	  0.942200
Attention for 'sat' to 'cat is	  1.475400
Attention for 'sat' to 'sat is	  1.457000
Attention for 'sat' to 'on is	  0.829600
Attention for 'sat' to 'the is	  0.715400
Attention for 'sat' to 'mat is	  1.060500


Attention for 'on' to 'A is	  0.475300
Attention for 'on' to 'cat is	  0.843400
Attention for 'on' to 'sat is	  0.829600
Attention for 'on' to 'on is	  0.493700
Attention for 'on' to 'the is	  0.347400
Attention for 'on' to 'mat is	  0.656500


Attention for 

##### Normalized attention

In [42]:
query = inputs
key = inputs

attention_scores = query @ key.T
attention_weights = torch.softmax(attention_scores, dim=1)

attention_weights

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

In [55]:
text = ["A", "cat", "sat", "on","the", "mat"]

for i, attention in enumerate(attention_weights):
    for j, word in enumerate(text):
        print(f"Attention weight for '{text[i]}' to '{text[j]} is\t {attention[j]: .6f}")
    print("\n")

Attention weight for 'A' to 'A is	  0.209835
Attention weight for 'A' to 'cat is	  0.200581
Attention weight for 'A' to 'sat is	  0.198149
Attention weight for 'A' to 'on is	  0.124228
Attention weight for 'A' to 'the is	  0.122049
Attention weight for 'A' to 'mat is	  0.145158


Attention weight for 'cat' to 'A is	  0.138548
Attention weight for 'cat' to 'cat is	  0.237891
Attention weight for 'cat' to 'sat is	  0.233274
Attention weight for 'cat' to 'on is	  0.123992
Attention weight for 'cat' to 'the is	  0.108182
Attention weight for 'cat' to 'mat is	  0.158114


Attention weight for 'sat' to 'A is	  0.139008
Attention weight for 'sat' to 'cat is	  0.236921
Attention weight for 'sat' to 'sat is	  0.232602
Attention weight for 'sat' to 'on is	  0.124204
Attention weight for 'sat' to 'the is	  0.110800
Attention weight for 'sat' to 'mat is	  0.156464


Attention weight for 'on' to 'A is	  0.143527
Attention weight for 'on' to 'cat is	  0.207394
Attention weight for 'on' to 'sat is	  

In [47]:
attention_weights.sum(dim=1) # attention weight sum to 1 in each row

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])

##### Context Vector

In [53]:
value = inputs

context_vectors = attention_weights @ value # (6 * 6) . (6 * 3) = (6 * 3)

context_vectors

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

In [52]:
for i, vector in enumerate(context_vectors):
    print(f"Context vector for word '{text[i]}' is  \t{vector}")

Context vector for word 'A' is  	tensor([0.4421, 0.5931, 0.5790])
Context vector for word 'cat' is  	tensor([0.4419, 0.6515, 0.5683])
Context vector for word 'sat' is  	tensor([0.4431, 0.6496, 0.5671])
Context vector for word 'on' is  	tensor([0.4304, 0.6298, 0.5510])
Context vector for word 'the' is  	tensor([0.4671, 0.5910, 0.5266])
Context vector for word 'mat' is  	tensor([0.4177, 0.6503, 0.5645])
