## Self Attention with Trainable Weights

Here we want to compute context vectors as weighted sums over the input vectors certain to input element.

In [1]:
import torch

inputs = torch.tensor(
    [[0.43, 0.15, 0.89],
     [0.55, 0.87, 0.66],
     [0.57, 0.85, 0.64],
     [0.22, 0.58, 0.33],
     [0.77, 0.25, 0.10],
     [0.05, 0.80, 0.55]]
)

In [2]:
x_1 = inputs[1]
d_in = inputs.shape[1] # embedding dimension = 3
d_out = 2 # but in general the in and out are of same dimension in GPT like architectures

In [8]:
torch.manual_seed(123)

w_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
w_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
w_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [9]:
w_query.shape, w_key.shape, w_value.shape

(torch.Size([3, 2]), torch.Size([3, 2]), torch.Size([3, 2]))

In [10]:
query_1 = x_1 @ w_query
key_1 = x_1 @ w_key
value_1 = x_1 @ w_value

In [11]:
query_1, key_1, value_1

(tensor([0.4306, 1.4551]), tensor([0.4433, 1.1419]), tensor([0.3951, 1.0037]))

In [13]:
query = inputs @ w_query
key = inputs @ w_key

query.shape, key.shape

(torch.Size([6, 2]), torch.Size([6, 2]))

In [14]:
# attention score

key_1 = key[1]
attention_score_11 = query_1 @ key_1

print(f"Attention score for 'cat' and 'cat' is {attention_score_11.item()}")

Attention score for 'cat' and 'cat' is 1.8523844480514526


In [15]:
attention_scores = query @ key.T

attention_scores

tensor([[0.9231, 1.3545, 1.3241, 0.7910, 0.4032, 1.1330],
        [1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440],
        [1.2544, 1.8284, 1.7877, 1.0654, 0.5508, 1.5238],
        [0.6973, 1.0167, 0.9941, 0.5925, 0.3061, 0.8475],
        [0.6114, 0.8819, 0.8626, 0.5121, 0.2707, 0.7307],
        [0.8995, 1.3165, 1.2871, 0.7682, 0.3937, 1.0996]])

#### Normalization

In [20]:
d_k = key.shape[-1] # dimension of key i.e 2
attention_weight_1 = torch.softmax(attention_scores[1] / (d_k ** 0.5), dim=-1)

attention_weight_1 # weight for cat after scaling and normalizing

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])

In [21]:
attention_weight_1.sum()

tensor(1.0000)

In [22]:
attention_weight_scaled_normalized = torch.softmax(attention_scores / (d_k ** 0.5), dim=-1)

attention_weight_scaled_normalized

tensor([[0.1551, 0.2104, 0.2059, 0.1413, 0.1074, 0.1799],
        [0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820],
        [0.1503, 0.2256, 0.2192, 0.1315, 0.0914, 0.1819],
        [0.1591, 0.1994, 0.1962, 0.1477, 0.1206, 0.1769],
        [0.1610, 0.1949, 0.1923, 0.1501, 0.1265, 0.1752],
        [0.1557, 0.2092, 0.2048, 0.1419, 0.1089, 0.1794]])

These are trainable weights and model will learn to adust these weights as we train so that contextual representations are learned in training

In [24]:
attention_weight_scaled_normalized.sum(dim=1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])

In [30]:
text = ["A", "cat", "sat", "on", "the", "mat"]

for i, attention in enumerate(attention_weight_scaled_normalized[1]):
    print(f"Attention weight for 'cat' and '{text[i]}' is   \t{attention.item(): .6f}")

Attention weight for 'cat' and 'A' is   	 0.150019
Attention weight for 'cat' and 'cat' is   	 0.226384
Attention weight for 'cat' and 'sat' is   	 0.219872
Attention weight for 'cat' and 'on' is   	 0.131070
Attention weight for 'cat' and 'the' is   	 0.090629
Attention weight for 'cat' and 'mat' is   	 0.182026


In [41]:
context_vector_1 = attention_weight_scaled_normalized[1] @ value

context_vector_1 # new representation of 'cat' in 2 dimension after self attention mechanism

tensor([0.3061, 0.8210])

## Self Attention Class

In [33]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_out = d_out
        self.d_in = d_in
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
        
    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        
        attention_scores = queries @ keys.T
        attention_weights_scaled_normalized = torch.softmax(attention_scores / (self.d_out ** 0.5), dim=-1)
        
        context_vector = attention_weights_scaled_normalized @ values
        return context_vector

In [38]:
torch.manual_seed(123)

self_attention_v1 = SelfAttention_v1(d_in=3, d_out=2)

context_vectors = self_attention_v1(inputs)

print(context_vectors) # our contextual embeddings after projection into 2 dimension from self attention

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


We can use pytorch `nn.Linear` because it provides out of the box support for many features and optimization during training weights of matrices

In [49]:
torch.manual_seed(789)

class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_out = d_out
        self.d_in = d_in
        self.W_query = nn.Linear(in_features=d_in, out_features=d_out, bias=False)
        self.W_key = nn.Linear(in_features=d_in, out_features=d_out, bias=False)
        self.W_value = nn.Linear(in_features=d_in, out_features=d_out, bias=False)
    
    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        attention_scores = queries @ keys.T
        attention_weights = torch.softmax( attention_scores / ( keys.shape[1] ** 0.5), dim=-1)
        
        context_vectors = attention_weights @ values
        
        return context_vectors

In [50]:
self_attention = SelfAttention(d_in=3, d_out=2)

context_vectors = self_attention(inputs)

context_vectors

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)

In [52]:
inputs

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])

In [53]:
self_attention = SelfAttention(d_in=3, d_out=2)

queries = self_attention.W_query(inputs)
keys = self_attention.W_key(inputs)

attention_scores = queries @ keys.T
attention_weights = torch.softmax(attention_scores / (keys.shape[1] ** 0.5), dim=-1)


attention_weights

tensor([[0.1766, 0.1701, 0.1699, 0.1597, 0.1618, 0.1620],
        [0.1772, 0.1720, 0.1717, 0.1580, 0.1596, 0.1615],
        [0.1769, 0.1719, 0.1716, 0.1582, 0.1597, 0.1616],
        [0.1725, 0.1696, 0.1695, 0.1618, 0.1627, 0.1638],
        [0.1687, 0.1694, 0.1692, 0.1637, 0.1634, 0.1656],
        [0.1758, 0.1704, 0.1702, 0.1598, 0.1615, 0.1623]],
       grad_fn=<SoftmaxBackward0>)