# Recurrent translations

This notebook considers the solution of the translation task using recurrent layers.

In [2]:
from tqdm import tqdm
from datasets import load_dataset

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")

## Data

As an example, the English->Russian data set `tatoeda` is considered. [Tatoeda](https://tatoeba.org/en/) - collections of sentences and their translations. In particular, it's [hugging face implementation](https://huggingface.co/datasets/Helsinki-NLP/tatoeba) was used. The following cell loads data, transforms it into a more convenient format and displays some sentences from the dataset.

In [None]:
dataset = load_dataset("tatoeba", lang1="en", lang2="ru")
dataset = [
    (translation["en"], translation["ru"])
    for translation in dataset["train"]["translation"]
]
dataset[200:205]

The next cell prepares the following:  
- Service tokens.
- Vocabulary. 
- Tokenization transformation.  
- Mappings between tokens and indices (both directions).

In [5]:
# Special tokens
PAD_TOKEN = "<PAD>"
# End of sentence
EOS_TOKEN = "<EOS>"
# Start of sentence
SOS_TOKEN = "<SOS>"
# Unknown
UNK_TOKEN = "<UNK>"


def tokenize(sentence: str) -> list[str]:
    '''Tokinelattion of the given stirng  by words'''
    return sentence.lower().split()

EN_VOCAB = {PAD_TOKEN, EOS_TOKEN, SOS_TOKEN, UNK_TOKEN}
RU_VOCAB = {PAD_TOKEN, EOS_TOKEN, SOS_TOKEN, UNK_TOKEN}
for en, ru in dataset:
    EN_VOCAB.update(tokenize(en))
    RU_VOCAB.update(tokenize(ru))

def create_mappings(vocab: set) -> tuple[dict[str, int], dict[int, str]]:
    '''
    Create mappings.
    '''
    word2int = {word: i for i, word in enumerate(vocab)}
    int2word = {i: word for word, i in word2int.items()}
    return word2int, int2word

EN_WORD2INT, EN_INT2WORD = create_mappings(EN_VOCAB)
RU_WORD2INT, RU_INT2WORD = create_mappings(RU_VOCAB)

Now, `torch.utils.data.Dataset` is set up to iterate over sentence pairs, already transformed into tensors containing token indices.

In [6]:
def tensor_tokenize(
    sentence: str,
    word2int: dict[str, int],
    device: torch.device = DEVICE
) -> torch.Tensor:
    '''Transform sentence into tensor with indeces of tokens.'''
    return torch.tensor(
        [
            word2int.get(word, word2int[UNK_TOKEN])
            for word in tokenize(sentence)
        ]
        + [word2int[EOS_TOKEN]],
        dtype=torch.long,
        device=device
    )


class TranslationDataset(torch.utils.data.Dataset):
    '''A data set iterating over input->translation pairs.'''
    def __init__(
        self,
        pairs: list[tuple[str, str]],
        en_word2int: dict[str, int],
        ru_word2int: dict[str, int]
    ):
        self.pairs = pairs
        self.en_word2int = en_word2int
        self.ru_word2int = ru_word2int

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        eng, rus = self.pairs[idx]
        eng_tensor = tensor_tokenize(sentence=eng, word2int=self.en_word2int)
        rus_tensor = tensor_tokenize(sentence=rus, word2int=self.ru_word2int)
        return eng_tensor, rus_tensor

translation_dataset = TranslationDataset(
    pairs=dataset,
    en_word2int=EN_WORD2INT,
    ru_word2int=RU_WORD2INT
)

Here is an example of using the created dataset - it simply returns a pair of tensors containing the indices of English and Russian tokens, respectively.

In [7]:
next(iter(translation_dataset))

(tensor([35081, 45487, 32934, 30542,  3478, 14224, 39060, 16212, 16250, 48871,
         44435, 48467,  1121, 19334, 49510]),
 tensor([ 54680,  39000,  37535,  60821, 120953,  29419,  54389,  24444,  19745,
          28568,  90264,  60302]))

And finally, a dataloader constructs minibatches of pairs and pads shorter sentences so that data across a set of sentences can be represented as a single tensor.

In [8]:
def collate_fn(batch: list[torch.Tensor, torch.Tensor]):
    '''
    Transforms a list of tokinized sentences into the torch tensor. Should be
    used as `collate_fn` argument of the dataloader. The main purpose is to
    pad all sentences to have the same length.
    '''
    eng_batch, rus_batch = zip(*batch)
    eng_batch_padded = pad_sequence(
        eng_batch, batch_first=True, padding_value=EN_WORD2INT[PAD_TOKEN]
    )
    rus_batch_padded = pad_sequence(
        rus_batch, batch_first=True, padding_value=RU_WORD2INT[PAD_TOKEN]
    )
    return eng_batch_padded, rus_batch_padded

batch_size = 64
translation_dataloader = torch.utils.data.DataLoader(
    translation_dataset,
    batch_size=batch_size,
    shuffle=False,
    drop_last=True,
    collate_fn=collate_fn,
)
batch = next(iter(translation_dataloader))
(batch[0][:2], batch[1][:2])

(tensor([[35081, 45487, 32934, 30542,  3478, 14224, 39060, 16212, 16250, 48871,
          44435, 48467,  1121, 19334, 49510],
         [24420,  6220, 31775, 49510, 11129, 11129, 11129, 11129, 11129, 11129,
          11129, 11129, 11129, 11129, 11129]]),
 tensor([[ 54680,  39000,  37535,  60821, 120953,  29419,  54389,  24444,  19745,
           28568,  90264,  60302,  80629,  80629,  80629,  80629],
         [122624,  64426,  45452,  60302,  80629,  80629,  80629,  80629,  80629,
           80629,  80629,  80629,  80629,  80629,  80629,  80629]]))

## Model

In [6]:
x = next(iter(translation_dataloader))[0]
x.shape

torch.Size([64, 15])

In [7]:
class Encoder(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        embed_size: int,
        hidden_size: int,
        num_layers: int = 1
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.RNN(
            input_size=embed_size,
            hidden_size=hidden_size,
            batch_first=True,
            bidirectional=True,
            num_layers=num_layers
        )

    def forward(self, x: torch.Tensor):
        embedded = self.embedding(x)
        outputs, hidden = self.rnn(embedded)
        # Concatenating results of all layers and all directions to one long
        # sequence
        hidden = torch.cat(
            [hidden[i, :, :] for i in range(len(hidden))], dim=1
        ).unsqueeze(0)
        return outputs, hidden

In [8]:
embed_size = 10
hidden_size = 10
num_layers = 4

encoder = Encoder(
    vocab_size=len(EN_VOCAB),
    embed_size=embed_size,
    hidden_size=hidden_size,
    num_layers=num_layers
).to(DEVICE)
encoder(x)[1].shape

torch.Size([1, 64, 80])

In [9]:
class Decoder(nn.Module):
    def __init__(
        self: int,
        vocab_size: int,
        embed_size: int,
        hidden_size: int
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.RNN(
            input_size=embed_size,
            hidden_size=hidden_size,
            batch_first=True
        )
        self.fc = nn.Linear(in_features=hidden_size, out_features=vocab_size)

    def forward(self, x, hidden):
        out = self.embedding(x)
        out, hidden = self.rnn(out, hidden)
        out = self.fc(out).reshape(out.size(0), -1)
        return out, hidden

In [10]:
decoder = Decoder(
    vocab_size=len(RU_VOCAB),
    embed_size=10,
    hidden_size=hidden_size*num_layers*2
).to(DEVICE)

In [11]:
def translate(
    encoder: Encoder,
    decoder: Decoder,
    sentence: str,
    en_word2int: dict[str: int],
    ru_int2word: dict[int: str],
    ru_word2int: dict[str: int],
    max_length: int = 15,
    device: torch.device = DEVICE
):
    encoder.eval()
    decoder.eval()

    with torch.inference_mode():
        input_tensor = tensor_tokenize(sentence=sentence, word2int=en_word2int)
        input_tensor = input_tensor.view(1, -1).to(device)

        # Pass input sentence through encoder
        _, encoder_hidden = encoder(input_tensor)
        # Intialise hidden state of decoder
        decoder_hidden = encoder_hidden

        decoded_words = []
        last_word = torch.tensor([[en_word2int[SOS_TOKEN]]]).to(device)
        for _ in range(max_length):
            # Pass last predicted token through decoder
            logits, decoder_hidden = decoder(last_word, decoder_hidden)
            # Selecting the most probable token
            next_token = logits.argmax(dim=1)
            last_word = next_token.unsqueeze(0).to(device)
            if next_token.item() == ru_word2int[EOS_TOKEN]:
                break
            else:
                decoded_words.append(ru_int2word.get(next_token.item()))

    # return predicted words as a string
    return " ".join(decoded_words)

In [12]:
translate(
    encoder=encoder,
    decoder=decoder,
    sentence="hello world",
    en_word2int=EN_WORD2INT,
    ru_int2word=RU_INT2WORD,
    ru_word2int=RU_WORD2INT,
    max_length=20
)

'экватор детьми-близнецами бифштекс? посещаете горячая? крыса отсутствии иностранца. принципа, умрёте? пловцы. спорно. джастин драться, иностранца. принципа, умрёте? пловцы. спорно. джастин'

## Fitting

In [13]:
for i, v in enumerate(translation_dataloader):
    if i == 1965:
        break

loss_fn = nn.CrossEntropyLoss(ignore_index=EN_WORD2INT[PAD_TOKEN])

input_tensor, target_tensor = v
target_length = target_tensor.size(1)
_, encoder_hidden = encoder(input_tensor)
decoder_input = torch.full(
    (batch_size, 1), EN_WORD2INT[SOS_TOKEN], dtype=torch.long
).to(DEVICE)
decoder_hidden = encoder_hidden
loss = torch.tensor(0.0, device=DEVICE, requires_grad=True)
for di in range(target_length):
    logits, decoder_hidden = decoder(decoder_input, decoder_hidden)

    loss = loss + loss_fn(logits, target_tensor[:, di])

    decoder_input = target_tensor[:, di].reshape(batch_size, 1)

In [None]:
loss_fn = nn.CrossEntropyLoss(ignore_index=EN_WORD2INT[PAD_TOKEN])

encoder_optimizer = optim.AdamW(encoder.parameters())
decoder_optimizer = optim.AdamW(decoder.parameters())

num_epochs = 1

encoder.train()
decoder.train()

for epoch in range(num_epochs):
    iterator = tqdm(enumerate(translation_dataloader))
    for i, (input_tensor, target_tensor) in iterator:
        input_tensor, target_tensor = (
            input_tensor.to(DEVICE),
            target_tensor.to(DEVICE)
        )

        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()

        target_length = target_tensor.size(1)

        _, encoder_hidden = encoder(input_tensor)

        decoder_input = torch.full(
            (batch_size, 1), EN_WORD2INT[SOS_TOKEN], dtype=torch.long
        ).to(DEVICE)
        decoder_hidden = encoder_hidden

        loss = torch.tensor(0.0, device=DEVICE, requires_grad=True)
        for di in range(target_length):
            logits, decoder_hidden = decoder(decoder_input, decoder_hidden)

            loss = loss + loss_fn(logits, target_tensor[:, di])

            decoder_input = target_tensor[:, di].reshape(batch_size, 1)

        loss.backward()
        encoder_optimizer.step()
        decoder_optimizer.step()

        if i % 100 == 0:
            print(
                f"Epoch {epoch}, ",
                f"Batch {i}, ",
                f"Loss: {loss.item() / target_length:.4f}"
            )

In [18]:
translate(
    encoder=encoder,
    decoder=decoder,
    sentence="How does it works?",
    en_word2int=EN_WORD2INT,
    ru_int2word=RU_INT2WORD,
    ru_word2int=RU_WORD2INT,
    max_length=20
)

'как это не так.'

In [20]:
translate(
    encoder=encoder,
    decoder=decoder,
    sentence="Let's try something.",
    en_word2int=EN_WORD2INT,
    ru_int2word=RU_INT2WORD,
    ru_word2int=RU_WORD2INT,
    max_length=20
)

'давайте поговорим и не будем быть в бостоне.'