# Studying the mechanism of self-attention

[GitHub link](https://github.com/foobar167/junkyard/blob/master/fine_tuning/Attention_study.ipynb) and
[Colab link](https://colab.research.google.com/drive/1912LC9Tn7lyBzIwlZd2_QuNJqBtZzZ2D) to this Python script.

* Video [Understanding the Self-Attention Mechanism in **8 min**](https://youtu.be/W28LfOld44Y) - theory
* Video [Multi-head Attention Mechanism Explained](https://youtu.be/W6s9i02EiR0) in **4 min** - theory
* Video [Implementing the Self-Attention Mechanism from Scratch in PyTorch](https://youtu.be/ZPLym9rJtM8) - **4 min** to implement + **11 min** testing and playing with code
* Article [Attention Is All You Need](https://ar5iv.labs.arxiv.org/html/1706.03762)

![Self-Attention scheme](./data/Self-Attention_Layer.jpg)

## Attention implementation

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Attention(nn.Module):
    """ Attention mechanism implementation """
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.Q = nn.Linear(d_in, d_out)  # initialize fully connected layer with random values
        self.K = nn.Linear(d_in, d_out)
        self.V = nn.Linear(d_in, d_out)

    def forward(self, x):
        queries = self.Q(x)  # function performs the math operation y=xA^{T}+b
        keys = self.K(x)
        values = self.V(x)
        # torch.bmm  - Batch Matrix-Matrix Product
        # Use torch.matmul for Multi-Head Attention
        scores = torch.bmm(queries, keys.transpose(1, 2))
        # Normalize. Normalized scores do not depend on the number of elements
        scores = scores * (self.d_out ** -0.5)
        # Apply the softmax so that the sum of the values in a row equals 1.0
        attention = F.softmax(scores, dim=-1)  # apply sortmax to the last dimension
        hidden_states = torch.bmm(attention, values)
        return hidden_states

## Implementation of a very simple tokenizer

In [2]:
BOS_token = 0  # BOS (Beginning of Sequence) the same as SOS (Start of Sequence)
EOS_token = 1  # EOS (End of Sequence)

index2word = {
    BOS_token: "BOS",
    EOS_token: "EOS",
}

text = "How are you doing ? I am good and you ?"
vocabulary = set(text.lower().split(" "))  # set of unique words

for word in vocabulary:
    index2word[len(index2word)] = word

print(*list(index2word.items()), sep="\n")

word2index = {w: i for i, w in index2word.items()}
print("\n", *list(word2index.items()), sep="\n")

(0, 'BOS')
(1, 'EOS')
(2, 'how')
(3, '?')
(4, 'are')
(5, 'and')
(6, 'doing')
(7, 'good')
(8, 'am')
(9, 'you')
(10, 'i')


('BOS', 0)
('EOS', 1)
('how', 2)
('?', 3)
('are', 4)
('and', 5)
('doing', 6)
('good', 7)
('am', 8)
('you', 9)
('i', 10)


In [3]:
def convert2tensor(sentence):
    """ Convert sentence to tensor """
    words_list = sentence.lower().split(" ")
    indices = [word2index[word] for word in words_list]
    # Add new dimension for batches using view(1,-1) at the beginning of the tensor
    return torch.tensor(indices, dtype=torch.long).view(1, -1)


def convert2sentence(tensor):
    """ Convert tensor with tokens to sentence """
    indices = tensor.tolist()[0]
    words_list = [index2word[index] for index in indices]
    return " ".join(words_list).capitalize()


sentence = "How are you doing ?"

input_tensor = convert2tensor(sentence)
print(f"Input tensor: {input_tensor}")
print(f"Size of input tensor: {input_tensor.size()}")

sentence_converted = convert2sentence(input_tensor)
print(f"Converted sentence: {sentence_converted}")

Input tensor: tensor([[2, 4, 9, 6, 3]])
Size of input tensor: torch.Size([1, 5])
Converted sentence: How are you doing ?


## Create a very small neural network with attention layer

In [4]:
HIDDEN_SIZE = 12
VOCAB_SIZE = len(word2index)

sentence = "How are you doing ?"
input_tensor = convert2tensor(sentence)

embedding = nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embedded = embedding(input_tensor)

print(embedded)
# size: [batch_size, sentence_length, hidden_size]
print(f"Size of embedded tensor: {embedded.size()}")

tensor([[[-1.0587,  0.5892, -0.8730,  1.5204, -0.2458, -0.4056,  1.0812,
          -0.8234,  0.0461, -1.3530, -0.4353,  0.5372],
         [ 0.8191,  0.8516,  0.0484,  1.8933, -2.2349,  0.9433,  0.1067,
           0.8285,  1.0547,  0.1152,  0.5180,  0.7104],
         [-0.3938, -0.0363, -1.0047,  1.2835,  0.6363, -0.5910, -0.0508,
           0.3778, -0.0836, -0.7666, -0.5286,  1.2667],
         [-1.9934,  0.2951, -0.7795,  0.3300,  0.4425,  0.7721,  0.3109,
           2.1065, -0.0852,  0.6770, -0.6713,  0.2350],
         [ 0.4861, -1.1047, -0.8358, -0.4331, -1.5519, -0.5540,  0.7782,
          -1.2695,  0.5603,  0.5940, -0.4143,  1.2459]]],
       grad_fn=<EmbeddingBackward0>)
Size of embedded tensor: torch.Size([1, 5, 12])


In [5]:
attention = Attention(HIDDEN_SIZE, HIDDEN_SIZE)  # initialize object

hidden_states = attention(embedded)  # call forward funcion

print(hidden_states)
print(hidden_states.size())

tensor([[[ 0.0850,  0.0218, -0.0600, -0.3665,  0.5640, -0.2356,  0.0426,
          -0.1988,  0.4133,  0.5180,  0.4875,  0.2378],
         [ 0.1598, -0.0544, -0.0115, -0.3971,  0.5660, -0.2071,  0.0570,
          -0.1700,  0.3649,  0.4253,  0.4144,  0.2667],
         [ 0.0835,  0.0121, -0.0396, -0.3072,  0.5538, -0.2886,  0.0469,
          -0.1749,  0.3380,  0.4791,  0.4640,  0.2334],
         [ 0.0735,  0.0074, -0.0534, -0.3287,  0.5718, -0.2741,  0.0937,
          -0.2135,  0.3769,  0.5014,  0.4872,  0.1904],
         [ 0.0896,  0.0123, -0.0451, -0.3228,  0.5547, -0.2576,  0.0227,
          -0.1709,  0.3516,  0.4803,  0.4584,  0.2517]]],
       grad_fn=<BmmBackward0>)
torch.Size([1, 5, 12])


In [6]:
d_in = HIDDEN_SIZE
d_out = HIDDEN_SIZE

Q = nn.Linear(d_in, d_out)  # initialize fully connected layer with random values
K = nn.Linear(d_in, d_out)
V = nn.Linear(d_in, d_out)

queries = Q(embedded)  # function performs the math operation y=xA^{T}+b
keys = K(embedded)
values = V(embedded)

print(queries.size(), keys.size(), values.size(), sep="\n")

torch.Size([1, 5, 12])
torch.Size([1, 5, 12])
torch.Size([1, 5, 12])


In [7]:
scores = torch.bmm(queries, keys.transpose(1, 2))
print(scores.size())

scores = scores * (d_out ** -0.5)  # normalize
attention_tensor = F.softmax(scores, dim=-1)
print("", attention_tensor.size(), sep="\n")
print(f"Sum of the rows of the last dimension:\n\t{attention_tensor.sum(dim=-1)}")
print(attention_tensor)

hidden_states = torch.bmm(attention_tensor, values)
print("", hidden_states.size(), sep="\n")
print(hidden_states)

torch.Size([1, 5, 5])

torch.Size([1, 5, 5])
Sum of the rows of the last dimension:
	tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000]], grad_fn=<SumBackward1>)
tensor([[[0.2148, 0.2227, 0.2044, 0.1723, 0.1857],
         [0.1697, 0.1840, 0.2207, 0.2094, 0.2162],
         [0.1942, 0.2211, 0.2339, 0.1662, 0.1847],
         [0.2044, 0.1757, 0.2147, 0.1637, 0.2415],
         [0.1964, 0.1768, 0.2207, 0.1553, 0.2508]]],
       grad_fn=<SoftmaxBackward0>)

torch.Size([1, 5, 12])
tensor([[[ 0.0998, -0.1397,  0.2777,  0.5193,  0.1172, -0.8438,  0.0805,
           0.2359, -0.1140,  0.3579, -0.1809,  0.3405],
         [ 0.1004, -0.1379,  0.2651,  0.5714,  0.1167, -0.8165,  0.0713,
           0.2358, -0.1519,  0.3807, -0.2103,  0.2778],
         [ 0.0875, -0.1543,  0.2816,  0.5295,  0.1123, -0.8349,  0.0684,
           0.2363, -0.1320,  0.3502, -0.1821,  0.3394],
         [ 0.0992, -0.1796,  0.2455,  0.5202,  0.1077, -0.8397,  0.1039,
           0.2282, -0.1397,  0.3218, -0.2293,  0.2924],
        

## Multi-head Attention Implementaion

Multi-head Attention scheme

![Multi-head Attention scheme](./data/Multihead_Attention.jpg)

Attention head scheme

![Attention head scheme](./data/Attention_head.jpg)

In [8]:
import types
import torch
import torch.nn as nn
import torch.nn.functional as F


config = types.SimpleNamespace(
    vocab_size = len(word2index),
    embed_dim = 12,  # hidden size
    num_heads = 3,
    seq_len = 1024,
    attention_dropout = 0.1,
    residual_dropout = 0.1,
)


class Multihead_Attention(nn.Module):
    def __init__(self, config):
        super().__init__()  # initialize nn.Module
        self.embed_dim = config.embed_dim
        self.n_heads = config.num_heads
        assert self.embed_dim % self.n_heads == 0, "embedding dimension must be divisible by number of heads"
        self.head_size = self.embed_dim // self.n_heads  # 4 = 12 // 3

        self.c_attn = nn.Linear(self.embed_dim,
                                self.embed_dim * 3,
                                bias=True)

        self.scale = self.head_size ** -0.5

        # The self.register_buffer() method in PyTorch's nn.Module is used to
        # register a NON-LEARNABLE tensor
        self.register_buffer(
            "mask",  # self.mask
            torch.tril(torch.ones(1, 1, config.seq_len, config.seq_len)) == 0,
        )  # initialize non-learnable boolean (True/False) mask tensor of fize (1,1,1024,1024)

        self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)

        self.attn_dropout = nn.Dropout(config.attention_dropout)  # dropout for attention matrix
        self.resid_dropout = nn.Dropout(config.residual_dropout)  # dropout for resulting matrix


    def forward(self, x):
        print("\nForwarding embedded tensor...")  # test message

        b, t, c = x.shape  # batch_size, sentence_length, embed_dim
        q, k, v = self.c_attn(x).chunk(3, dim=-1)  # query, key, value

        # (0, 2, 1, 3) --> batch * n_heads * t * head_size
        q = q.view(b, t, self.n_heads, self.head_size).permute(0, 2, 1, 3)
        k = k.view(b, t, self.n_heads, self.head_size).permute(0, 2, 1, 3)
        v = v.view(b, t, self.n_heads, self.head_size).permute(0, 2, 1, 3)

        scores = (q@k.transpose(-2, -1)) * self.scale
        scores = scores.masked_fill(
            self.mask[:, :, :t, :t],  # truncate (1,1,1024,1024) to (1,1,t,t)
            float("-inf"),
        )
        scores = F.softmax(scores, dim=-1)
        scores = self.attn_dropout(scores)

        attention = scores @ v
        attention = attention.permute(0, 2, 1, 3).contiguous().view(b, t, c)

        out = self.c_proj(attention)
        out = self.resid_dropout(out)

        return out

## Test multi-head attention

In [9]:
sentence = "How are you doing ?"
input_tensor = convert2tensor(sentence)

embedding = nn.Embedding(config.vocab_size, config.embed_dim)
embedded = embedding(input_tensor)

# tensor size: [batch_size, sentence_length, hidden_size]
print(f"Embedded tensor size: {embedded.size()}")

multihead_attention = Multihead_Attention(config)  # initialize object

hidden_states = multihead_attention(embedded)  # call forward function
print(f"Hidden states size: {hidden_states.size()}")

Embedded tensor size: torch.Size([1, 5, 12])

Forwarding embedded tensor...
Hidden states size: torch.Size([1, 5, 12])
