In [47]:
import os
import re
import random
import numpy as np
from tqdm import tqdm
from torchinfo import summary
from collections import Counter
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

In [48]:
import nltk

nltk.download("punkt")
nltk.download("punkt_tab")

from nltk.tokenize import word_tokenize
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /usr/share/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


## Config

In [49]:
en_train = '/kaggle/input/ted-talks-corpus/train.en'
fr_train = '/kaggle/input/ted-talks-corpus/train.fr'
en_val = '/kaggle/input/ted-talks-corpus/dev.en'
fr_val = '/kaggle/input/ted-talks-corpus/dev.fr'
en_test = '/kaggle/input/ted-talks-corpus/test.en'
fr_test = '/kaggle/input/ted-talks-corpus/test.fr'

In [None]:
train = False
padding_before = False
plot_losses = False

In [51]:
embedding_dim = 300
max_length = 64
lr=1e-4

heads = 6
layers = 6

epochs = 15
batch_size = 32

In [52]:
os.makedirs("models", exist_ok=True)

save_path="./models/transformer"

save_path = save_path + f"_heads{heads}_layers{layers}"

save_path = save_path + ".pth"

print(f"Saving model to {save_path}")

Saving model to ./models/transformer_heads6_layers6.pth


In [53]:
random_seed = 42

random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
print("Using Random Seed:", random_seed)

Using Random Seed: 42


In [54]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

Using cuda


## Utils

In [55]:
def clean_text(text):
    text = str(text).lower().strip()
    text = text.rstrip('\n')
#     text = re.sub(r"<[^>]+>", "", text)
    text = re.sub(r"[^a-zA-ZÀ-ÿ0-9\s.,;!?':()\[\]{}-]", " ", text)  # Keep selected punctuation marks, symbols and apostrophes
    text = re.sub(r"\s+", " ", text)

    text = text.encode("utf-8", errors="ignore").decode("utf-8")  # Corrected encoding

    return text

def clean_sentences(sentences):
    sentences = [clean_text(sentence) for sentence in sentences]
    sentences = [s for s in sentences if s and s != ""]  # remove empty strings
    return sentences

In [56]:
def read_data(en_path, fr_path):
    with open(en_path, "r") as f:
        en_data = f.readlines()
    with open(fr_path, "r") as f:
        fr_data = f.readlines()

    assert len(en_data) == len(fr_data), "Data mismatch"

    en_data = clean_sentences(en_data)
    fr_data = clean_sentences(fr_data)

    assert len(en_data) == len(fr_data), "Data mismatch in cleaned data"

    return en_data, fr_data

def word_tokenizer(sentence):
    words = word_tokenize(sentence)
    return words

In [57]:
def flatten_concatenation(list_of_lists, unique=False):
    # flat_list = []
    # for sublist in list_of_lists:
    #     flat_list += sublist

    # flat_list = list(set(flat_list))
    # return flat_list
    flat_array = np.concatenate(list_of_lists)
    if unique:
        flat_list = np.unique(flat_array).tolist()
    else:
        flat_list = flat_array.tolist()
    return flat_list

In [58]:
def reverse_vocab(vocab):
    return {v: k for k, v in vocab.items()}

In [59]:
def return_words_till_EOS(lst, eos=2):
    if eos not in lst:
        return lst
    return lst[:lst.index(eos)]

### Dataset

In [60]:
def pad_sequence(sequence, max_len, before=True, pad_token=0):
    if len(sequence) > max_len:
        return sequence[:max_len]
    elif before:
        return [pad_token] * (max_len - len(sequence)) + sequence
    else:
        return sequence + [pad_token] * (max_len - len(sequence))

