# Neural Machine Traslation using Encoder-Decoder Architecture

The aim of this notebook is to implement a Neural Machine Traslation (NMT) using basic [encoder-decoder](https://proceedings.neurips.cc/paper/2014/file/a14ac55a4f27472c5d894ec1c3c743d2-Paper.pdf) approach with [Bahandanau attention mechanism](https://arxiv.org/pdf/1409.0473.pdf). 

In [1]:
# %%capture
# !mkdir MNT-Dataset
# !wget -P MNT-Dataset/ https://www.manythings.org/anki/spa-eng.zip
# !unzip MNT-Dataset/spa-eng.zip -d MNT-Dataset/

In [2]:
# import libaries
import torch
import spacy

import numpy as np
import pandas as pd
import torch.nn.functional as F
import string
from torch import nn
import multiprocessing as mp

from typing import List
from torchmetrics.functional import bleu_score
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader, Dataset, TensorDataset
import re


In [3]:
# load dataset
dataset = pd.read_table("MNT-Dataset/spa.txt", header=None, names=["english", "spanish", "ref"]).drop(labels=["ref"], axis=1)
print(dataset.shape)
dataset.head(10)

(141543, 2)


Unnamed: 0,english,spanish
0,Go.,Ve.
1,Go.,Vete.
2,Go.,Vaya.
3,Go.,Váyase.
4,Hi.,Hola.
5,Run!,¡Corre!
6,Run!,¡Corran!
7,Run!,¡Huye!
8,Run!,¡Corra!
9,Run!,¡Corred!


In [4]:
# Basic preprocessing.
def preprocess(text: str):
    # lowercase
    text = text.lower()

    # remove accents
    # text = unicodedata.normalize("NFKD", text).encode("ascii", "ignore").decode("utf-8", "ignore")

    # remove punctuation
    text = text.translate(str.maketrans("", "", string.punctuation))

    # replace numbers by %num
    text = re.sub(r'\d+', '%num', text)

    return text

In [5]:
# define a tokenizer using spacy
class Tokenizer:
    def __init__(self, language: str = None) -> None:
        """
        A simple tokenizer class that uses Spacy to tokenize text.

        Parameteres:
        ------------
            language (str, optional): The language of the text to be tokenized. Defaults to None.
                Supported languages are 'sp' for Spanish and 'en' for English.
        """

        if language == "sp":
            self.nlp = spacy.load("es_core_news_sm")  # load the Spanish Spacy model
        elif language == "en":
            self.nlp = spacy.load("en_core_web_sm")  # load the English Spacy model

    def __call__(self, text: str) -> str:
        """
        Tokenizes a given text using the Spacy tokenizer.

        Args:
            text (str): The text to be tokenized.

        Returns:
            A list of strings representing the tokens in the text.
        """

        return [w.text for w in self.nlp.tokenizer(text)]  # return the text tokens

In [6]:
# Now we a language class that represents a language and its vocabulary
class Lang:
    def __init__(self, language: str = "sp", sequence_length: int = 50):
        """
        A class for language preprocessing and encoding. It uses a tokenizer to split text into tokens, and encodes
        these tokens into integer values. It also provides methods to add sentences and words to the vocabulary, and to
        transform text into its encoded form.

        Parameters:
        -----------
        language : str, default='sp'
            The language of the text to process. Currently supported languages are 'sp' (Spanish) and 'en' (English).
        sequence_length : int, default=50
            The maximum length of a sequence of tokens. Longer sequences are truncated and shorter ones are padded.
        """

        self.language = language
        self.sequence_length = sequence_length
        self.word2index = {"<pad>": 0, "<start>": 1, "<end>": 2, "<unk>": 3}
        self.word2count = {}
        self.index2word = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"}
        self.n_words = 4  # Count SOS and EOS
        self.tokenizer = Tokenizer(language)

    def addSentence(self, sentence: str):
        """
        Add a sentence to the vocabulary.

        Parameters:
        -----------
        sentence : str
            The sentence to add.
        """

        for word in self.tokenizer(sentence):
            self.addWord(word)

    def addWord(self, word: str):
        """
        Add a word to the vocabulary.

        Parameters:
        -----------
        word : str
            The word to add.
        """

        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

    def fit(self, dataset: List[str]):
        """
        Build the vocabulary from a dataset.

        Parameters:
        -----------
        dataset : list
            A list of sentences to add to the vocabulary.
        """

        for data in tqdm(dataset):
            self.addSentence(preprocess(data))

    def transform(self, text: str, padding: bool = True, start_token: bool = False, end_token: bool = False):
        """
        Transform text into its encoded form.

        Parameters:
        -----------
        text : str
            The text to encode.
        padding : bool, default=True
            Whether to pad the sequence to the maximum sequence length.
        start_token : bool, default=False
            Whether to add a start token to the sequence.
        end_token : bool, default=False
            Whether to add an end token to the sequence.

        Returns:
        --------
        encoding : list
            A list of integers representing the encoded sequence.
        """

        text = preprocess(text)
        tokens = self.tokenizer(text)

        if start_token:
            tokens = ["<start>"] + tokens
        if end_token:
            tokens = tokens + ["<end>"] 
        if padding:
            tokens = self.right_padding(tokens, self.sequence_length)

        encoding = [self.word2index[tk] if tk in self.word2index.keys() else 3 for tk in tokens]

        return encoding

    def inverse_transform(self, tokens: List):
        """
        Decodes the encoded sequence of integers using the vocabulary of the language.

        Parameters:
        -----------
            tokens: list
                The encoded sequence of integers to decode.

        Returns:
        --------
            str: The decoded sentence.
        """

        words = [self.index2word[tk] for tk in tokens]

        return " ".join(words)

    def right_padding(self, tokens: List, sequence_length: int):
        """
        Pads the sequence of tokens with <pad> tokens to match the desired sequence length.

        Parameters:
        -----------
            tokens: list
                The sequence of tokens to pad.
            sequence_length: int
                The desired length of the padded sequence.

        Returns:
        --------
            list: The padded sequence of tokens.
        """

        if len(tokens) < sequence_length:
            padded_tokens = tokens + ["<pad>"] * (sequence_length - len(tokens))

        elif len(tokens) == sequence_length:
            padded_tokens = tokens[: sequence_length + 1]

        elif len(tokens) > sequence_length:
            padded_tokens = tokens[: sequence_length] +  ["<end>"]

        return padded_tokens

# Custom Dataset

We define a custom data set that output the token for the sentences in spanish and english

In [7]:
# Create custom dataset
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, dataset: pd.DataFrame, max_seq_len: int, src_lang: Lang = None, trg_lang: Lang = None):
        """
        A PyTorch custom dataset for language translation.

        Parameters:
        ----------
            dataset : pandas DataFrame
                The dataset containing the English and Spanish sentences.
            max_seq_len : int
                The maximum length of the source and target sentences.
            src_lang : Lang
                The language object for the source language.
            trg_lang : Lang
                The language object for the target language.
        """

        self.dataset = dataset
        self.max_seq_len = max_seq_len

        if isinstance(src_lang, Lang) and isinstance(trg_lang, Lang):
            self.src_lang = src_lang
            self.trg_lang = trg_lang

        else:
            # Initialize language objects for Spanish and English
            self.src_lang = Lang(language="en", sequence_length=max_seq_len)
            self.src_lang.fit(dataset.english)

            self.trg_lang = Lang(language="sp", sequence_length=max_seq_len)
            self.trg_lang.fit(dataset.spanish)

    def __len__(self):
        """
        Returns the number of samples in the dataset.

        Returns:
        -------
        int
            The number of samples in the dataset

        """

        return len(self.dataset)

    def __getitem__(self, idx):
        """
        Returns a sample from the dataset.

        Parameters:
        ----------
        idx : int
            The index of the sample to return.

        Returns:
        -------
        tuple of torch.Tensor
            The English sentence and the Spanish sentence as tensors.

        """

        # Get the Spanish and English sentences from the dataset
        src_text = self.dataset.english.tolist()[idx]
        trg_text = self.dataset.spanish.tolist()[idx]

        # Transform the Spanish and English sentences using the language objects
        src_text = self.src_lang.transform(src_text, start_token=True, end_token=True)
        trg_input_text = self.trg_lang.transform(trg_text, start_token=True)
        trg_output_text = self.trg_lang.transform(trg_text, end_token=True)

        # Convert the transformed sentences to tensors
        src_text = torch.Tensor(src_text).long()
        trg_input_text = torch.Tensor(trg_input_text).long()
        trg_output_text = torch.Tensor(trg_output_text).long()

        return src_text, trg_input_text, trg_output_text

