# Simplified version of self attention

In [1]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [2]:
# first step is to calculate the values of attention scores (w)
query = inputs[1]
attn_scores_2 = torch.matmul(inputs, query)
print("attn_scores for the second word journey is:", attn_scores_2)

attn_scores for the second word journey is: tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [3]:
# normalize the attention scores to get the attention weights
attn_weights_2 = torch.nn.functional.softmax(attn_scores_2, dim=0)
print("attn_weights for the second word journey is:", attn_weights_2)
print("sum of attn_weights:", attn_weights_2.sum())

attn_weights for the second word journey is: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
sum of attn_weights: tensor(1.)


In [4]:
# final step to calculate the context vector
# context vector is the sum of all input pairs with its attn_weights
context_vector_2 = torch.matmul(attn_weights_2, inputs)
print("context vector for the second word journey is:", context_vector_2)

context vector for the second word journey is: tensor([0.4419, 0.6515, 0.5683])


In [5]:
# lets implement context vectors for all the inputs
attn_scores = torch.matmul(inputs, inputs.T)
print("attn_scores are:", attn_scores)
attn_weights = torch.nn.functional.softmax(attn_scores, dim=1)
print("attn_weights are:", attn_weights)
print("sum of all rows of attn_weights:", attn_weights.sum(dim=1))
context_vector = torch.matmul(attn_weights, inputs)
print("context_vector:", context_vector)

attn_scores are: tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])
attn_weights are: tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
sum of all rows of attn_weights: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])
context_vector: tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 

# Implementing self-attention with trainable weights

In [6]:
# unlike the simple self-attention mechanism we did earlier, here we initialise
# 3 vectors q, k, v which are the obtained by multiplying the inputs with weight
# matrices Wq, Wk and Wv
x_2 = inputs[1]
d_in = x_2.shape[0]
d_out = 2

# initialize the weight matrices, for demo we are using requires_grad as False
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [7]:
# compute query, key and value for x_2
query_2 = torch.matmul(x_2, W_query)
key_2 = torch.matmul(x_2, W_key)
value_2 = torch.matmul(x_2, W_value)
print("query_2:", query_2)

query_2: tensor([0.4306, 1.4551])


In [8]:
# get all the keys and values
keys = torch.matmul(inputs, W_key)
values = torch.matmul(inputs, W_value)
print("keys shape:", keys.shape)
print("values shape:", values.shape)

keys shape: torch.Size([6, 2])
values shape: torch.Size([6, 2])


In [9]:
# compute the attention scores
attn_scores_2 = torch.matmul(query_2, keys.T)
print("attn_scores:", attn_scores_2)

attn_scores: tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


In [10]:
# compute the attention weights
# we use softmax with the root of the embedding dimension

d_k = keys.shape[1]
attn_weights_2 = torch.nn.functional.softmax(attn_scores_2 / d_k**0.5, dim = -1)
print("attn_weights:", attn_weights_2)
print("sum of attn_weights:", attn_weights_2.sum())

attn_weights: tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])
sum of attn_weights: tensor(1.)


In [11]:
# compute the context vector
context_vector_2 = torch.matmul(attn_weights_2, values)
print("context_vector:", context_vector_2)

context_vector: tensor([0.3061, 0.8210])


# Implementing a python call for self attention

In [12]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):
  def __init__(self, d_in, d_out):
    super().__init__()
    self.W_query = nn.Parameter(torch.rand(d_in, d_out))
    self.W_key = nn.Parameter(torch.rand(d_in, d_out))
    self.W_value = nn.Parameter(torch.rand(d_in, d_out))

  def forward(self, x):
    queries = torch.matmul(inputs, self.W_query)
    keys = torch.matmul(inputs, self.W_key)
    values = torch.matmul(inputs, self.W_value)
    attn_scores = torch.matmul(queries, keys.T)
    attn_weights = nn.functional.softmax(attn_scores/(keys.shape[-1] ** 0.5), dim = -1)
    context_vector = torch.matmul(attn_weights, values)
    return context_vector