In [61]:
class MyDataset(Dataset):
    def __init__(
        self,
        en_data,
        fr_data,
        en_vocab,
        fr_vocab,
        pad_before=False,
    ):
        self.en_data = []
        self.fr_data = []
        self.labels = []
        self.en_vocab = en_vocab
        self.fr_vocab = fr_vocab

        assert len(en_data) == len(fr_data)
        self.length = len(en_data)

        en_pad = self.en_vocab["<pad>"]
        en_unk = self.en_vocab["<unk>"]
        en_sos = self.en_vocab["<sos>"]
        en_eos = self.en_vocab["<eos>"]
        fr_pad = self.fr_vocab["<pad>"]
        fr_unk = self.fr_vocab["<unk>"]
        fr_sos = self.fr_vocab["<sos>"]
        fr_eos = self.fr_vocab["<eos>"]

        tqdm_obj = tqdm(
            total=self.length, desc="Creating dataset"
        )
        for index, (en_sentence, fr_sentence) in enumerate(zip(en_data, fr_data)):
            en_indices = [int(self.en_vocab.get(w, en_unk)) for w in en_sentence]
            en_indices = [en_sos] + en_indices[: max_length - 2] + [en_eos]
            en_indices = pad_sequence(
                en_indices, max_length, before=pad_before, pad_token=en_pad
            )
            self.en_data.append(
                torch.tensor(en_indices, dtype=torch.int, device=device)
            )

            fr_indices1 = [int(self.fr_vocab.get(w, fr_unk)) for w in fr_sentence]
            fr_indices = [fr_sos] + fr_indices1
            fr_indices = pad_sequence(
                fr_indices, max_length, before=pad_before, pad_token=fr_pad
            )
            self.fr_data.append(
                torch.tensor(fr_indices, dtype=torch.int, device=device)
            )

            fr_indices = fr_indices1 + [fr_eos]
            fr_indices = pad_sequence(
                fr_indices, max_length, before=pad_before, pad_token=fr_pad
            )
            self.labels.append(torch.tensor(fr_indices, device=device))

            if index % 10 == 0:
                tqdm_obj.update(10)

        tqdm_obj.close()

        print(f"Dataset created with {self.length} samples")

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return self.en_data[idx], self.fr_data[idx], self.labels[idx]

### Model

In [62]:
def create_positional_encoding(max_length, embedding_dim):
    pe = torch.zeros(max_length, embedding_dim)
    position = torch.arange(0, max_length).unsqueeze(1)
    div_term = torch.exp(
        torch.arange(0, embedding_dim, 2) * -(np.log(10000.0) / embedding_dim)
    )
    pe[:, 0::2] = torch.sin(position.float() * div_term)
    pe[:, 1::2] = torch.cos(position.float() * div_term)
    return pe.to(device)


def make_src_mask(src):
    src1 = src
    if len(src.shape) == 3:
        src1 = torch.sum(src, dim=-1)

    src_mask = (src1 != 0).unsqueeze(1).unsqueeze(2)
    return src_mask.to(device)


def make_trg_mask(trg):
    trg1 = trg
    if len(trg.shape) == 3:
        trg1 = torch.sum(trg, dim=-1)
    
    n, trg_len = trg1.size()
    trg_mask = torch.tril(torch.ones(trg_len, trg_len)).expand(n, 1, trg_len, trg_len)
    return trg_mask.to(device)

In [63]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dim: int = 512, num_heads: int = 8):
        super(MultiHeadAttention, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.head_dim = embedding_dim // num_heads

        assert (
            self.head_dim * num_heads == embedding_dim
        ), "Embedding dimension must be divisible by number of heads"

        self.q = nn.Linear(self.head_dim, self.head_dim)
        self.k = nn.Linear(self.head_dim, self.head_dim)
        self.v = nn.Linear(self.head_dim, self.head_dim)
        self.fc = nn.Linear(self.embedding_dim, self.embedding_dim)

    def forward(self, value, key, query, mask):
        n = query.size(0)
        query_len, key_len, value_len = query.size(1), key.size(1), value.size(1)

        value = self.v(value.reshape(n, value_len, self.num_heads, self.head_dim))
        query = self.q(query.reshape(n, query_len, self.num_heads, self.head_dim))
        key = self.k(key.reshape(n, key_len, self.num_heads, self.head_dim))

        energy = torch.einsum("nqhd,nkhd->nhqk", [query, key])
        if mask is not None:
            energy = energy.masked_fill(mask == 0, -float("inf"))
        attention = F.softmax(energy / np.sqrt(self.head_dim), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, value]).reshape(
            n, query_len, self.embedding_dim
        )
        out = self.fc(out)

        return out