In [8]:
# test the dataloader
ds_train = CustomDataset(dataset, 10)
next(iter(ds_train))

  0%|          | 0/141543 [00:00<?, ?it/s]

  0%|          | 0/141543 [00:00<?, ?it/s]

(tensor([1, 4, 2, 0, 0, 0, 0, 0, 0, 0]),
 tensor([1, 4, 0, 0, 0, 0, 0, 0, 0, 0]),
 tensor([4, 2, 0, 0, 0, 0, 0, 0, 0, 0]))

In [9]:
# get the spanish and English vocab size
src_vocab = ds_train.src_lang.n_words; print(src_vocab)
trg_vocab = ds_train.trg_lang.n_words; print(trg_vocab)

14213
27993


# Transformers

Next, we will construct the transformer model, which consists of the essential components listed below:
1. Word Embedding
2. Positional Encoding 
3. Multi-Head Attention 
4. Add and Normalize 
5. Point Wise Fully Connected 

<img src="https://i.imgur.com/ZdQnGV5.png" alt= “” width="500px" height="700px">

## Position Encoding

If the length of the sentence is given by $pos$ and the embedding dimension/depth is given by $dim$, positional encoding $\mathbf{P}$ is a 2-d matrix of same dimension, i.e., $\mathbf{P} \in \mathbb{R}^{l \times d}$. Every position can be represented with equation in terms of $i$ which is along the $pos$ and $j$ which is along the $dim$ dimension as