In [13]:
# test the self attention class
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


In [14]:
# stabilizing the self attention by using nn.Linear

class SelfAttention_v2(nn.Module):
  def __init__(self, d_in, d_out, qvk_bias = False):
    super().__init__()
    self.W_query = nn.Linear(d_in, d_out, bias = qvk_bias)
    self.W_key = nn.Linear(d_in, d_out, bias = qvk_bias)
    self.W_value = nn.Linear(d_in, d_out, bias = qvk_bias)

  def forward(self, x):
    queries = self.W_query(x)
    keys = self.W_key(x)
    values = self.W_value(x)
    attn_scores = torch.matmul(queries, keys.T)
    attn_weights = nn.functional.softmax(attn_scores/(keys.shape[-1] ** 0.5), dim = -1)
    context_vector = torch.matmul(attn_weights, values)
    return context_vector

In [15]:
# test the self attention class
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


# Hiding feature words with casual attention (Masked Attention)

In [16]:
# lets compute the unmasked attention weights first

queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
values = sa_v2.W_value(inputs)
attn_scores = torch.matmul(queries, keys.T)
attn_weights = nn.functional.softmax(attn_scores/(keys.shape[-1] ** 0.5), dim = -1)
print(attn_weights)

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


In [17]:
# creating the mask using the tril operation
context_length = attn_scores.shape[0]
mask = torch.tril(torch.ones(context_length, context_length))
print(mask)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [18]:
# compute masked attn weights
masked_attn_weights = attn_weights * mask
print(masked_attn_weights)

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)


In [19]:
# normalize the masked attn weights
row_sums = masked_attn_weights.sum(dim = -1, keepdim=True)
normalized_masked_attn_weights = masked_attn_weights / row_sums
print(normalized_masked_attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)


In [20]:
# more efficient way of calculating the self attention
# is though masking the attentions scores with -inf
# before the softmax operation
mask = torch.tril(torch.ones(context_length, context_length))
masked_attn_scores = attn_scores.masked_fill(mask == 0, -torch.inf)
print(masked_attn_scores)

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)


In [21]:
# apply softmax to the attn_weights
attn_weights = torch.softmax(masked_attn_scores/(keys.shape[-1] ** 0.5), dim = -1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


In [22]:
# adding dropout to the model to eliminate overfitting
# understanding how dropout works

torch.manual_seed(123)
dropout = nn.Dropout(0.5)
print(dropout(torch.ones(5, 5)))
# as you see in the output, we have dropped 50% of the weights
# hence the remaining values are scaled up by 1/0.5 = 2 times

tensor([[2., 2., 2., 2., 2.],
        [2., 0., 2., 0., 0.],
        [0., 0., 0., 0., 2.],
        [0., 2., 0., 2., 2.],
        [0., 0., 0., 2., 2.]])


In [23]:
# applying drouput to our attn_weights
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.8966, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4921, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4350, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.0000, 0.0000, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)


# Creating a compact casual attention class

In [24]:
# adding batches to the inputs to make CasualAttention handle batched inputs
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


In [25]:
class CasualAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, qvk_bias = False):
    super().__init__()
    self.d_out = d_out
    self.W_query = nn.Linear(d_in, d_out, bias = qvk_bias)
    self.W_key = nn.Linear(d_in, d_out, bias = qvk_bias)
    self.W_value = nn.Linear(d_in, d_out, bias = qvk_bias)
    self.dropout = nn.Dropout(dropout)
    self.register_buffer("mask", torch.tril(torch.ones(context_length, context_length)))

  def forward(self, x):
    b, num_tokens, d_in = x.shape
    queries = self.W_query(x)
    keys = self.W_key(x)
    values = self.W_value(x)
    attn_scores = torch.matmul(queries, keys.transpose(1, 2))
    masked_attn_scores = attn_scores.masked_fill_(self.mask == 0, -torch.inf)
    attn_weights = nn.functional.softmax(masked_attn_scores/(keys.shape[-1] ** 0.5), dim = -1)
    attn_weights = self.dropout(attn_weights)
    context_vector = torch.matmul(attn_weights, values)
    return context_vector

