In [1]:
!pip install -q evaluate

In [2]:
import requests
import zipfile
import io
import unicodedata
import re
from torchtext.vocab import build_vocab_from_iterator
import torch
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import time
import random
import evaluate
import os

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Download and extract the dataset

- the text file containing the translation pairs should end up at the path `data/eng-fra.txt`


In [4]:
DATASET_URL = "https://download.pytorch.org/tutorial/data.zip"
r = requests.get(DATASET_URL, stream=True)
data_zip = zipfile.ZipFile(io.BytesIO(r.content))
data_zip.extractall()

## Process the dataset


### Get the pairs of translations


In [5]:
DATASET_PATH = "data/eng-fra.txt"
pairs = []
DATASET_USAGE = 1


def read_pairs():
    with open(DATASET_PATH, "r", encoding="utf8") as f:
        pairs = f.read().strip().split("\n")
        random.shuffle(pairs)
        pairs = pairs[: int(DATASET_USAGE * len(pairs))]
        return pairs

In [6]:
len(read_pairs())

135842

In [7]:
read_pairs()[:5]

['You seem happy.\tTu sembles heureux.',
 "The pain hasn't gone away.\tLa douleur n'est pas partie.",
 "This is a farce.\tIl s'agit d'une farce.",
 'I asked them to fix my car.\tJe leur ai demandé de réparer ma voiture.',
 "Don't you remember my name?\tNe te rappelles-tu pas mon nom ?"]

In [8]:
read_pairs()[-5:]

['This boat is seaworthy.\tLe navire est en état de naviguer.',
 "You're alone, aren't you?\tVous êtes seule, n'est-ce pas ?",
 "Yes, I'm a student too.\tOui, je suis aussi étudiant.",
 "I remember mentioning it once or twice.\tJe me rappelle l'avoir mentionné à une ou deux reprises.",
 'I know what those books are like.\tJe sais à quoi ces ouvrages ressemblent.']

### Split the pairs and normalize the data


In [9]:
PAIR_DELIMETER = "\t"
MAX_LEN = 10


# Remove any accented characters and non-ASCII symbols
# From: https://stackoverflow.com/a/7782177
def normalize_text(text):
    return str(
        unicodedata.normalize("NFKD", text.strip().lower())
        .encode("ascii", "ignore")
        .decode("ascii")
    )


def remove_special_chars(text):
    text = re.sub(r"([.!?])", r" \1", text)
    text = re.sub(r"[^a-zA-Z!?]+", " ", text)
    return text.strip()


def preprocess_text(text):
    text = normalize_text(text)
    return remove_special_chars(text)


def read_normalized_pairs():
    pairs = read_pairs()
    for index, pair in enumerate(pairs):
        eng, fra = pair.split(PAIR_DELIMETER)
        eng = preprocess_text(eng)
        fra = preprocess_text(fra)
        pairs[index] = [eng, fra]
    filtered = [
        pair
        for pair in pairs
        if len(pair[0].split(" ")) <= MAX_LEN and len(pair[1].split(" ")) <= MAX_LEN
    ]
    return filtered

In [10]:
len(read_normalized_pairs())

115596

In [11]:
read_normalized_pairs()[:5]

[['i saw him sawing a tree', 'je l ai vu en train de scier un arbre'],
 ['she aimed at the target', 'elle visa la cible'],
 ['he came at o clock in the afternoon',
  'il est venu a trois heures de l apres midi'],
 ['how does he do it ?', 'comment s y prend il ?'],
 ['come quickly !', 'depechez vous de venir !']]

In [12]:
max([len(eng.split(" ")) for eng, fra in read_normalized_pairs()])

10

In [13]:
read_normalized_pairs()[-5:]

[['it s gradually getting colder', 'il fait de plus en plus froid'],
 ['i want you to get your own place',
  'je veux que vous preniez votre propre logement'],
 ['you have the right to know', 'tu as le droit de savoir'],
 ['they re part of us', 'ils font partie de nous'],
 ['how long is the flight ?', 'combien de temps dure le vol ?']]