$$
\begin{gathered}
\mathbf{P}_{i, 2 j}=\sin \left(i / 1000^{2 j / dim}\right) \\
\mathbf{P}_{i, 2 j+1}=\cos \left(i / 1000^{2 j / dim}\right)
\end{gathered}
$$

In [10]:
# define positional encoding
class PositionalEncoding(nn.Module):
    def __init__(self, seq_len, embed_dim, device='cpu'):
        super().__init__()

        self.seq_len = seq_len
        self.embed_dim = embed_dim
        self.device = device

    def forward(self):

        P = torch.zeros(self.seq_len, self.embed_dim)
        pos = torch.arange(self.seq_len).reshape(-1, 1)
        dim = torch.arange(self.embed_dim).reshape(1, -1)

        P[:, 0::2] = torch.sin(pos / 10000 ** (dim[:, 0::2] / self.embed_dim))
        P[:, 1::2] = torch.cos(pos / 10000 ** ((dim[:, 1::2] - 1) / self.embed_dim))

        return P.to(self.device)

# test positional embedding
positional_encoding = PositionalEncoding(seq_len=5, embed_dim=2)
positional_encoding()

tensor([[ 0.0000,  1.0000],
        [ 0.8415,  0.5403],
        [ 0.9093, -0.4161],
        [ 0.1411, -0.9900],
        [-0.7568, -0.6536]])

## Embeddings

Here we construct the embedding class which include the token embedding and the positional encoding

<img src="https://i.imgur.com/c7EOe74.png" alt= “” width="300px" height="200px">

In [11]:
# Define embedding class using positional and token embeddings
class Embedding(nn.Module):
    def __init__(self, vocab_size, embed_dim, max_seq_len, device="cpu"):
        super().__init__()

        self.token_embeddings = nn.Embedding(vocab_size, embed_dim, device=device)
        self.positional_encodings = PositionalEncoding(max_seq_len, embed_dim, device=device)

    def forward(self, x):
        token_embeddings = self.token_embeddings(x)
        positional_encodings = self.positional_encodings()

        embedding = token_embeddings + positional_encodings

        return embedding

