Understanding decoders is a key point to understand how transformers work. In this notebook, we will present what encoders are and why they are useful by themselves and inside more complex architectures like the famous transformer architecture. 

*Note : It is considered necessary to have the knowledge linked to encoders before reading this. Most information given in my Encoders_Explained repository won't be given again in this one*

# I. How do decoders work ? 

## 1. What are decoders

Just like encoders, decoders can be used as a standalone architecture for different tasks. One may use decoders for the same purpose as encoders, albeit with generally a loss of performance. On the contrary, decoders are much more adapted for certain tasks like generating text. Let's see what are the architectural differences between encoders and decoders:

Again, let's use the example of the sentence "the cat likes cheese".

Passing them through the decoder will give you a numerical representation of each word. Just like encoders, it gives you a feature vector for each of these words. 

## 2. Main difference with encoders: the masked self-attention mechanism

The difference here, is that the decoder doesn't use the self-attention mechanism the same way: it uses **masked self-attention**.

This difference can be illustrated with the word cat for example:

The returned feature vector of this word will only be affected by the word the. The next elements (i.e. future elements) of the sequence are not used by the decoder to make the representation of the element. The next elements are said to be "masked". 

Actually, this is more of an example than what it is in reality, to be more correct, we should say that while encoders have access to a bi-directional context, decoders have access only to context from the left or from the right: it's a unidirectional context. 

## 3. In what cases are decoders helpful ?

Decoders can  be used as standalone models in a variety of tasks:

- Causal tasks; Generating sequences (example: GPT-2)
- Language modeling
- Text generation


In general, decoders are useful for tasks linked to the need of a unidirectional extraction of meaningful information in a sequence. 


# II. Implementation

As done in the *Encoders_Explained* repo, we will implement the decoder architecture following the "Attention is all you need" paper:

<div align="center">
    <img src="assets/decoder_architecture_diagram.png" alt="Architecture" width="300">
</div>

## 1. Multi-Head Attention

Cf Encoder repo to get knowledge about it

In [None]:
import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self,
                 embed_dim,
                 num_heads):
        super(SelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads

        self.head_dim = embed_dim // num_heads

        assert (self.head_dim * num_heads == embed_dim), "embed_dim must be divisible by num_heads"

        self.V = nn.Linear(embed_dim, embed_dim, bias=False)
        self.K = nn.Linear(embed_dim, embed_dim, bias=False)
        self.Q = nn.Linear(embed_dim, embed_dim, bias=False)

        self.fc_out = nn.Linear(num_heads * self.head_dim, embed_dim)

    def forward(self,
                query,
                keys,
                values,
                mask=None):
        
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
        
        # 1. Extract the embeddings from the input:
        Q = self.Q(query) # [N, query_len, embed_dim]
        K = self.K(keys) # [N, key_len, embed_dim]
        V = self.V(values) # [N, value_len, embed_dim]

        # 2. Split embeddings into multiple heads
        Queries = Q.reshape(N, query_len, self.num_heads, self.head_dim) # [N, query_len, num_heads, head_dim]
        Keys = K.reshape(N, key_len, self.num_heads, self.head_dim) # [N, key_len, num_heads, head_dim]
        Values = V.reshape(N, value_len, self.num_heads, self.head_dim) # [N, value_len, num_heads, head_dim]

        # 3. Compute the attention scores
        # matmul
        energy = torch.einsum("nqhd,nkhd->nhqk", [Queries, Keys])

        # scale
        energy = energy / (self.embed_dim ** (1/2)) # Explanations https://youtu.be/1IKrHh2X0F0?si=fQozjbfBRPw7J9p9
        
        # apply mask
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
        
        # apply softmax to get attention weights
        attention = torch.softmax(energy, dim=3)

        # final matmul between attention weights with values
        out = torch.einsum("nhql,nlhd->nqhd", [attention, Values]).reshape(N, query_len, self.num_heads * self.head_dim) # [N, query_len, num_heads, head_dim]

        # Out shape :       (N, query_len, num_heads, head_dim) after einsum and flattening the last two dimensions

        # Final linear layer
        out = self.fc_out(out)

        
        return out 
        

## 2. Transformer Block

Same

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_dim, num_heads)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, forward_expansion * embed_dim),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_dim, embed_dim)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):

        attention = self.attention(query, key, value, mask)
        x = self.dropout(self.norm1(attention + query)) # Residual connection is done with just the query input
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

