# Encoder

Parecía que ya teníamos todo el encoder listo, pero hay una cosa que aparece en la arquitectura en la que no nos habíamos fijado, a la derecha del encoder aparece un `Nx`

<div style="text-align:center;">
  <img src="Imagenes/transformer_architecture_model_encoder_Nx.png" alt="Encoder Nx" style="width:425px;height:626px;">
</div>

En el paper indican que hay que realizar el encoder y el decoder N veces, y proponen 8, por lo que vamos a implementarlo gracias a la función [nn.ModuleList](https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html) de Pytorch

## Implementación

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class ScaledDotProductAttention(nn.Module):
    def __init__(self, dim_embedding):
        """
        Args:
            dim_embedding: dimension of embedding vector
        """
        super().__init__()
        self.dim_embedding = dim_embedding
    
    def forward(self, key, query, value):
        """
        Args:
            key: key vector
            query: query vector
            value: value vector
        
        Returns:
            output vector from scaled dot product attention
        """
        # MatMul
        key_trasposed = key.transpose(-1,-2)
        product = torch.matmul(query, key_trasposed)
        # scale
        product = product / math.sqrt(self.dim_embedding)
        # softmax
        attention_matrix = torch.nn.functional.softmax(product, dim=-1)
        # MatMul
        output = torch.matmul(attention_matrix, value)
        
        return output

class MultiHeadAttention(nn.Module):
    def __init__(self, heads, dim_embedding):
        super().__init__()
        
        self.dim_embedding = dim_embedding
        self.dim_proyection = dim_embedding // heads
        self.heads = heads
        
        self.proyection_Q = nn.Linear(dim_embedding, dim_embedding)
        self.proyection_K = nn.Linear(dim_embedding, dim_embedding)
        self.proyection_V = nn.Linear(dim_embedding, dim_embedding)
        self.attention = nn.Linear(dim_embedding, dim_embedding)

        self.scaled_dot_product_attention = ScaledDotProductAttention(self.dim_proyection)
    
    def forward(self, Q, K, V):
        batch_size = Q.size(0)
        
        # perform linear operation and split into h heads
        proyection_Q = self.proyection_Q(Q).view(batch_size, -1, self.heads, self.dim_proyection)
        proyection_K = self.proyection_K(K).view(batch_size, -1, self.heads, self.dim_proyection)
        proyection_V = self.proyection_V(V).view(batch_size, -1, self.heads, self.dim_proyection)
        
        # transpose to get dimensions bs * h * sl * d_model
        proyection_Q = proyection_Q.transpose(1,2)
        proyection_K = proyection_K.transpose(1,2)
        proyection_V = proyection_V.transpose(1,2)

        # calculate attention using function we will define next
        outputs = []
        for i in range(self.heads):
            output = self.scaled_dot_product_attention(proyection_Q[:, i, :, :], proyection_K[:, i, :, :], proyection_V[:, i, :, :])
            outputs.append(output)
        scaled_dot_product_attention = torch.stack(outputs, dim=2)  # stacking along the head dimension
        
        # concatenate heads and put through final linear layer
        concat = scaled_dot_product_attention.transpose(1,2).contiguous().view(batch_size, self.dim_embedding)
        
        output = self.attention(concat)
    
        return output

class AddAndNorm(nn.Module):
    def __init__(self, dim_embedding):
        """
        Args:
            dim_embedding (int): Embedding dimension.
        """
        super().__init__()
        self.normalization = nn.LayerNorm(dim_embedding)

    def forward(self, x, sublayer):
        """
        Args:
            x (torch.Tensor): Input tensor.
            sublayer (torch.Tensor): Sublayer tensor.

        Returns:
            torch.Tensor: Output tensor.
        """
        return self.normalization(torch.add(x, sublayer))

class FeedForward(nn.Module):
    def __init__(self, dim_embedding, increment=4):
        super().__init__()
        self.feed_forward = nn.Sequential(
            nn.Linear(dim_embedding, dim_embedding*increment),
            nn.ReLU(),
            nn.Linear(dim_embedding*increment, dim_embedding)
        )
    
    def forward(self, x):
        """
        Args:
            x (torch.Tensor): (batch_size, seq_len, dim_embedding)

        Returns:
            torch.Tensor: (batch_size, seq_len, dim_embedding)
        """
        x = self.feed_forward(x)
        return x

In [2]:
class EncoderLayer(nn.Module):
    def __init__(self, heads, dim_embedding):
        super().__init__()
        self.multi_head_attention = MultiHeadAttention(heads, dim_embedding)
        self.add_and_norm_1 = AddAndNorm(dim_embedding)
        self.feed_forward = FeedForward(dim_embedding)
        self.add_and_norm_2 = AddAndNorm(dim_embedding)
    
    def forward(self, x):
        """
        Args:
            x (torch.Tensor): (batch_size, seq_len, dim_embedding)

        Returns:
            torch.Tensor: (batch_size, seq_len, dim_embedding)
        """
        multi_head_attention = self.multi_head_attention(x, x, x)
        add_and_norm_1 = self.add_and_norm_1(x, multi_head_attention)
        feed_forward = self.feed_forward(add_and_norm_1)
        add_and_norm_2 = self.add_and_norm_2(add_and_norm_1, feed_forward)
        return add_and_norm_2

In [3]:
class Encoder(nn.Module):
    def __init__(self, heads, dim_embedding, num_layers):
        super().__init__()
        self.encoder_layers = nn.ModuleList([EncoderLayer(heads, dim_embedding) for _ in range(num_layers)])
    
    def forward(self, x):
        """
        Args:
            x (torch.Tensor): (batch_size, seq_len, dim_embedding)

        Returns:
            torch.Tensor: (batch_size, seq_len, dim_embedding)
        """
        for encoder_layer in self.encoder_layers:
            x = encoder_layer(x)
        return x

In [4]:
from transformers import BertModel, BertTokenizer
def extract_embeddings(input_sentence, model_name='bert-base-uncased'):
    tokenizer = BertTokenizer.from_pretrained(model_name)
    model = BertModel.from_pretrained(model_name)
    
    input_ids = tokenizer.encode(input_sentence, add_special_tokens=True)
    input_ids_tensor = torch.tensor([input_ids])
    tokens = tokenizer.convert_ids_to_tokens(input_ids)
    
    with torch.no_grad():
        outputs = model(input_ids_tensor)
        
    token_embeddings = outputs[0][0]
    
    # Los embeddings posicionales están en la segunda capa de los embeddings de la arquitectura BERT
    positional_encodings = model.embeddings.position_embeddings.weight[:len(input_ids), :].detach()
    
    embeddings_with_positional_encoding = token_embeddings + positional_encodings
    
    return tokens, input_ids, token_embeddings, positional_encodings, embeddings_with_positional_encoding

In [5]:
sentence1 = "I gave the dog a bone because it was hungry"
tokens1, input_ids1, token_embeddings1, positional_encodings1, embeddings_with_positional_encoding1 = extract_embeddings(sentence1)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
heads = 8
dim_embedding = embeddings_with_positional_encoding1.shape[-1]
Nx = 8
encoder = Encoder(heads, dim_embedding, Nx)

In [7]:
encoder_output = encoder(embeddings_with_positional_encoding1)
encoder_output.shape

torch.Size([12, 768])