In [12]:
# Test Embedding class
vocab_size = 10
embed_dim = 8
max_seq_len = 3
bs = 2

x = torch.randint(0, vocab_size-1, size=(bs, max_seq_len))

embedding = Embedding(vocab_size, embed_dim, max_seq_len)
embedding(x).shape

torch.Size([2, 3, 8])

### Multi-Head Attention

The Transformer paper introduced the multi-head attention layer as a significant and innovative concept.

<img src="https://i.imgur.com/I8ouVdr.png" alt= “” width="700px" height="450px">

The output of the i-th head is given by 

$$\text { head }_i=\operatorname{attention}\left(\mathbf{W}_q^i \mathbf{Q}, \mathbf{W}_k^i \mathbf{K}, \mathbf{W}_v^i \mathbf{V}\right)$$

where

$$\operatorname{attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V})=\operatorname{softmax}\left(\frac{\mathbf{Q K}^{\mathrm{T}}}{\sqrt{d_k}}\right) \mathbf{V}$$

Multi-head attention have multiple sets of query/key/value weight matrices, each resulting in different query/key/value matrices for the inputs, finally generating output matrices $z_i$ . These output matrices from each head are concatenated and multiplied with an additional weight matrix, $W_O$ , to get a single final matrix, $Z$, with vectors zi as output for each input $x_i$ .The MultiHead

$$\operatorname{multihead}(\mathbf{Q}, \mathbf{K}, \mathbf{V})=\mathbf{W}_O \text { concat }\left(\text { head }_1, \ldots, \text { head }_h\right)$$

In [13]:
# define MultiHeadAttention
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0) -> None:
        super().__init__()

        assert embed_dim % num_heads == 0

        self.embed_dim = embed_dim
        self.num_head = num_heads
        self.head_dim = embed_dim // num_heads

        self.w_q = nn.Linear(embed_dim, embed_dim)
        self.w_k = nn.Linear(embed_dim, embed_dim)
        self.w_v = nn.Linear(embed_dim, embed_dim)
        self.w_o = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        # q shape: (batch_size, seq_len, embed_dim)
        # k shape: (batch_size, seq_len, embed_dim)
        # v shape: (batch_size, seq_len, embed_dim)

        batch_size = q.shape[0]
        seq_len = q.shape[1]

        Q = self.w_q(q)  # Q shape : (batch_size, seq_len, embed_dim)
        K = self.w_k(k)  # K shape : (batch_size, seq_len, embed_dim)
        V = self.w_v(v)  # V shape : (batch_size, seq_len, embed_dim)

        Q = Q.reshape(batch_size, seq_len, self.num_head, self.head_dim).permute(0, 2, 1, 3)  # Q shape: (batch_size, num_head, seq_len, head_dim)
        K = K.reshape(batch_size, seq_len, self.num_head, self.head_dim).permute(0, 2, 1, 3)  # K shape: (batch_size, num_head, seq_len, head_dim)
        V = V.reshape(batch_size, seq_len, self.num_head, self.head_dim).permute(0, 2, 1, 3)  # V shape: (batch_size, num_head, seq_len, head_dim)

        energy = torch.matmul(Q, K.permute(0, 1, 3, 2))/np.sqrt(self.head_dim) # energy shape: (batch_size, num_head, seq_len, seq_len)

        if mask is not None:
            energy = energy.masked_fill(mask == 0, -float('inf'))

        attention = torch.matmul(F.softmax(energy, dim=-1), V)   # attention shape: (batch_size, num_head, seq_len, head_dim)

        attention = self.dropout(attention)
        
        Z = self.w_o(attention.permute(0, 2, 1, 3).reshape(batch_size, seq_len, self.num_head * self.head_dim)) # Z shape: (batch_size, seq_len, embed_dim)
 
        return Z, attention