## 3. Decoder Block


Last time, implementing the Transformer Block to make the Encoder was enough. But the Decoder architecture is naturally different (as we don't expect the exact same behaviour from two different things of course).

The Decoder architecture involved the "usual" Transformer Block implemented just before, plus another self-attention. THe Transformer Block would take, in the context. of a whole transformer, entries from both :

1. The embedding processed by the first Masked Multi-Head Attention (which means represention from only previous/past elements in the sequence)

2. The embeddings processed by the encoder, which are embeddings representative of the previous (if masked used in the encoder) words (or more generaly elements) as part of the whole sub-sequence. 


<div align="center">
    <img src="assets/decoder_architecture_diagram.png" alt="Architecture" width="300">
</div>

In [7]:
class DecoderBlock(nn.Module):
    def __init__(
            self,
               embed_size,
            num_heads,
            forward_expansion,
            dropout,
            device):

        super(DecoderBlock, self).__init__()
        self.embed_size = embed_size

        self.attention = SelfAttention(embed_size, num_heads)
        self.norm = nn.LayerNorm(embed_size)
        self.transformer_block = TransformerBlock(
            embed_size,
            num_heads,
            dropout=dropout,
            forward_expansion=forward_expansion
        )

        self.dropout = nn.Dropout(dropout) 

    def forward(self, x, value, key, src_mask, trg_mask):
        attention = self.attention(x, x, x, trg_mask)
        x = self.dropout(self.norm(attention + x))
        out = self.transformer_block(value, key, x, src_mask)
        return out

In the case of using the decoder as a standalone model for generation (without an encoder), we will use this class instead: 

In [None]:
class StandaloneDecoderBlock(nn.Module):
    def __init__(
            self,
               embed_size,
            num_heads,
            forward_expansion,
            dropout,
            device):

        super(StandaloneDecoderBlock, self).__init__()
        self.embed_size = embed_size

        self.attention = SelfAttention(embed_size, num_heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )

        self.dropout = nn.Dropout(dropout) 

    def forward(self, x, src_mask, trg_mask):
        attention = self.attention(x, x, x, trg_mask)
        x = self.dropout(self.norm1(attention + x))

        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

which simply corresponds to our previous decoder class, except this one doesn't use a second full transformer block, but just its feed-forward block instead.

## 4. Decoder

The decoder architecture is very simply implemented using the pre-made blocks. The overall implementation is really done by the exact same process as for the encoder. 

**Just be attentive to the connexion taken from the encoder in the context of a whole transformer architecture, with the enc-out variable taken as an input by the forward method**

<div align="center">
    <img src="assets/decoder.png" alt="Scheme" width="300">
</div>

In [8]:
class Decoder(nn.Module):
    def __init__(self,
                trg_vocab_size,
                embed_size,
                num_layers,
                num_heads,
                forward_expansion,
                dropout,
                device,
                max_length):
        super(Decoder, self).__init__()

        self.device = device
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                DecoderBlock(
                    embed_size,
                    num_heads,
                    forward_expansion,
                    dropout,
                    device
                ) for _ in range(num_layers)
            ]
        )


        self.fc_out = nn.Linear(embed_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)


    def forward(self, x, enc_out, src_mask, trg_mask): #enc_out = encoder output
        N, self.seq_length = x.shape
        positions = torch.arange(0, self.seq_length).expand(N, self.seq_length).to(self.device)
        x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)

        out = self.fc_out(x)

        return out  # [N, seq_length, trg_vocab_size]

## 5. Practical case

In the mini-project, we will detail how to use decoder as a standalone model for a generation task

## Sources:

"Transformer: decodeur", Hugging Face Youtube channel (https://www.youtube.com/watch?v=d_ixlCubqQw)

"A Dive Into Multihead Attention, Self-Attention and Cross-Attention", Machine Learning Studio Youtube channel (https://www.youtube.com/watch?v=mmzRYGCfTzc)

"Attention Is All You Need", Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin(arXiv:1706.03762)

"Self-Attention Using Scaled Dot-Product Approach", Machine Learning Studio Youtube channel (https://youtu.be/1IKrHh2X0F0?si=fQozjbfBRPw7J9p9)