In [64]:
class TransformerBlock(nn.Module):
    def __init__(
        self,
        embed_size: int,
        heads: int,
        forward_expansion: int,
        dropout: float,
    ):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadAttention(embed_size, heads)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )

        self.layer_norm1 = nn.Sequential(
            nn.LayerNorm(embed_size),
            nn.Dropout(dropout),
        )
        self.layer_norm2 = nn.Sequential(
            nn.LayerNorm(embed_size),
            nn.Dropout(dropout),
        )

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        x = self.layer_norm1(attention + query)
        forward = self.feed_forward(x)
        out = self.layer_norm2(forward + x)
        return out

In [65]:
class Encoder(nn.Module):
    def __init__(
        self,
        src_vocab_size: int,
        embed_size: int,
        num_layers: int,
        heads: int,
        forward_expansion: int,
        dropout: float,
        max_len: int,
    ):
        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_len, embed_size)
        self.layers = nn.ModuleList(
            [
                TransformerBlock(embed_size, heads, forward_expansion, dropout)
                for _ in range(num_layers)
            ]
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        n, seq_len = x.size()
        positions = torch.arange(0, seq_len).expand(n, seq_len).to(device)
        out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))
        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out

In [66]:
class DecoderBlock(nn.Module):
    def __init__(
        self, embed_size: int, heads: int, forward_expansion: int, dropout: float
    ):
        super(DecoderBlock, self).__init__()
        self.attention = MultiHeadAttention(embed_size, heads)
        self.transformer_block = TransformerBlock(
            embed_size, heads, forward_expansion, dropout
        )
        self.layer_norm = nn.Sequential(
            nn.LayerNorm(embed_size),
            nn.Dropout(dropout),
        )

    def forward(self, x, value, key, src_mask, trg_mask):
        attention = self.attention(x, x, x, trg_mask)
        query = self.layer_norm(attention + x)
        out = self.transformer_block(value, key, query, src_mask)
        return out

In [67]:
class Decoder(nn.Module):
    def __init__(
        self,
        trg_vocab_size: int,
        embed_size: int,
        num_layers: int,
        heads: int,
        forward_expansion: int,
        dropout: float,
        max_len: int,
    ):
        super(Decoder, self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_len, embed_size)
        self.layers = nn.ModuleList(
            [
                DecoderBlock(embed_size, heads, forward_expansion, dropout)
                for _ in range(num_layers)
            ]
        )
        self.fc = nn.Linear(embed_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask, trg_mask):
        n, seq_len = x.size()
        positions = torch.arange(0, seq_len).expand(n, seq_len).to(device)
        x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))
        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)
        out = self.fc(x)
        return out