In [14]:
# Test the MultiHeadAttention
bs = 2
embed_dim = 4
max_length = 5
num_heads = 1

X = torch.rand(size=(bs, max_length, embed_dim))
mask = torch.tril(torch.ones(max_length, max_length))
multi_head_attention = MultiHeadAttention(embed_dim, num_heads)
z, attention = multi_head_attention(X, X, X, mask)
z.shape

torch.Size([2, 5, 4])

### Residuals and Layer Normalization

Similar to ResNets, the inputs, $X$, are short circuited to the output, $Z$, and both are added and passed through layer normalization $$addAndN orm(X + Z)$$

In [15]:
# define add and normalize layer
class AddAndNormalize(nn.Module):
    def __init__(self, embed_dim) -> None:
        super().__init__()

        self.layer_norm = nn.LayerNorm(embed_dim)

    def forward(self, x, z):
        return self.layer_norm(x + z)

### Positionwise Feed-forward Networks

Both encoder and decoder contain a fully connected feed-forward network after the attention sub layers. For each position, similar linear transformations with a ReLU activation in between is performed.
$$FFN(x) = max(0, xW1 + b1 )W2 + b2$$

In [16]:
# define PositionWise FFN class
class PositionWiseFFN(nn.Module):

    def __init__(self, embed_dim, pf_dim ,dropout) -> None:
        super().__init__()

        self.w1 = nn.Linear(embed_dim, pf_dim)
        self.w2 = nn.Linear(pf_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()

    def forward(self, x):

        x = self.w1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.w2(x)
        
        return x

### Encoder Layer

The encoder layer consists of $\{multi-headAttention, addAndNorm, FFN, addAndNorm\}$

<img src="https://i.imgur.com/f3jpYWe.png" alt= “” width="300px" height="350px">

In [17]:
# define Encoder Layer class
import torch
class EncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout, pf_dim):
        super().__init__()

        self.multihead_attention = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.feedforward = PositionWiseFFN(embed_dim, pf_dim, dropout)
        self.add_and_norm_1 = AddAndNormalize(embed_dim)
        self.add_and_norm_2 = AddAndNormalize(embed_dim)

    def forward(self, x, mask):

        # x shape: (batch_size, seq_len, embed_dim)
        # mask shape: (batch_size, 1, 1, seq_len)

        z, _ = self.multihead_attention(x, x, x, mask)
        x = self.add_and_norm_1(x, z)

        z = self.feedforward(x)
        x = self.add_and_norm_2(x, z)
        
        return x

In [18]:
# Test Encoder Layer
embed_dim = 8
max_length = 3
num_heads = 2
dropout = 0.5
pf_dim = 4
bs = 2


x = torch.rand(size=(bs, max_length, embed_dim))
mask = torch.tril(torch.ones(size=(bs, 1, 1, max_length)))

encoder_layer = EncoderLayer(embed_dim, num_heads, dropout, pf_dim)
encoder_layer(x, mask).shape

torch.Size([2, 3, 8])

### Encoder

The Encoder of the transformers consist of N EncoderLayer, the positional encoding and the token embedding.

In [19]:
# Define the EncoderTransformer layer
class EncoderTransformer(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout, pf_dim, vocab_size, max_seq_length, n_layers, device='cpu') -> None:
        super().__init__()

        self.embedding = Embedding(vocab_size, embed_dim, max_seq_length, device)
        self.encoder = nn.ModuleList([EncoderLayer(embed_dim, num_heads, dropout, pf_dim) for _ in range(n_layers)])
        self.to(device)

    def forward(self, x, mask):
        # x shape: (batch_size, seq_length)
        # mask shape: (batch_size, 1, 1, seq_length)

        x = self.embedding(x)  # x shape:  (batch_size, seq_length, embed_dim)

        for layer in self.encoder:
            x = layer(x, mask)  # x shape: (batch_size, seq_length, embed_dim)

        return x

In [20]:
# test EncoderTransformer
embed_dim = 8
max_length = 3
num_heads = 2
vocab_size = 10
n_layers = 2
dropout = 0.5
pf_dim = 4
bs = 2