In [14]:
UNK_TOKEN = "<unk>"
SOS_TOKEN = "<sos>"
EOS_TOKEN = "<eos>"
PAD_TOKEN = "<pad>"


def tokenize(sentence):
    return sentence.lower().split(" ")


def eng_iter(pairs):
    for eng, fra in pairs:
        yield tokenize(eng)


def fra_iter(pairs):
    for eng, fra in pairs:
        yield tokenize(fra)


def build_vocab(pairs=None):
    if pairs is None:
        pairs = read_normalized_pairs()
    specials = [UNK_TOKEN, SOS_TOKEN, EOS_TOKEN, PAD_TOKEN]
    eng_vocab = build_vocab_from_iterator(
        eng_iter(pairs), specials=specials, special_first=True
    )
    fra_vocab = build_vocab_from_iterator(
        fra_iter(pairs), specials=specials, special_first=True
    )
    eng_vocab.set_default_index(eng_vocab[UNK_TOKEN])
    fra_vocab.set_default_index(fra_vocab[UNK_TOKEN])
    return eng_vocab, fra_vocab

In [15]:
eng_vocab, fra_vocab = build_vocab()

In [16]:
eng_vocab(tokenize("he ate an apple"))

[12, 690, 79, 805]

In [17]:
eng_vocab(tokenize("I had a good day"))

[4, 69, 9, 74, 132]

In [18]:
print("English vocab length:", len(eng_vocab))

English vocab length: 11300


In [19]:
print("French vocab length:", len(fra_vocab))

French vocab length: 19001


## Create the dataloader

The data gets loaded with 0s as a default, which is the UNK token


In [20]:
def max_seq_length(pairs):
    eng_max = max([len(eng.split(" ")) for eng, fra in pairs])
    fra_max = max([len(fra.split(" ")) for eng, fra in pairs])
    return max(eng_max, fra_max)


MAX_SEQUENCE_LENGTH = max_seq_length(read_normalized_pairs()) + 1

In [21]:
def text_pipeline(text, vocab):
    return vocab(tokenize(text))


def prepare_input(text, vocab):
    text = text_pipeline(text, vocab)
    text += vocab([EOS_TOKEN])
    return text


def create_dataloader(batch_size):
    pairs = read_normalized_pairs()
    eng_vocab, fra_vocab = build_vocab(pairs)
    # Use MAX_SEQUENCE_LENGTH+1 to allow for the EOS_TOKEN to be added
    eng_data = torch.zeros(
        len(pairs), MAX_SEQUENCE_LENGTH, dtype=torch.long, device=device
    ).fill_(eng_vocab[PAD_TOKEN])
    fra_data = torch.zeros(
        len(pairs), MAX_SEQUENCE_LENGTH, dtype=torch.long, device=device
    ).fill_(eng_vocab[PAD_TOKEN])

    for index, (eng, fra) in enumerate(pairs):
        eng_datum = prepare_input(eng, eng_vocab)
        fra_datum = prepare_input(fra, fra_vocab)

        # Add the datum to the dataset
        eng_data[index, : len(eng_datum)] = torch.tensor(
            eng_datum, dtype=torch.long, device=device
        )
        fra_data[index, : len(fra_datum)] = torch.tensor(
            fra_datum, dtype=torch.long, device=device
        )

    combined_dataset = TensorDataset(eng_data, fra_data)
    return eng_vocab, fra_vocab, DataLoader(combined_dataset, batch_size=batch_size)

In [22]:
MAX_SEQUENCE_LENGTH

11

In [23]:
dl = create_dataloader(64)
dl

(Vocab(), Vocab(), <torch.utils.data.dataloader.DataLoader at 0x7ca4dae2c730>)

In [24]:
for eng, fra in dl[2]:
    print(eng, fra)
    print(eng.shape, fra.shape)
    break