In [68]:
class Transformer(nn.Module):
    def __init__(
        self,
        src_vocab_size: int,
        trg_vocab_size: int,
        embed_size: int = 512,
        num_layers: int = 6,
        forward_expansion: int = 4,
        heads: int = 8,
        dropout: float = 0.2,
        max_len: int = 50,
        save_path=None,
    ):
        super(Transformer, self).__init__()
        self.encoder = Encoder(
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            max_len,
        )
        self.decoder = Decoder(
            trg_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            max_len,
        )
        self.best_val_loss = float("inf")
        self.save_path = save_path

    def forward(self, src, trg):
        src_mask = make_src_mask(src)
        trg_mask = make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_src, src_mask, trg_mask)
        return out

    def fit(self, train_loader, val_loader, criterion, optimizer, num_epochs: int = 10):
        train_losses = []
        val_losses = []

        for epoch in range(num_epochs):
            self.train()
            train_loss = 0
            for src, trg, label in tqdm(train_loader, total=len(train_loader)):
                optimizer.zero_grad()
                output = self(src, trg)
                output = output.reshape(-1, output.size(-1))
                label = label.reshape(-1)

                loss = criterion(output, label)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()

            train_loss /= len(train_loader)
            train_losses.append(train_loss)
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {train_loss}")

            if device == "cuda":
                torch.cuda.empty_cache()

            val_loss = self.evaluate(val_loader, criterion, True)
            val_losses.append(val_loss)
            print(f"Validation Loss: {val_loss}")

            if self.save_path and val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                torch.save(self.state_dict(), self.save_path)

        return train_losses, val_losses

    def evaluate(self, val_loader, criterion, tqdm_disabled: bool = False):
        self.eval()
        val_loss = 0
        with torch.no_grad():
            for src, trg_input, trg_target in tqdm(
                val_loader, total=len(val_loader), disable=tqdm_disabled
            ):
                output = self(src, trg_input)
                output_dim = output.shape[-1]
                output = output.view(-1, output_dim)
                trg_target = trg_target.view(-1)
                loss = criterion(output, trg_target)
                val_loss += loss.item()
        val_loss /= len(val_loader)
        return val_loss

    def load(self, path=None):
        if path:
            self.load_state_dict(torch.load(path))
        elif self.save_path:
            self.load_state_dict(torch.load(self.save_path))
        else:
            raise ValueError("No model path provided")

    def predict(
        self,
        src,
        src_preprocessed=False,
        max_len=64,
        return_sentence=False,
        fr_vocab=None,
        en_vocab=None,
        start_token_idx=1,
        end_token_idx=2,
    ):
        self.eval()  # Set the model to evaluation mode

        if not src_preprocessed:
            assert en_vocab is not None
            src = word_tokenizer(src)
            src = [en_vocab.get(w, en_vocab["<unk>"]) for w in src]
            src = [start_token_idx] + src[: max_len - 2] + [end_token_idx]
            src = pad_sequence(src, max_len, before=padding_before, pad_token=0)
            src = torch.tensor([src], dtype=torch.int, device=device)

        trg = torch.tensor([[start_token_idx]], dtype=torch.int, device=device)

        src_mask = make_src_mask(src)
        with torch.no_grad():
            enc_src = self.encoder(src, src_mask)

        for _ in range(max_len):
            trg_mask = make_trg_mask(trg)

            with torch.no_grad():
                output = self.decoder(trg, enc_src, src_mask, trg_mask)
                output = output[:, -1]

            next_token = output.argmax(-1).unsqueeze(0)
            trg = torch.cat((trg, next_token), dim=1)

            if next_token.item() == end_token_idx:
                break

        generated_sequence = trg.squeeze(0).tolist()[1:]
        if generated_sequence[-1] == 2:
            generated_sequence = generated_sequence[:-1]

        if return_sentence:
            assert fr_vocab is not None
            fr_vocab_rev = reverse_vocab(fr_vocab)
            generated_sequence = [fr_vocab_rev[idx] for idx in generated_sequence]

        return generated_sequence

    def test(self, test_loader, en_vocab, fr_vocab):
        self.eval()
        bleu_scores = []

        reverse_fr_vocab = reverse_vocab(fr_vocab)

        with torch.no_grad():
            for src, _, label in tqdm(test_loader, total=len(test_loader)):
                for i in range(src.size(0)):
                    src_i = src[i].unsqueeze(0)
                    trg_target_i = label[i].unsqueeze(0)

                    candidate = self.predict(
                        src_i,
                        src_preprocessed=True,
                        max_len=max_length,
                        fr_vocab=fr_vocab,
                        start_token_idx=en_vocab["<sos>"],
                        end_token_idx=en_vocab["<eos>"],
                    )
                    candidate = [reverse_fr_vocab[idx] for idx in candidate]

                    reference = return_words_till_EOS(
                        trg_target_i.squeeze(0).tolist(), eos=fr_vocab["<eos>"]
                    )
                    reference = [reverse_fr_vocab[idx] for idx in reference]

                    bleu_score = sentence_bleu(
                        [reference],
                        candidate,
                        smoothing_function=SmoothingFunction().method1,
                    )
                    bleu_scores.append(bleu_score)

        print(f"Test BLEU Score: {np.mean(bleu_scores)}")
        # print("BLEU Scores:", bleu_scores)

        if device == "cuda":
            torch.cuda.empty_cache()

