Understanding encoders is a key point to understand how transformers work. In this notebook, we will present what encoders are and why they are useful by themselves and inside more complex architectures like the famous transformer architecture. 

# I. How do encoders work ? 

###  1. What are encoders ? 

Encoders are neural-network which, once trained, will give a mathematical representation of an input as the output. This sentence alone is not really enough to actually understand what they are and why they are useful, so we'll use the example of words.

For example, let's say that you enter the sentence "the cat likes cheese" into an encoder, what might want to get as an output is an embedding vector (also called embedding tensor/feature vector/feature tensor) of each of these words. Just like the human brain does, the encoder allows us to get more information than what the letters composing the word give us. 

In this example, the word "cats" will have an embedding vector that, if we look at our representation space, will probably be close to "dog", "mouse" and other animals. Probably even better than that, the feature vector may contain information about cats being carnivorous, having claws or usually being domesticated. But to be clear, all of this cannot be clearly read by a human. The only thing that we can do is to make a projection of these high-dimensional representations and visualize how close "similar" words are. The dimension of these vectors is defined in the architecture of the model. 

Even better than that, the embedding will contain the meaning of a word in the context of the sentence: in this example, the embedding will be influenced by the fact that the word "cat" is positioned before the verb "to like" in this sentence. It is how we can differenciate different meanings of the same word depending on context. For instance:
- "bank" in "river bank" vs "saving bank" 
- "apple" in "apple fruit" vs "Apple company"

This is done by the **self-attention mechanism**, a very important part of how encoders work. 

### 2. How does the self-attention mechanism in encoders work ? 

The **self-attention mechanism** is what makes encoders so powerful:

1. **Query, Key, Value**: Each word creates three vectors - think of it like asking questions (Query), having an identity (Key), and containing information (Value).

2. **Attention scores**: Each word "looks at" every other word in the sentence and decides how much attention to pay to each one.

3. **Weighted combination**: The final representation of each word is a weighted combination of all words' values, where weights are determined by relevance.

**Example**: In "The cat that I saw yesterday was black"
- When processing "cat", the attention mechanism will focus heavily on "black" and "saw"
- When processing "black", it will focus on "cat" to understand what is black

This is what allows the model to understand long-range dependencies and complex relationships within the sentence.

### 3. In what case are encoders useful ? 

Even though we are often talking about encoders in the context of transformers, they can actually be used as standalone models in a variety of tasks:
- Sequence classification (The mini-project of this repo will focus on this task, with a sentiment analysis model)
- Masked language learning
- Question answering
- ...

In general, encoders are useful for tasks linked to the need of a bi-directional extraction of meaningful information in a sequence. 



# II. Implementation

Even though the concept of getting "embeddings" of objects, i.e. getting a representation of them is old and even though the concept of encoders in neural networks dates back to earlier seq2seq models (2014-2015), the **Transformer encoder architecture** was introduced in "Attention is All You Need" (2017), revolutionizing how encoders work by replacing recurrent connections with pure self-attention mechanisms. 

This is the implementation that we will see in this notebook.

Here is what an encoder looks like, based on the original paper:

<div align="center">
    <img src="ressources/aiayn-encoder-scheme.png" alt="Architecture" width="300">
</div>

Let's delve into each of its component:

## 1. Multi-Head Attention

**Multi-Head Attention** is by far the most complex component of the encoder/decoder in the transformer architecture.

To understand Multi-Head Attention, we first need to know about **Scaled Dot-Product Attention**

### 1.1. Scaled Dot-Product Attention

Given a sequence of words: 

"\<bos\> the green frog ran accross the river \<eos\>"

the goal is to find the relationship betweend these words $w_i$ where $i \in \mathbb{N}, i \leq T$, $T$ being the length of the sequence.

For each of this word, we will extract features $x_i$, and for each of these features, we extract three vectors $q_i, k_i, v_i$ corresponding to "Query", "Key" and "Value" vectors. These computations are based on a matrix multiplication:

$q_i = x_i W^q_{d \times d}$

$k_i = x_i W^k_{d \times d}$

$v_i = x_i W^v_{d \times d}$


<div align="center">
    <img src="ressources/attention_youtube_channel_vectors.png" alt="Scheme" width="500">
</div>

Now that we introduced the input, we can focus on the diagram of the Scaled Dot-Product Attention:

<div align="center">
    <img src="ressources/aiayn-scaleddotproductattention-diagram.png" alt="Scheme" width="300">
</div>


For all the mathematical details, check the video "A Dive Into Multihead Attention, Self-Attention and Cross-Attention" by Machine Learning Studio on Youtube.

The "Mask" is optional and only useful to train model to generate new sequences (sequence to sequence). This mask can be applied to make the upper triangular part of the "compatibility matrix" (i.e. the matrix made by the matmul step) equal to $-\infty$, to ensure that a word will not be able to "see" words with a larger index. 

The softmax step gives the attention weights. Note that at this step, $-\infty$ becomes $0$.

The last matmul between the attention weights and the value matrix give us the context matrix $Z$.

### 1.2. Multi-Head Attention

**Multi-Head Attention** is the attention brick used within the transformer architecture, hence in the encoder architecture that we are implementing. 

<div align="center">
    <img src="ressources/aiayn-attentions.png" alt="Scheme" width="300">
</div>

Instead of performing a single attention on large matrixes Q, K and V, it is better to break it into multiple smaller dimensions and performe a scaled dot product separately on each of those smaller matrixes. 