x = torch.randint(0, vocab_size-1, size=(bs, max_length))
mask = ((x != 0)*1).unsqueeze(1).unsqueeze(2) # dummy mask

encoder = EncoderTransformer(embed_dim, num_heads, dropout, pf_dim, vocab_size, max_length, n_layers)
z = encoder(x, mask)
z.shape

torch.Size([2, 3, 8])

### Decoder Layer

The decoder block in the transformer consists of n blocks of ${maskedMul-tiheadAttention, addAndNorm, encoderDecoderAttention, addAndNorm, FFN, addAndNorm}$


<img src="https://i.imgur.com/HJfpj2Y.png" alt= “” width="300px" height="500px">

In [21]:
# Define decoder layer
class DecoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, pf_dim, dropout):
        super().__init__()

        self.mask_attention = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.encoder_decoder_attention = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.feedforward = PositionWiseFFN(embed_dim, pf_dim, dropout)

        self.add_and_norm_1 = AddAndNormalize(embed_dim)
        self.add_and_norm_2 = AddAndNormalize(embed_dim)
        self.add_and_norm_3 = AddAndNormalize(embed_dim)

    def forward(self, x, encoder_output, encoder_mask, decoder_mask):
        # x shape: (batch_size, seq_len, embed_dim)
        # encoder_output shape: (batch_size, seq_len, embed_dim)
        # encoder_mask shape: (batch_size, 1, 1, seq_len)
        # decoder_mask shape: (batch_size, 1, seq_len, seq_len)

        z, _ = self.mask_attention(x, x, x, mask=decoder_mask)
        x = self.add_and_norm_1(x, z)

        z, attention = self.encoder_decoder_attention(q=x, k=encoder_output, v=encoder_output, mask=encoder_mask)
        x = self.add_and_norm_2(x, z)

        z = self.feedforward(x)
        x = self.add_and_norm_3(x, z)

        return x, attention

In [22]:
# test the DecoderLayer
embed_dim = 8
max_length = 5
num_heads = 2
dropout = 0.5
pf_dim = 4
vocab_size = 10
bs = 2



x = torch.randint(0, vocab_size, size=(bs, max_length))
encoder_mask = ((x != 0)*1).unsqueeze(1).unsqueeze(2)

x_embeddings = torch.rand(size=(bs, max_length, embed_dim))
encoder_output = torch.rand(size=(bs, max_length, embed_dim))
decoder_mask = torch.tril(torch.ones(size=(bs, 1, max_length, max_length)))



decoder_layer = DecoderLayer(embed_dim, num_heads, pf_dim, dropout)
z, attention = decoder_layer(x_embeddings, encoder_output, encoder_mask, decoder_mask)
z.shape

torch.Size([2, 5, 8])

### Decoder

The Decoder of a transformer consist of N DecoderLayer, the positional encoding and the token embedding.

In [23]:
# Define the DecoderTransformer layer
class DecoderTransformer(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout, pf_dim, vocab_size, max_seq_length, n_layers, device="cpu") -> None:
        super().__init__()

        self.embedding = Embedding(vocab_size, embed_dim, max_seq_length, device)
        self.decoder = nn.ModuleList([DecoderLayer(embed_dim, num_heads, pf_dim, dropout) for _ in range(n_layers)])
        self.to(device)

    def forward(self, x, encoder_output, encoder_mask, decoder_mask):
        x = self.embedding(x)

        for layer in self.decoder:
            x, _ = layer(x, encoder_output, encoder_mask, decoder_mask)

        return x

In [24]:
# Test the DecoderTransformer layer
embed_dim = 8
max_length = 5
num_heads = 2
dropout = 0.5
pf_dim = 4
vocab_size = 10
bs = 2
n_layers = 2

x = torch.randint(0, vocab_size, size=(bs, max_length))
encoder_output = torch.rand(size=(bs, max_length, embed_dim))
decoder_mask = torch.tril(torch.ones(size=(bs, 1, max_length, max_length)))
encoder_mask = ((x != 0)*1).unsqueeze(1).unsqueeze(2)