In [26]:
# testing teh casual attention class
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CasualAttention(d_in, d_out, context_length, 0.0)
context_vector = ca(batch)
print("context vector shape:", context_vector.shape)

context vector shape: torch.Size([2, 6, 2])


# Multi Head Attention

In [27]:
# A wrapper class to implement Multi Head attention

class MultiHeadAttentionWrapper(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, num_heads, qvk_bias = False):
    super().__init__()
    self.heads = nn.ModuleList([
        CasualAttention(d_in, d_out, context_length, dropout, qvk_bias) for _ in range(num_heads)
    ])

  def forward(self, x):
    return torch.cat([head(x) for head in self.heads], dim = -1)


In [28]:
# test the MultiHeadAttentionWrapper
torch.manual_seed(123)
context_length = batch.shape[1]
d_in = 3
d_out = 2
num_heads = 2
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads)
context_vector = mha(batch)
print("contet_vector shape:", context_vector.shape)
print("context_vector:", context_vector)

contet_vector shape: torch.Size([2, 6, 4])
context_vector: tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)


In [29]:
# effectively calculating muli-head attention by using matrix multiplications
# and effiecient implementatio of mulit Head Attention class

class MultiHeadAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, num_heads, qvk_bias = False):
    super().__init__()
    assert (d_out % num_heads) == 0, "d_out must be divisible by num_heads"

    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = d_out // num_heads
    self.W_query = nn.Linear(d_in, d_out, bias = qvk_bias)
    self.W_key = nn.Linear(d_in, d_out, bias = qvk_bias)
    self.W_value = nn.Linear(d_in, d_out, bias = qvk_bias)
    self.out_proj = nn.Linear(d_out, d_out)
    self.dropout = nn.Dropout(dropout)
    self.register_buffer("mask", torch.tril(torch.ones(context_length, context_length)))

  def forward(self, x):
    b, num_tokens, d_in = x.shape
    queries = self.W_query(x)
    keys = self.W_key(x)
    values = self.W_value(x)

    queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
    keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
    values = values.view(b, num_tokens, self.num_heads, self.head_dim)

    queries = queries.transpose(1, 2)
    keys = keys.transpose(1, 2)
    values = values.transpose(1, 2)

    attn_scores = torch.matmul(queries, keys.transpose(2, 3))
    masked_attn_scores = attn_scores.masked_fill_(self.mask == 0, -torch.inf)
    attn_weights = nn.functional.softmax(masked_attn_scores/(keys.shape[-1] ** 0.5), dim = -1)
    attn_weights = self.dropout(attn_weights)
    context_vector = torch.matmul(attn_weights, values).transpose(1, 2)
    context_vector = context_vector.contiguous().view(b, num_tokens, self.d_out)
    context_vector = self.out_proj(context_vector)
    return context_vector

In [30]:
# testing our multihead attention class
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
num_heads = 1
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads)
context_vector = mha(batch)
print("contet_vector shape:", context_vector.shape)
print("context_vector:", context_vector)

contet_vector shape: torch.Size([2, 6, 2])
context_vector: tensor([[[0.3190, 0.4858],
         [0.2926, 0.3896],
         [0.2841, 0.3592],
         [0.2689, 0.3877],
         [0.2632, 0.3933],
         [0.2572, 0.4033]],

        [[0.3190, 0.4858],
         [0.2926, 0.3896],
         [0.2841, 0.3592],
         [0.2689, 0.3877],
         [0.2632, 0.3933],
         [0.2572, 0.4033]]], grad_fn=<ViewBackward0>)