1. First, we define how many heads we want to use for the multi-head attention

2. Then, multiply $X$ with each weight matrices for $Q$, $K$ and $V$

<div align="center">
    <img src="ressources/head_scheme.png" alt="Scheme" width="300">
</div>

3. Then, we perform the same computation as seen before with Scaled Dot-Product Attention:

<div align="center">
    <img src="ressources/scaled_dot_product_multihead.png" alt="Scheme" width="300">
</div>

4. Then, we concatenate the result for each head to get the big "usual" matrix:

<div align="center">
    <img src="ressources/concat_heads.png" alt="Scheme" width="300">
</div>

5. Then, we finally go through a linear layer to get the output

The computational cost is similar, but more efficient than a single Scaled Dot-Product Attention. The reason for that is that Multi head Attention can extract context information from different subspaces at different positions of the input sequence. 



### 1.3. Implementation



In [1]:
import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self,
                 embed_dim,
                 num_heads):
        super(SelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads

        self.head_dim = embed_dim // num_heads

        assert (self.head_dim * num_heads == embed_dim), "embed_dim must be divisible by num_heads"

        self.V = nn.Linear(embed_dim, embed_dim, bias=False)
        self.K = nn.Linear(embed_dim, embed_dim, bias=False)
        self.Q = nn.Linear(embed_dim, embed_dim, bias=False)

    def forward(self,
                query,
                keys,
                values,
                mask=None):
        
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
        
        # 1. Extract the embeddings from the input:
        Q = self.Q(query) # [N, query_len, embed_dim]
        K = self.K(keys) # [N, key_len, embed_dim]
        V = self.V(values) # [N, value_len, embed_dim]

        # 2. Split embeddings into multiple heads
        Queries = Q.reshape(N, query_len, self.num_heads, self.head_dim) # [N, query_len, num_heads, head_dim]
        Keys = K.reshape(N, key_len, self.num_heads, self.head_dim) # [N, key_len, num_heads, head_dim]
        Values = V.reshape(N, value_len, self.num_heads, self.head_dim) # [N, value_len, num_heads, head_dim]

        # 3. Compute the attention scores
        # matmul
        energy = torch.einsum("nqhd,nkhd->nhqk", [Queries, Keys])

        # scale
        energy = energy / (self.embed_dim ** (1/2)) # Explanations https://youtu.be/1IKrHh2X0F0?si=fQozjbfBRPw7J9p9
        
        # apply mask
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
        
        # apply softmax to get attention weights
        attention = torch.softmax(energy, dim=3)

        # final matmul between attention weights with values
        out = torch.einsum("nhql,nlhd->nqhd", [attention, Values]).reshape(N, query_len, self.num_heads * self.head_dim) # [N, query_len, num_heads, head_dim]

        # Out shape :       (N, query_len, num_heads, head_dim) after einsum and flattening the last two dimensions

        # Final linear layer
        out = self.fc_out(out)

        
        return out 
        


## 2. Transformer block

Then, we are already ready to make the main block of the encoder architecture, and actualy of a whole transformer architecture:

<div align="center">
    <img src="ressources/transformer_block.png" alt="Scheme" width="300">
</div>

Without going too far into details, it's understandable that skip connexions are used as transformers are derived from previous sequence models, with RNNs solving the problem of vanishing/exploding gradients first. If needed, check the theory of how residual connexions allow to have much deeper networks.

The decoder can be split into two part:
- As seen before, the first part is used for attention, which is extracting information of tokens in the specific context in which they are given in the input sequence. The layer normalization allows stabilization of training. The skip connection is also used as a stalizer as seen with the example of RNNs. That said, note that only the query is used for the residual connexion !

- The second part is used to individually process the inputs, with the feed forward using an expansion, a ReLU layer on the expanded embedding to make more complex representation and a compression to go back to the original shape. Layer Normalization and skip connexion are also used for the same reason as before. 


In [3]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_dim, num_heads)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, forward_expansion * embed_dim),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_dim, embed_dim)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):

        attention = self.attention(value, key, query, mask)
        x = self.dropout(self.norm1(attention + query)) # Residual connection is done with just the query input
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

## 3. Encoder

Finally, the encoder architecture is done very fast since it's just one operation embedding extractions and a repetition of the TransformerBlock:

<div align="center">
    <img src="ressources/encoder.png" alt="Scheme" width="300">
</div>

In [None]:
class Encoder(nn.Module):
    def __init__(
            self,
            src_vocab_size,
            embed_size,
            num_layers,
            num_heads,
            device,
            forward_expansion,
            dropout,
            max_length):
        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    num_heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion
                ) for _ in range(num_layers)
            ]
        )

        self.dropout = nn.Dropout(dropout)


    def forward(self, x, mask):
        N, seq_length = x.shape

        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)

        out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out

## Sources:

"Transformer: encoder", Hugging Face Youtube channel (https://www.youtube.com/watch?v=MUqNwgPjJvQ)

"A Dive Into Multihead Attention, Self-Attention and Cross-Attention", Machine Learning Studio Youtube channel (https://www.youtube.com/watch?v=mmzRYGCfTzc)

"Attention Is All You Need", Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin(arXiv:1706.03762)

"Self-Attention Using Scaled Dot-Product Approach", Machine Learning Studio Youtube channel (https://youtu.be/1IKrHh2X0F0?si=fQozjbfBRPw7J9p9)