decoder = DecoderTransformer(embed_dim, num_heads, dropout, pf_dim, vocab_size, max_length, n_layers)
z = decoder(x, encoder_output, encoder_mask, decoder_mask)
print(z.shape)

torch.Size([2, 5, 8])


## Transformer Model

The transformer model is composed of an EncoderTransformer and a DecoderTransformer layer.

In [25]:
class Transformer(nn.Module):
    def __init__(
        self,
        embed_dim,
        num_heads,
        dropout,
        pf_dim,
        src_vocab_size,
        trg_vocab_size,
        max_seq_len,
        src_lang,
        trg_lang,
        n_layers=6,
        device="cpu",
    ):
        super().__init__()

        self.max_seq_len = max_seq_len
        self.src_lang = src_lang
        self.trg_lang = trg_lang

        self.encoder = EncoderTransformer(embed_dim, num_heads, dropout, pf_dim, src_vocab_size, max_seq_len, n_layers, device)
        self.decoder = DecoderTransformer(embed_dim, num_heads, dropout, pf_dim, trg_vocab_size, max_seq_len, n_layers, device)
        self.final_layer = nn.Linear(embed_dim, trg_vocab_size, device=device)

        self.loss_fc = nn.CrossEntropyLoss(ignore_index=0)
        self.optimizer = torch.optim.Adam(self.parameters(), lr=2e-5)
        self.device = device

    def forward(self, x, y):
        # x shape: (batch_size, src_len)
        # y shape: (batch_size, trg_len)

        encoder_mask, decoder_mask = self.compute_masks(x, y)

        encoder_output = self.encoder(x, encoder_mask)
        decoder_output = self.decoder(y, encoder_output, encoder_mask, decoder_mask)

        output = self.final_layer(decoder_output)

        return output

    def compute_masks(self, x, y):
        encoder_mask = (x != 0).unsqueeze(1).unsqueeze(2)
        decoder_mask = (y != 0).unsqueeze(1).unsqueeze(2)
        decoder_mask = decoder_mask & torch.tril(torch.ones(y.shape[0], 1, y.shape[1], y.shape[1], device=self.device)).bool()

        encoder_mask = encoder_mask.to(self.device)
        decoder_mask = decoder_mask.to(self.device)

        return encoder_mask, decoder_mask

    def train_one_epoch(self, train_loader):
        running_loss = 0.0
        running_bleu = 0.0

        self.train()
        bar = tqdm(train_loader)

        for step, (x, y_input, y_output) in enumerate(bar, 1):
            x = x.to(self.device)  # x shape: (batch_size, src_len)
            y_input = y_input.to(self.device)  # y_input shape: (batch_size, trg_len)
            y_output = y_output.to(self.device)  # y_output shape: (batch_size, trg_len)

            self.optimizer.zero_grad()

            output = self(x, y_input)  # output shape: (batch_size, trg_len, trg_vocab_size)

            loss = self.loss_fc(output.reshape(-1, output.shape[2]), y_output.reshape(-1))

            loss.backward()

            # clip gradients
            torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0)

            self.optimizer.step()
 
            output = output.argmax(dim=-1).cpu().numpy()
            y_output = y_output.cpu().numpy()

            prediction = np.array([self.trg_lang.inverse_transform(pred) for pred in output])
            y_output = np.array([self.trg_lang.inverse_transform(y_out) for y_out in y_output])

            running_bleu += bleu_score(preds=prediction, target=y_output)
            running_loss += loss.item()

            bar.set_description(f"Training Step: {step}, Training Loss: {running_loss / step:.4f}, Training BLEU: {running_bleu / step:.4f}")

        logs = {"Training Loss": running_loss / step, "Training BLEU": running_bleu / step}

        return logs

    def test_one_epoch(self, test_loader):
        running_loss = 0.0
        running_bleu = 0.0

        self.eval()

        bar = tqdm(test_loader)

        with torch.no_grad():
            for step, (x, y_input, y_output) in enumerate(bar, 1):
                x = x.to(self.device)
                y_input = y_input.to(self.device)
                y_output = y_output.to(self.device)

                output = self(x, y_input)

                loss = self.loss_fc(output.reshape(-1, output.shape[2]), y_output.reshape(-1))

                output = output.argmax(dim=-1).cpu().numpy()
                y_output = y_output.cpu().numpy()

                prediction = np.array([self.trg_lang.inverse_transform(pred) for pred in output])
                y_output = np.array([self.trg_lang.inverse_transform(y_out) for y_out in y_output])

                running_bleu += bleu_score(preds=prediction, target=y_output)
                running_loss += loss.item()

                bar.set_description(f"Test Step: {step}, Test Loss: {running_loss / step:.4f}, Test BLEU: {running_bleu / step:.4f}")

        logs = {"Test Loss": running_loss / step, "Test BLEU": running_bleu / step}

        return logs


    def fit(self, train_loader, test_loader, epochs):
        bar = tqdm(range(epochs))

        for epoch in bar:
            logs_train = self.train_one_epoch(train_loader)
            logs_test = self.test_one_epoch(test_loader)

            bar.set_description(f"Epoch: {epoch}")

    def translate_sentence(self, sentence):
        src_sentence = torch.Tensor(self.src_lang.transform(sentence, start_token=True, end_token=True)).long().to(self.device)
        src_sentence = src_sentence.unsqueeze(0)

        trg_sentence = torch.Tensor(np.zeros(shape=(1, self.max_seq_len))).long().to(self.device)
        trg_sentence[0, 0] = self.trg_lang.word2index["<start>"]

        encoder_mask, decoder_mask = self.compute_masks(src_sentence, trg_sentence)

        self.eval()

        with torch.no_grad():
            encoder_output = self.encoder(src_sentence, encoder_mask)

            for i in range(1, self.max_seq_len):
                decoder_output = self.decoder(trg_sentence, encoder_output, encoder_mask, decoder_mask)
                output = self.final_layer(decoder_output)
                output = output.squeeze(0).argmax(dim=-1)
                trg_sentence[0, i] = output[i - 1]

                if output[i - 1] == self.trg_lang.word2index["<end>"]:
                    break

                encoder_mask, decoder_mask = self.compute_masks(src_sentence, trg_sentence)

        output = output.cpu().numpy().tolist()

        output = self.trg_lang.inverse_transform(output[:i])

        return output

