In [1]:
import torch

#getting the vectors for the input sequence
inputs =  torch.tensor(
    [
        [0.43, 0.15, 0.89], #Your x1
        [0.55, 0.87, 0.66], #journey x2
        [0.57, 0.85, 0.64], #starts x3
        [0.22, 0.58, 0.33], #with x4
        [0.77, 0.25, 0.10], #one  x5
        [0.05, 0.80, 0.55]  #step x6
    ]
)

In [2]:
#defining variables
x_2 = inputs[2] #second input element, query
d_in = inputs.shape[1] #row dimension of the weights matrix, it must be set to the col dimention of the input/query embedding
d_out = 2 #col dimension of the weights matrix, it can be set to any thing

In [3]:
#initializing the query, key and value weight matrices
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False) 
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False) 
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False) 

Setting the `requires_grad` to `False` here means that the values of the weights will not be optimized as the model is being trained. When we are building the model fr, we will set the vaue to true

Note that in GPT models the output dims are usually the same, but for the sake of this practical, we are making them different. 

In [4]:
print(W_query)
print(W_key)
print(W_value)

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])
Parameter containing:
tensor([[0.1366, 0.1025],
        [0.1841, 0.7264],
        [0.3153, 0.6871]])
Parameter containing:
tensor([[0.0756, 0.1966],
        [0.3164, 0.4017],
        [0.1186, 0.8274]])


In [7]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)
print(key_2)
print(value_2)

tensor([0.4300, 1.4343])
tensor([0.4361, 1.1156])
tensor([0.3879, 0.9831])


In [8]:
#on=btaining the querries, keys and values matrices
queries = inputs @ W_query
keys = inputs @ W_key
values = inputs @ W_value

print(f"Shape of Queries: {queries.shape}")
print(f"Shape of Keys: {keys.shape}")
print(f"Shape of values: {values.shape}")

Shape of Queries: torch.Size([6, 2])
Shape of Keys: torch.Size([6, 2])
Shape of values: torch.Size([6, 2])


In [9]:
#computing the anttention scores for x-2
keys_2 = keys[1]
attn_score_2 = query_2.dot(key_2)
print(attn_score_2)

tensor(1.7877)


In [10]:
#we can generalize the computation to all the attention score via matrix multiplication
attn_score_2 = query_2 @ keys.T #teh attention scores of all other words in relation to journey
print(attn_score_2)

tensor([1.2544, 1.8284, 1.7877, 1.0654, 0.5508, 1.5238])


In [12]:
#obtaining the attention scores for the whole matrix
attn_scores = queries @ keys.T
print(attn_scores)

tensor([[0.9231, 1.3545, 1.3241, 0.7910, 0.4032, 1.1330],
        [1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440],
        [1.2544, 1.8284, 1.7877, 1.0654, 0.5508, 1.5238],
        [0.6973, 1.0167, 0.9941, 0.5925, 0.3061, 0.8475],
        [0.6114, 0.8819, 0.8626, 0.5121, 0.2707, 0.7307],
        [0.8995, 1.3165, 1.2871, 0.7682, 0.3937, 1.0996]])


Each row of the attn_score matrix shows the attn_score of the ith row in relation to other words.

In [14]:
#normalizing the attention scores to onbtain the attention weights
d_keys = keys.shape[-1]

attn_weights_2 = torch.softmax(attn_score_2 / d_keys ** 0.5, dim=-1)
print(attn_weights_2)


tensor([0.1503, 0.2256, 0.2192, 0.1315, 0.0914, 0.1819])


Before using softmax to normalize the attention scores matrix, we first scale by the square root of the dim of the keys (cols) <br>

Reasons we divide with the sqrt of dimension. <br>
For stability in learning: the softmax function is sensitive to the magnitudes of the inputs. When the inputs are large, the differnces between the exponential values of each input becomes much more pronounced. this makes the softmax become peaky. Which can make the model overly confident in one particular key. <br>

TO make the variance of the dot product stable, multiplying two random numbers, increases the variance, so dividing by sqrt of the dimension keeps the varianc close to 1. the dimension of the vectors directly affects the variance.

In [17]:
#obtaining the context vectors 
context_vector_2 = attn_weights_2 @ values
print(context_vector_2)

tensor([0.3058, 0.8203])


Genaralizing the code 

In [21]:
import torch.nn as nn
#create a self attention clas
class SelfAttention_V1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__() #pareant 
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ W_value

        attn_scores = queries @ keys.T
        attm_weights = torch.softmax(
            attn_scores / keys.shape[-1] ** 0.5, dim=-1   
        )

        context_vec = attm_weights @ values
        return context_vec

In [22]:
sam_v1 = SelfAttention_V1(d_in, d_out)
print(sam_v1(inputs))

tensor([[0.3076, 0.8176],
        [0.3171, 0.8418],
        [0.3166, 0.8406],
        [0.3003, 0.8048],
        [0.2971, 0.7953],
        [0.3066, 0.8196]], grad_fn=<MmBackward0>)


Using nn.linear function instead of Parameter, which has an optimized weight initilaization scheme, contributing to more stable and effective model training

In [28]:
class SelfAttention_V2(nn.Module):
    def __init__(self, d_in, d_out, qlv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qlv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qlv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qlv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vector = attn_weights @ values
        return context_vector


In [29]:
sam_v2 = SelfAttention_V2(d_in, d_out)
print(sam_v2(inputs))

tensor([[-0.4496,  0.4982],
        [-0.4462,  0.4957],
        [-0.4462,  0.4956],
        [-0.4478,  0.4971],
        [-0.4468,  0.4963],
        [-0.4479,  0.4971]], grad_fn=<MmBackward0>)