tensor([[1475,    9, 1256,  864,    2,    3,    3,    3,    3,    3,    3],
        [ 170,   11,    9,  867,   17,   66,    2,    3,    3,    3,    3],
        [   4, 1535,   10, 1830,    6,  104,    5,    2,    3,    3,    3],
        [   4,  956,    7, 3496,    2,    3,    3,    3,    3,    3,    3],
        [  13,   14,   51, 2493,    2,    3,    3,    3,    3,    3,    3],
        [   4,  126,    6,    7, 1670,  158,    2,    3,    3,    3,    3],
        [  27, 5273, 3863,    2,    3,    3,    3,    3,    3,    3,    3],
        [  70,   22,   10,    5,  821,   96,    8,    2,    3,    3,    3],
        [  88,   92,    9,  676,    2,    3,    3,    3,    3,    3,    3],
        [  12,   26, 1044,  338, 1524,    2,    3,    3,    3,    3,    3],
        [  53,   26,    9,  920,   17, 1320, 4275,    2,    3,    3,    3],
        [   4,   99,   36,  925,   58,   15,    2,    3,    3,    3,    3],
        [  12,   26,  101,  249, 1078,  580,    2,    3,    3,    3,    3],
        [  1

In [25]:
eng[0]

tensor([1475,    9, 1256,  864,    2,    3,    3,    3,    3,    3,    3],
       device='cuda:0')

In [26]:
fra[0]

tensor([4749,   26, 1537, 1228,    2,    3,    3,    3,    3,    3,    3],
       device='cuda:0')

## Define the model architecture


In [27]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_dim, bidirectional=False, dropout=0.0):
        super().__init__()
        self.embedding = nn.Embedding(input_size, hidden_dim)
        self.rnn = nn.LSTM(
            hidden_dim, hidden_dim, batch_first=True, bidirectional=bidirectional
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, input):
        embedded = self.embedding(input)
        output, hidden = self.rnn(embedded)
        output = self.dropout(output)
        return output, hidden


class Decoder(nn.Module):
    def __init__(
        self,
        hidden_dim,
        output_size,
        bidirectional=False,
        dropout=0.0,
        fill_val=eng_vocab[SOS_TOKEN],
    ):
        super().__init__()
        self.fill_val = fill_val
        self.embedding = nn.Embedding(output_size, hidden_dim)
        self.rnn = nn.LSTM(
            hidden_dim, hidden_dim, batch_first=True, bidirectional=bidirectional
        )
        fc_in_features = hidden_dim * 2 if bidirectional else hidden_dim
        self.fc = nn.Linear(fc_in_features, output_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, encoder_outputs, hidden_state, target_tensor=None):
        batch_size = encoder_outputs.shape[0]
        input = torch.zeros(batch_size, 1, dtype=torch.long, device=device).fill_(
            self.fill_val
        )
        outputs = []

        for i in range(MAX_SEQUENCE_LENGTH):
            output, _ = self.forward_step(input, hidden_state)
            # output is of shape: [batch_size, 1, output_size], where output_size is
            # num. of unique words in target language
            outputs.append(output)

            if target_tensor is not None:
                # If teacher forcing:
                # Use the target tensor's values as the next input, converting them to
                # the same shape as the original input
                input = target_tensor[:, i].unsqueeze(1)
            else:
                # If not teacher forcing:
                # Take the topk of the output and use it as the next input (the topk
                # will be of size [batch_size, 1, 1])
                _, top_indexes = output.topk(1)
                input = top_indexes.squeeze(-1)

        # Concatenate all of the outputs along dimension 1, creating the complete
        # sequence from individual parts for each input
        # (size [batch_size, MAX_SEQUENCE_LENGTH, output_size])
        outputs = torch.cat(outputs, dim=1)
        outputs = F.log_softmax(outputs, dim=-1)
        return outputs

    def forward_step(self, input, hidden_state):
        output = self.embedding(input)
        output = F.relu(output)
        output, hidden_state = self.rnn(output, hidden_state)
        output = self.dropout(output)
        output = self.fc(output)
        return output, hidden_state

In [28]:
class Seq2SeqModel(nn.Module):
    def __init__(
        self,
        input_size,
        output_size,
        hidden_dim,
        bidirectional=True,
        dropout=0.0,
        fill_val=eng_vocab[SOS_TOKEN],
    ):
        super().__init__()
        self.encoder = Encoder(
            input_size, hidden_dim, bidirectional=bidirectional, dropout=dropout
        )
        self.decoder = Decoder(
            hidden_dim,
            output_size,
            bidirectional=bidirectional,
            dropout=dropout,
            fill_val=fill_val,
        )

    def forward(self, input, target_tensor=None):
        encoder_output, encoder_hidden = self.encoder(input)
        decoder_output = self.decoder(encoder_output, encoder_hidden, target_tensor)
        return decoder_output

## Train the model


In [29]:
def train_epoch(train_dataloader, validation_dataloader, model, optimizer, criterion):
    # Training
    total_loss = 0
    for batch in train_dataloader:
        input, target = batch
        input.to(device)
        target.to(device)

        optimizer.zero_grad(set_to_none=True)

        output = model(input, target)

        loss = criterion(output.view(-1, output.size(-1)), target.view(-1))
        loss.backward()

        optimizer.step()

        total_loss += loss.item()

    # Validation loss
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in validation_dataloader:
            input, target = batch
            input.to(device)
            target.to(device)
            output = model(input, target)
            loss = criterion(output.view(-1, output.size(-1)), target.view(-1))
            val_loss += loss.item()
    model.train()

    return total_loss / len(train_dataloader), val_loss / len(validation_dataloader)

In [30]:
def train(
    train_dataloader,
    validation_dataloader,
    model,
    epochs,
    print_every=20,
    checkpoint_path=".",
    stop_early=True,
):
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    criterion = nn.CrossEntropyLoss()


    save_path = f"{checkpoint_path}/model_checkpoint.pt"

    best_loss = float("inf")

    saved_once = False


    start = time.time()

    last_epoch_time = start

    print("Training started...")


    for epoch in range(epochs):

        epoch += 1

        train_loss, validation_loss = train_epoch(
            train_dataloader, validation_dataloader, model, optimizer, criterion
        )


        # Save a checkpoint of the encoder and decoder's weights

        # Attempt to save the best of the combined losses, without overfitting (too much)

        if (
            checkpoint_path is not None
            and train_loss + validation_loss < best_loss
            and train_loss * 1.1 >= validation_loss
        ):
            torch.save(model.state_dict(), save_path)

            saved_once = True

            best_epoch = epoch


        if epoch % print_every == 0:
            print(
                f"Epoch {epoch}\t | Loss (train): {train_loss:.4f}\t| Loss (validation): {validation_loss:.4f}\t| Total Time: {time.time() - start:.4f}s\t| Epoch Time: {time.time() - last_epoch_time:.4f}s"
            )

        last_epoch_time = time.time()


        # Stop early if the model is overfitting too much

        if stop_early and train_loss * 1.2 < validation_loss and epoch >= 4:
            print("Stopping early...")

            break


    if saved_once:
        # Load the best version of the encoder and decoder's weights

        print(f"Loading best checkpoint from epoch {best_epoch}....")
        model.load_state_dict(torch.load(save_path))


    print(f"Took {time.time() - start}s")

In [31]:
def split_dataloader(dataloader, pcts):
    b = dataloader.batch_size
    datasets = torch.utils.data.random_split(dataloader.dataset, list(pcts))
    return tuple(DataLoader(dataset, b) for dataset in datasets)

In [32]:
hidden_size = 512
batch_size = 128

In [33]:
eng_vocab, fra_vocab, dataloader = create_dataloader(batch_size)

In [34]:
train_dataloader, validation_dataloader, test_dataloader = split_dataloader(
    dataloader, (0.7, 0.2, 0.1)
)

In [35]:
print(f"train_dataloader size:", len(train_dataloader.dataset))
print(f"validation_dataloader size:", len(validation_dataloader.dataset))
print(f"test_dataloader size:", len(test_dataloader.dataset))

train_dataloader size: 80918
validation_dataloader size: 23119
test_dataloader size: 11559


In [36]:
bidirectional = True
model = Seq2SeqModel(
    len(eng_vocab),
    len(fra_vocab),
    hidden_size,
    bidirectional=bidirectional,
    dropout=0.0,
).to(device)

In [37]:
EPOCHS = 4

In [38]:
train(
    train_dataloader,
    validation_dataloader,
    model,
    EPOCHS,
    print_every=1,
    checkpoint_path=None,
    stop_early=False,
)

Training started...
Epoch 1	 | Loss (train): 2.3223	| Loss (validation): 1.6109	| Total Time: 76.9069s	| Epoch Time: 76.9069s
Epoch 2	 | Loss (train): 1.1992	| Loss (validation): 1.3426	| Total Time: 153.5835s	| Epoch Time: 76.6765s
Epoch 3	 | Loss (train): 0.7852	| Loss (validation): 1.2690	| Total Time: 230.4485s	| Epoch Time: 76.8648s
Epoch 4	 | Loss (train): 0.5664	| Loss (validation): 1.2651	| Total Time: 307.6803s	| Epoch Time: 77.2317s
Took 307.68041491508484s


## Test the model


In [39]:
def tensor_to_sentence(model_output_indexes, vocab):
    sentence = []
    for word_index in model_output_indexes:
        word = vocab.lookup_token(word_index)
        if word == EOS_TOKEN:
            break
        sentence.append(word)
    return " ".join(sentence)


def inference(model, inputs):
    with torch.no_grad():
        model.eval()
        outputs = model(inputs)
        model.train()
    return outputs


def translate(model, input, eng_vocab, fra_vocab):
    with torch.no_grad():
        input = torch.tensor(
            prepare_input(input, eng_vocab), dtype=torch.long, device=device
        ).unsqueeze(0)

        output = inference(model, input)

        output = output.squeeze(0)

        _, pred_indexes = output.topk(1, dim=1)
        return tensor_to_sentence(pred_indexes, fra_vocab)

In [40]:
prepare_input("Wow", eng_vocab)

[3336, 2]

In [41]:
eng_vocab[UNK_TOKEN]

0

In [42]:
translate(model, "we had a good day", eng_vocab, fra_vocab)

'nous avons eu une bonne journee'

In [43]:
translate(model, "we went to swim", eng_vocab, fra_vocab)

'nous sommes alles nager'

In [44]:
translate(model, "we are happy", eng_vocab, fra_vocab)

'nous sommes heureux'

In [45]:
translate(model, "I don't like that", eng_vocab, fra_vocab)

'j aime ca'

In [46]:
%%timeit
translate(model, "we had a good day", eng_vocab, fra_vocab)

6.29 ms ± 37.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Bleu Scoring


In [47]:
def ref_and_pred(dataloader, model, src_vocab, target_vocab):
    references = []
    predictions = []
    for batch in dataloader:
        input, target = batch
        output = inference(model, input)
        for index, (input_sample, target_sample, output_sample) in enumerate(
            zip(input, target, output)
        ):
            english_input = tensor_to_sentence(input_sample, src_vocab)
            _, pred_indexes = output_sample.topk(1, dim=1)
            model_translation = tensor_to_sentence(pred_indexes, fra_vocab)
            ground_truth = tensor_to_sentence(target_sample, fra_vocab)
            references.append(ground_truth)
            predictions.append(model_translation)
    return references, predictions

In [49]:
bleu = evaluate.load("bleu")
predictions, references = ref_and_pred(test_dataloader, model, eng_vocab, fra_vocab)
results = bleu.compute(predictions=predictions, references=references)

In [50]:
results

{'bleu': 0.25979312491709844,
 'precisions': [0.5123417299746768,
  0.3263852783160814,
  0.20590218249764028,
  0.13229985659892202],
 'brevity_penalty': 1.0,
 'length_ratio': 1.0903938381049265,
 'translation_length': 75030,
 'reference_length': 68810}