In [26]:
# train test split
from torch.utils.data import DataLoader
dataset_train, dataset_test = train_test_split(dataset.sample(50000), test_size=0.2, random_state=42)

# pytorch datasets
ds_train = CustomDataset(dataset_train, max_seq_len=100)
ds_test = CustomDataset(dataset_test, max_seq_len=100, src_lang=ds_train.src_lang, trg_lang=ds_train.trg_lang)

#Dataloader
loader_train = DataLoader(ds_train, batch_size=128, shuffle=True, num_workers=5, prefetch_factor=5)
loader_test = DataLoader(ds_test, batch_size=128, shuffle=False)

  0%|          | 0/40000 [00:00<?, ?it/s]

  0%|          | 0/40000 [00:00<?, ?it/s]

In [27]:
# instance model
embed_dim = 512
num_heads = 8
dropout = 0.3
pf_dim = 512
src_vocab_size = ds_train.src_lang.n_words
trg_vocab_size = ds_train.trg_lang.n_words
max_length = ds_train.max_seq_len
n_layers = 6
device = "cuda"

transformer = Transformer(
    embed_dim,
    num_heads,
    dropout,
    pf_dim,
    src_vocab_size,
    trg_vocab_size,
    max_length,
    ds_train.src_lang,
    ds_train.trg_lang,
    n_layers,
    device,
)

In [28]:
# fit the model
transformer.fit(loader_train, loader_test, epochs=10)

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]



  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

In [29]:
# short translation
transformer.translate_sentence("i am really happy here!")

'estoy realmente feliz <end>'