In [1]:
import torch

  cpu = _conversion_method_template(device=torch.device("cpu"))


In [2]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [3]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2
print('d_in:',d_in)
print('d_out:',d_out)

d_in: 3
d_out: 2


Note: in GPT-like models, the input and output dimensions are usually the same. Here we chose that the dimensions are not the same

We will initialize the three weight matrices Wq,Wk and Wv

In [4]:
torch.manual_seed(123)
w_query = torch.nn.Parameter(torch.rand(d_in,d_out),requires_grad=False)
w_key = torch.nn.Parameter(torch.rand(d_in,d_out),requires_grad=False)
w_value = torch.nn.Parameter(torch.rand(d_in,d_out),requires_grad=False)


In [5]:
print(w_query)

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])


In [6]:
print(w_key)

Parameter containing:
tensor([[0.1366, 0.1025],
        [0.1841, 0.7264],
        [0.3153, 0.6871]])


In [7]:
print(w_value)

Parameter containing:
tensor([[0.0756, 0.1966],
        [0.3164, 0.4017],
        [0.1186, 0.8274]])


Now, we can get the query,key and value matrices

In [8]:
q_2 = x_2 @ w_query
k_2 = x_2 @ w_key
v_2 = x_2 @ w_value

In [9]:
print(q_2)

tensor([0.4306, 1.4551])


We can now generalize to all the inputs

In [10]:
queries = inputs @ w_query
keys = inputs @ w_key
values = inputs @ w_value

In [11]:
print(queries)

tensor([[0.2309, 1.0966],
        [0.4306, 1.4551],
        [0.4300, 1.4343],
        [0.2355, 0.7990],
        [0.2983, 0.6565],
        [0.2568, 1.0533]])


Now let's compute the attension score w22

In [12]:
keys_2 = keys[1]
attn_score_22 = q_2.dot(keys_2)
print(attn_score_22)

tensor(1.8524)


Computing attention score between the second query ("journey") and all keys.

In [13]:
attn_score_2 = q_2 @ keys.T
print(attn_score_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


We can generalize all attention scores of the queries and keys

In [14]:
attn_scores = queries @ keys.T 
print(attn_scores)

tensor([[0.9231, 1.3545, 1.3241, 0.7910, 0.4032, 1.1330],
        [1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440],
        [1.2544, 1.8284, 1.7877, 1.0654, 0.5508, 1.5238],
        [0.6973, 1.0167, 0.9941, 0.5925, 0.3061, 0.8475],
        [0.6114, 0.8819, 0.8626, 0.5121, 0.2707, 0.7307],
        [0.8995, 1.3165, 1.2871, 0.7682, 0.3937, 1.0996]])


Computing the attention weights by scaling the attention scores and using the softmax function.

In [15]:
d_keys = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_score_2/d_keys**0.5,dim=-1)
print(attn_weights_2)
print(d_keys)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])
2


In [16]:
attn_weights = torch.softmax(attn_scores/d_keys**0.5,dim=-1)
print(attn_weights)

tensor([[0.1551, 0.2104, 0.2059, 0.1413, 0.1074, 0.1799],
        [0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820],
        [0.1503, 0.2256, 0.2192, 0.1315, 0.0914, 0.1819],
        [0.1591, 0.1994, 0.1962, 0.1477, 0.1206, 0.1769],
        [0.1610, 0.1949, 0.1923, 0.1501, 0.1265, 0.1752],
        [0.1557, 0.2092, 0.2048, 0.1419, 0.1089, 0.1794]])


Dividing by sqrt(keys_dim):
* For stability in learning: the softmax function is sensitive to the magnitude of its inputs. When the inputs are large, the differences between the exponential values of each input value are much more pronounced. This causes the softmax output to become "peaky". Where the highest value receives almost all the probability mass.
* Particularly in transformers, if the dot products between query and key vectors become too large, the attention score becomes very large. This results in a very sharp softmax distribution, making the model overly confident in one particular key.
* It is also useful to make the variance of the dot product stable.
* The dot product of Q and K increases the variance because multiplying two random numbers increases the variance.
* This variance increases if the dimension grows.
* Dividing by sqrt(dim) keeps the variance close to one.

Final step is to compute the context vectors

In [17]:
context_vecs = attn_weights @ values
print(context_vecs)

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]])


Implementing a compact self attention class

In [18]:
import torch.nn as nn 

class SelfAttention_v1(nn.Module):
    def __init__(self,d_in,d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in,d_out))
        self.W_key = nn.Parameter(torch.rand(d_in,d_out))
        self.W_value = nn.Parameter(torch.rand(d_in,d_out))
    def forward(self,x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores/keys.shape[-1]**0.5,dim=-1
        )
        context_vec = attn_weights @ values
        return context_vec

* SelfAttention_v1 is a class derived from nn.Module, which is a fundamental building block of PyTorch models and that provides necessary functionalities for model layer creation and management.

In [19]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in,d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


We can improve the SelfAttention_V1 by utilizing PyTorch's nn.Linear layers, which effectively perform matrix multiplication when the bias units are disabled.
* nn.Linear has an optimized weight initialization scheme, contributing to more stable and effective model training.

In [20]:
class SelfAttention_v2(nn.Module):
    
    def __init__(self,d_in,d_out,qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_key = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_value = nn.Linear(d_in,d_out,bias=qkv_bias)
        
    def forward(self,x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_query(x)
        
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5,dim=-1)
        context_vec = attn_weights @ values
        
        return context_vec

In [21]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in,d_out)
print(sa_v2(inputs))

tensor([[ 0.6729, -0.3148],
        [ 0.6757, -0.3146],
        [ 0.6757, -0.3146],
        [ 0.6750, -0.3164],
        [ 0.6749, -0.3168],
        [ 0.6752, -0.3157]], grad_fn=<MmBackward0>)
