# Studying the mechanism of self-attention

[GitHub link](https://github.com/foobar167/junkyard/blob/master/fine_tuning/Attention_study.ipynb) and
[Colab link](https://colab.research.google.com/drive/1912LC9Tn7lyBzIwlZd2_QuNJqBtZzZ2D) to this Python script.

* Video [Understanding the Self-Attention Mechanism in **8 min**](https://youtu.be/W28LfOld44Y) - theory
* Video [Multi-head Attention Mechanism Explained](https://youtu.be/W6s9i02EiR0) in **4 min** - theory
* Video [Implementing the Self-Attention Mechanism from Scratch in PyTorch](https://youtu.be/ZPLym9rJtM8) - **4 min** to implement + **11 min** testing and playing with code
* Article [Attention Is All You Need](https://ar5iv.labs.arxiv.org/html/1706.03762)

![Self-Attention scheme](./data/Self-Attention_Layer.jpg)

## Attention implementation

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Attention(nn.Module):
    """ Attention mechanism implementation """
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.Q = nn.Linear(d_in, d_out)  # initialize fully connected layer with random values
        self.K = nn.Linear(d_in, d_out)
        self.V = nn.Linear(d_in, d_out)

    def forward(self, x):
        queries = self.Q(x)  # function performs the math operation y=xA^{T}+b
        keys = self.K(x)
        values = self.V(x)
        # torch.bmm  - Batch Matrix-Matrix Product
        # Use torch.matmul for Multi-Head Attention
        scores = torch.bmm(queries, keys.transpose(1, 2))
        # Normalize. Normalized scores do not depend on the number of elements
        scores = scores * (self.d_out ** -0.5)
        # Apply the softmax so that the sum of the values in a row equals 1.0
        attention = F.softmax(scores, dim=-1)  # apply sortmax to the last dimension
        hidden_states = torch.bmm(attention, values)
        return hidden_states

## Implementation of a very simple tokenizer

In [10]:
BOS_token = 0  # BOS (Beginning of Sequence) the same as SOS (Start of Sequence)
EOS_token = 1  # EOS (End of Sequence)

index2word = {
    BOS_token: "BOS",
    EOS_token: "EOS",
}

text = "How are you doing ? I am good and you ? I am fine, thank you ."
vocabulary = set(text.lower().split(" "))  # set of unique words

for word in vocabulary:
    index2word[len(index2word)] = word

print(*list(index2word.items()), sep="\n")

word2index = {w: i for i, w in index2word.items()}
print("\n", *list(word2index.items()), sep="\n")

(0, 'BOS')
(1, 'EOS')
(2, 'how')
(3, 'thank')
(4, '?')
(5, 'are')
(6, 'and')
(7, 'doing')
(8, 'good')
(9, 'fine,')
(10, 'am')
(11, 'you')
(12, 'i')
(13, '.')


('BOS', 0)
('EOS', 1)
('how', 2)
('thank', 3)
('?', 4)
('are', 5)
('and', 6)
('doing', 7)
('good', 8)
('fine,', 9)
('am', 10)
('you', 11)
('i', 12)
('.', 13)


In [11]:
def convert2tensor(sentence):
    """ Convert sentence to tensor """
    words_list = sentence.lower().split(" ")
    indices = [word2index[word] for word in words_list]
    # Add new dimension for batches using view(1,-1) at the beginning of the tensor
    return torch.tensor(indices, dtype=torch.long).view(1, -1)


def convert2sentence(tensor):
    """ Convert tensor with tokens to sentence """
    indices = tensor.tolist()[0]
    words_list = [index2word[index] for index in indices]
    return " ".join(words_list).capitalize()


sentence = "How are you doing ?"

input_tensor = convert2tensor(sentence)
print(f"Input tensor: {input_tensor}")
print(f"Size of input tensor: {input_tensor.size()}")

sentence_converted = convert2sentence(input_tensor)
print(f"Converted sentence: {sentence_converted}")

Input tensor: tensor([[ 2,  5, 11,  7,  4]])
Size of input tensor: torch.Size([1, 5])
Converted sentence: How are you doing ?


## Create a very small neural network with attention layer

In [12]:
HIDDEN_SIZE = 12
VOCAB_SIZE = len(word2index)

sentence = "How are you doing ?"
input_tensor = convert2tensor(sentence)

embedding = nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embedded = embedding(input_tensor)

print(embedded)
# size: [batch_size, sentence_length, hidden_size]
print(f"Size of embedded tensor: {embedded.size()}")

tensor([[[ 0.5221,  0.7157, -0.3543, -0.5256,  0.1637, -0.0708, -0.0091,
           1.0332, -1.3994,  3.2293,  0.3880,  1.6749],
         [-0.3050,  0.1364,  0.5668, -0.1816,  3.0615, -0.6120,  1.1956,
          -0.5862,  0.6557,  0.8610, -0.2304, -0.8637],
         [ 1.7445, -0.2422,  1.5841,  0.6004, -0.7200, -0.1400,  0.5470,
           0.0795,  1.4277, -2.1220,  0.0203, -0.0942],
         [ 0.5990, -0.2254, -0.3385,  0.4682,  0.3350, -1.9881,  0.5771,
          -1.4849,  0.3929,  3.1484, -1.7804,  0.1608],
         [-1.6338,  1.6158, -0.2962, -0.2781,  0.7622,  0.8868, -1.0893,
          -0.0453, -0.1252, -0.2076,  0.1610,  0.3798]]],
       grad_fn=<EmbeddingBackward0>)
Size of embedded tensor: torch.Size([1, 5, 12])


In [13]:
attention = Attention(HIDDEN_SIZE, HIDDEN_SIZE)  # initialize object

hidden_states = attention(embedded)  # call forward funcion

print(hidden_states)
print(hidden_states.size())

tensor([[[ 0.1953,  0.2028, -0.2137, -0.0722, -0.0161, -0.0196, -0.1961,
          -0.0577,  0.0434, -0.2525, -0.4601, -0.1460],
         [ 0.3746,  0.3702, -0.0224, -0.2034, -0.0802, -0.2717, -0.1027,
          -0.2241,  0.1758, -0.2005, -0.4023, -0.1239],
         [ 0.5535,  0.4315,  0.0739, -0.3023, -0.0796, -0.2415,  0.0369,
          -0.2942,  0.2452, -0.1164, -0.3334, -0.0777],
         [ 0.5658,  0.5691,  0.0541, -0.2639, -0.1582, -0.4898, -0.0853,
          -0.3069,  0.3372, -0.1678, -0.4179, -0.0766],
         [-0.1154, -0.0788, -0.0958, -0.0858,  0.0741,  0.0317, -0.1548,
          -0.0803, -0.1996, -0.2901, -0.3319, -0.2645]]],
       grad_fn=<BmmBackward0>)
torch.Size([1, 5, 12])


In [14]:
d_in = HIDDEN_SIZE
d_out = HIDDEN_SIZE

Q = nn.Linear(d_in, d_out)  # initialize fully connected layer with random values
K = nn.Linear(d_in, d_out)
V = nn.Linear(d_in, d_out)

queries = Q(embedded)  # function performs the math operation y=xA^{T}+b
keys = K(embedded)
values = V(embedded)

print(queries.size(), keys.size(), values.size(), sep="\n")

torch.Size([1, 5, 12])
torch.Size([1, 5, 12])
torch.Size([1, 5, 12])


In [15]:
scores = torch.bmm(queries, keys.transpose(1, 2))
print(scores.size())

scores = scores * (d_out ** -0.5)  # normalize
attention_tensor = F.softmax(scores, dim=-1)
print("", attention_tensor.size(), sep="\n")
print(f"Sum of the rows of the last dimension:\n\t{attention_tensor.sum(dim=-1)}")
print(attention_tensor)

hidden_states = torch.bmm(attention_tensor, values)
print("", hidden_states.size(), sep="\n")
print(hidden_states)

torch.Size([1, 5, 5])

torch.Size([1, 5, 5])
Sum of the rows of the last dimension:
	tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000]], grad_fn=<SumBackward1>)
tensor([[[0.1559, 0.1356, 0.2756, 0.2237, 0.2092],
         [0.2023, 0.1844, 0.2509, 0.1668, 0.1955],
         [0.1155, 0.2493, 0.1633, 0.1428, 0.3291],
         [0.3623, 0.1713, 0.1715, 0.2217, 0.0732],
         [0.1743, 0.1350, 0.2164, 0.2106, 0.2637]]],
       grad_fn=<SoftmaxBackward0>)

torch.Size([1, 5, 12])
tensor([[[ 0.0576, -0.3408, -0.4609, -0.2570, -0.1346, -0.2015, -0.0409,
          -0.4859,  0.0735,  0.4759, -0.1061, -0.3496],
         [ 0.0649, -0.3369, -0.4622, -0.3170, -0.1237, -0.1676, -0.0298,
          -0.4763,  0.0419,  0.4923, -0.1804, -0.3864],
         [ 0.1752, -0.3712, -0.5860, -0.3137, -0.0623, -0.0788, -0.0824,
          -0.3972,  0.0887,  0.3973, -0.0809, -0.3958],
         [ 0.0655, -0.2313, -0.5443, -0.5866, -0.2000, -0.3241,  0.0220,
          -0.3987, -0.0678,  0.7846, -0.4739, -0.2767],
        

## Multi-head Attention Implementaion

Multi-head Attention scheme

![Multi-head Attention scheme](./data/Multihead_Attention.jpg)

Attention head scheme

![Attention head scheme](./data/Attention_head.jpg)

In [16]:
import types
import torch
import torch.nn as nn
import torch.nn.functional as F


config = types.SimpleNamespace(
    vocab_size = len(word2index),
    embed_dim = 12,  # hidden size
    num_heads = 3,
    seq_len = 1024,
    attention_dropout = 0.1,
    residual_dropout = 0.1,
)


class Multihead_Attention(nn.Module):
    def __init__(self, config):
        super().__init__()  # initialize nn.Module
        self.embed_dim = config.embed_dim
        self.n_heads = config.num_heads
        assert self.embed_dim % self.n_heads == 0, "embedding dimension must be divisible by number of heads"
        self.head_size = self.embed_dim // self.n_heads  # 4 = 12 // 3

        self.c_attn = nn.Linear(self.embed_dim,
                                self.embed_dim * 3,
                                bias=True)

        self.scale = self.head_size ** -0.5

        # The self.register_buffer() method in PyTorch's nn.Module is used to
        # register a NON-LEARNABLE tensor
        self.register_buffer(
            "mask",  # self.mask
            torch.tril(torch.ones(1, 1, config.seq_len, config.seq_len)) == 0,
        )  # initialize non-learnable boolean (True/False) mask tensor of fize (1,1,1024,1024)

        self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)

        self.attn_dropout = nn.Dropout(config.attention_dropout)  # dropout for attention matrix
        self.resid_dropout = nn.Dropout(config.residual_dropout)  # dropout for resulting matrix


    def forward(self, x):
        print("\nForwarding embedded tensor...")  # test message

        b, t, c = x.shape  # batch_size, sentence_length, embed_dim
        q, k, v = self.c_attn(x).chunk(3, dim=-1)  # query, key, value

        # (0, 2, 1, 3) --> batch * n_heads * t * head_size
        q = q.view(b, t, self.n_heads, self.head_size).permute(0, 2, 1, 3)
        k = k.view(b, t, self.n_heads, self.head_size).permute(0, 2, 1, 3)
        v = v.view(b, t, self.n_heads, self.head_size).permute(0, 2, 1, 3)

        scores = (q@k.transpose(-2, -1)) * self.scale
        scores = scores.masked_fill(
            self.mask[:, :, :t, :t],  # truncate (1,1,1024,1024) to (1,1,t,t)
            float("-inf"),
        )
        scores = F.softmax(scores, dim=-1)
        scores = self.attn_dropout(scores)

        attention = scores @ v
        attention = attention.permute(0, 2, 1, 3).contiguous().view(b, t, c)

        out = self.c_proj(attention)
        out = self.resid_dropout(out)

        return out

## Test multi-head attention

In [17]:
sentence = "How are you doing ?"
input_tensor = convert2tensor(sentence)

embedding = nn.Embedding(config.vocab_size, config.embed_dim)
embedded = embedding(input_tensor)

# tensor size: [batch_size, sentence_length, hidden_size]
print(f"Embedded tensor size: {embedded.size()}")

multihead_attention = Multihead_Attention(config)  # initialize object

hidden_states = multihead_attention(embedded)  # call forward function
print(f"Hidden states size: {hidden_states.size()}")

Embedded tensor size: torch.Size([1, 5, 12])

Forwarding embedded tensor...
Hidden states size: torch.Size([1, 5, 12])