## Main

In [69]:
train_en, train_fr = read_data(en_train, fr_train)
val_en, val_fr = read_data(en_val, fr_val)
test_en, test_fr = read_data(en_test, fr_test)

In [70]:
train_en_words = [word_tokenizer(s) for s in train_en]
train_fr_words = [word_tokenizer(s) for s in train_fr]
val_en_words = [word_tokenizer(s) for s in val_en]
val_fr_words = [word_tokenizer(s) for s in val_fr]
test_en_words = [word_tokenizer(s) for s in test_en]
test_fr_words = [word_tokenizer(s) for s in test_fr]

all_en_words = flatten_concatenation(train_en_words + val_en_words + test_en_words)
all_fr_words = flatten_concatenation(train_fr_words + val_fr_words + test_fr_words)

In [71]:
en_word_counts = Counter(all_en_words)
assert en_word_counts.total() == len(all_en_words)
fr_word_counts = Counter(all_fr_words)
assert fr_word_counts.total() == len(all_fr_words)

en_vocab = {'<pad>': 0, '<sos>': 1, '<eos>': 2, '<unk>': 3}
fr_vocab = {'<pad>': 0, '<sos>': 1, '<eos>': 2, '<unk>': 3}

for word, count in en_word_counts.items():
    # if count > 1:
        en_vocab[word] = len(en_vocab.keys())

for word, count in fr_word_counts.items():
    # if count > 1:
        fr_vocab[word] = len(fr_vocab.keys())

In [72]:
if train:
    train_dataset = MyDataset(
        train_en_words,
        train_fr_words,
        en_vocab,
        fr_vocab,
        padding_before,
    )
    val_dataset = MyDataset(
        val_en_words,
        val_fr_words,
        en_vocab,
        fr_vocab,
        padding_before,
    )
test_dataset = MyDataset(
    test_en_words,
    test_fr_words,
    en_vocab,
    fr_vocab,
    padding_before,
)

Creating dataset: 100%|██████████| 30000/30000 [00:04<00:00, 7149.03it/s]


Dataset created with 30000 samples


Creating dataset: 890it [00:00, 7017.12it/s]                         


Dataset created with 887 samples


Creating dataset: 1310it [00:00, 7131.28it/s]                         

Dataset created with 1305 samples





In [73]:
if train:
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
    )
    print("Length of train_loader:", len(train_loader))
    print("Length of val_loader:", len(val_loader))

test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
print("Length of test_loader:", len(test_loader))

Length of train_loader: 938
Length of val_loader: 28
Length of test_loader: 1305


In [74]:
model = Transformer(
    len(en_vocab),
    len(fr_vocab),
    embed_size=embedding_dim,
    num_layers=layers,
    heads=heads,
    forward_expansion=4,
    dropout=0.2,
    max_len=max_length,
    save_path=save_path,
).to(device)

In [75]:
# print(model)
# summary(model, input_data=[
#     torch.randint(low=0, high=len(en_vocab), size=(batch_size, max_length), device=device)
#     torch.randint(low=0, high=len(fr_vocab), size=(batch_size, max_length), device=device)
# ])

In [76]:
optimizer = optim.Adam(model.parameters(), lr=lr) # type: ignore
criterion = nn.CrossEntropyLoss(ignore_index=0)

In [77]:
if train:
    train_losses, val_losses = model.fit(
        train_loader,
        val_loader,
        criterion,
        optimizer,
        num_epochs=epochs,
    )

    if plot_losses:
        plt.plot(train_losses, label="Train Loss")
        plt.plot(val_losses, label="Validation Loss")
        plt.legend()
        plt.show()

100%|██████████| 938/938 [01:21<00:00, 11.54it/s]


Epoch 1/15, Loss: 6.90152710282218
Validation Loss: 6.569449509893145


100%|██████████| 938/938 [01:21<00:00, 11.57it/s]


Epoch 2/15, Loss: 6.272208140094651
Validation Loss: 6.2743649653026035


100%|██████████| 938/938 [01:21<00:00, 11.56it/s]


Epoch 3/15, Loss: 5.933016585388672
Validation Loss: 6.088992510523115


100%|██████████| 938/938 [01:21<00:00, 11.57it/s]


Epoch 4/15, Loss: 5.707155293238951
Validation Loss: 5.966449567249843


100%|██████████| 938/938 [01:21<00:00, 11.58it/s]


Epoch 5/15, Loss: 5.5347091053594655
Validation Loss: 5.874079857553754


100%|██████████| 938/938 [01:20<00:00, 11.58it/s]


Epoch 6/15, Loss: 5.390632296421889
Validation Loss: 5.781954424721854


100%|██████████| 938/938 [01:21<00:00, 11.56it/s]


Epoch 7/15, Loss: 5.270443738904843
Validation Loss: 5.727157711982727


100%|██████████| 938/938 [01:21<00:00, 11.55it/s]


Epoch 8/15, Loss: 5.170359751308904
Validation Loss: 5.649943845612662


100%|██████████| 938/938 [01:21<00:00, 11.51it/s]


Epoch 9/15, Loss: 5.0770362841803385
Validation Loss: 5.600454092025757


100%|██████████| 938/938 [01:21<00:00, 11.53it/s]


Epoch 10/15, Loss: 4.992152666994758
Validation Loss: 5.543810163225446


100%|██████████| 938/938 [01:21<00:00, 11.56it/s]


Epoch 11/15, Loss: 4.917658916160242
Validation Loss: 5.544030870710101


100%|██████████| 938/938 [01:21<00:00, 11.54it/s]


Epoch 12/15, Loss: 4.850807659661592
Validation Loss: 5.512316295078823


100%|██████████| 938/938 [01:21<00:00, 11.54it/s]


Epoch 13/15, Loss: 4.785940251624915
Validation Loss: 5.469122852597918


100%|██████████| 938/938 [01:21<00:00, 11.55it/s]


Epoch 14/15, Loss: 4.726124031457311
Validation Loss: 5.45763145174299


100%|██████████| 938/938 [01:21<00:00, 11.57it/s]


Epoch 15/15, Loss: 4.675078034146762
Validation Loss: 5.429043531417847


In [78]:
model.load()

  self.load_state_dict(torch.load(self.save_path))


In [79]:
model.evaluate(test_loader, criterion)

100%|██████████| 1305/1305 [00:18<00:00, 71.26it/s]


4.895798279407837

In [80]:
model.predict(
    "I am a student.",
    en_vocab=en_vocab,
    fr_vocab=fr_vocab,
    return_sentence=True,
)

['les', 'gens', 'sont', 'un', 'peu', 'de', 'la', 'vie', '.']

In [81]:
model.test(test_loader, en_vocab, fr_vocab)

100%|██████████| 1305/1305 [05:42<00:00,  3.81it/s]

Test BLEU Score: 0.031163257459442834



