In [1]:
import time
from collections.abc import Callable

import evaluate
import matplotlib.pyplot as plt
import numpy as np
import spacy
import torch
import torch.nn.functional as F  # noqa: N812
from spacy.language import Language
from tokenizers import Tokenizer, decoders, normalizers
from tokenizers.models import WordPiece
from tokenizers.normalizers import NFD, Lowercase, StripAccents
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.processors import TemplateProcessing
from tokenizers.trainers import WordPieceTrainer
from torch import Tensor, nn
from tqdm import tqdm

import datasets


  from .autonotebook import tqdm as notebook_tqdm


### Data processing

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    # for Apple chips
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(device)

mps


In [3]:
dataset = datasets.load_dataset("bentrevett/multi30k")

In [4]:
train_data, valid_data, test_data = (
    dataset["train"],
    dataset["validation"],
    dataset["test"],
)

In [5]:
train_data[0]

{'en': 'Two young, White males are outside near many bushes.',
 'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.'}

In [6]:
train_data[:2]

{'en': ['Two young, White males are outside near many bushes.',
  'Several men in hard hats are operating a giant pulley system.'],
 'de': ['Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.',
  'Mehrere Männer mit Schutzhelmen bedienen ein Antriebsradsystem.']}

In [8]:
def craete_tokenizer(max_vocab_size: int = 4000) -> tuple[Tokenizer, WordPieceTrainer]:
    tokenizer = Tokenizer(WordPiece(unk_token="<UNK>"))  # noqa: S106
    tokenizer.normalizer = normalizers.Sequence([NFD(), Lowercase(), StripAccents()])
    tokenizer.pre_tokenizer = Whitespace()
    tokenizer.post_processor = TemplateProcessing(
        single="<BOS> $A <EOS>",
        special_tokens=[
            ("<BOS>", 1),
            ("<EOS>", 2),
        ],
    )
    tokenizer.decoder = decoders.WordPiece()
    trainer = WordPieceTrainer(vocab_size=max_vocab_size,
                               special_tokens=["<UNK>", "<BOS>", "<EOS>", "<PAD>"])

    return tokenizer, trainer

def batch_iterator(batch_size: int, data: datasets.Dataset, key: str) -> list[str]:
    for i in range(0, len(data), batch_size):
        yield data[i : i + batch_size][key]

In [9]:
src_lang, src_lang_trainer = craete_tokenizer()
tgt_lang, tgt_lang_trainer = craete_tokenizer()

src_lang.train_from_iterator(batch_iterator(100, train_data, "de"), src_lang_trainer)
tgt_lang.train_from_iterator(batch_iterator(100, train_data, "en"), tgt_lang_trainer)









In [10]:
def tokenize_example(example: dict[str, str],
                     en_tokenizer: Tokenizer,
                     de_tokenizer: Tokenizer,
                     max_length: int) -> dict[str, list[str | int]]:
    en_encoded = en_tokenizer.encode(example["en"])
    de_encoded = de_tokenizer.encode(example["de"])

    return {"en_tokens": en_encoded.tokens[:max_length],
            "de_tokens": de_encoded.tokens[:max_length],
            "en_ids": en_encoded.ids[:max_length],
            "de_ids": de_encoded.ids[:max_length]}

max_length = 1000
sos_token = "<BOS>"  # noqa: S105
eos_token = "<EOS>"  # noqa: S105

fn_kwargs = {
    "en_tokenizer": tgt_lang,
    "de_tokenizer": src_lang,
    "max_length": max_length,
}

train_data = train_data.map(tokenize_example, fn_kwargs=fn_kwargs)
valid_data = valid_data.map(tokenize_example, fn_kwargs=fn_kwargs)
test_data = test_data.map(tokenize_example, fn_kwargs=fn_kwargs)

Map: 100%|██████████| 29000/29000 [00:02<00:00, 11874.73 examples/s]
Map: 100%|██████████| 1014/1014 [00:00<00:00, 11207.77 examples/s]
Map: 100%|██████████| 1000/1000 [00:00<00:00, 11190.57 examples/s]


In [11]:
train_data.num_rows

29000

In [12]:
print(f"Vocab size src: {src_lang.get_vocab_size()}")
print(f"Vocab size tgt: {tgt_lang.get_vocab_size()}")

Vocab size src: 4000
Vocab size tgt: 4000


In [13]:
print(src_lang.decode(train_data[2]["de_ids"]))
print(tgt_lang.decode(train_data[2]["en_ids"]))

ein kleines madchen klettert in ein spielhaus aus holz.
a little girl climbing into a wooden playhouse.


In [14]:
data_type = "torch"
format_columns = ["en_ids", "de_ids"]

train_data = train_data.with_format(
    type=data_type, columns=format_columns, output_all_columns=True
    )

valid_data = valid_data.with_format(
    type=data_type, columns=format_columns, output_all_columns=True
)

test_data = test_data.with_format(
    type=data_type, columns=format_columns, output_all_columns=True
)

In [15]:
def get_collate_fn(pad_index: int) -> Callable[[dict], dict]:
    def collate_fn(batch: dict) -> dict[str, int]:
        batch_en_ids = [example["en_ids"] for example in batch]
        batch_de_ids = [example["de_ids"] for example in batch]
        batch_en_ids = nn.utils.rnn.pad_sequence(batch_en_ids, padding_value=pad_index)
        batch_de_ids = nn.utils.rnn.pad_sequence(batch_de_ids, padding_value=pad_index)
        return {
            "en_ids": batch_en_ids,
            "de_ids": batch_de_ids,
        }

    return collate_fn

def get_data_loader(dataset: datasets.Dataset, batch_size: int, pad_index: int, *,
                    shuffle: bool = True, pin_memory: bool = False) -> torch.utils.data.DataLoader:
    collate_fn = get_collate_fn(pad_index)
    return torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        shuffle=shuffle,
        pin_memory=pin_memory,
    )

batch_size = 128

train_data_loader = get_data_loader(train_data, batch_size, tgt_lang.token_to_id("<PAD>"))
valid_data_loader = get_data_loader(valid_data, batch_size, tgt_lang.token_to_id("<PAD>"))
test_data_loader = get_data_loader(test_data, batch_size, tgt_lang.token_to_id("<PAD>"))

### Model

In [16]:
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim: int, num_heads: int, dropout_ratio: float = 0.1):
        if hidden_dim % num_heads != 0:
            msg = "hidden_dim must be divisible by num_heads"
            raise ValueError(msg)

        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads

        self.fc_q = nn.Linear(hidden_dim, hidden_dim)
        self.fc_k = nn.Linear(hidden_dim, hidden_dim)
        self.fc_v = nn.Linear(hidden_dim, hidden_dim)

        self.out = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout_ratio)

        self.scaling = 1 / (self.head_dim ** .5)

    def forward(self, q: Tensor, k: Tensor, v: Tensor, mask: Tensor = None) -> Tensor:
        """
        Inputs
            q: query of size (batch_size, seq_length, hidden_dim)
            k: query of size (batch_size, seq_length, hidden_dim)
            v: query of size (batch_size, seq_length, hidden_dim)
            mask: optional mask of size (batch_size, 1, 1, seq_length)
                  or (batch_size, 1, seq_length, seq_length)
        Outputs
            attention weighted embedding vectors of size (batch_size, seq_length, hidden_dim)
        """
        # all Q, K, V are of shape (batch_size, seq_length, hidden_dim)
        Q = self.fc_q(q)
        K = self.fc_k(k)
        V = self.fc_v(v)

        batch_size, seq_length, _ = Q.size()

        # all Q, K, V are of shape (batch_size, num_heads, seq_length, head_dim)
        Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        # energy.shape (batch_size, num_heads, seq_length, seq_length)
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) * self.scaling  # type: Tensor

        if mask is not None:
            energy = energy.masked_fill(mask == 0, -torch.inf)

        # attention.shape (batch_size, num_heads, seq_length, seq_length)
        attention = energy.softmax(dim=-1)
        attention = self.dropout(attention)

        # x.shape (batch_size, num_heads, seq_length, head_dim)
        x = torch.matmul(attention, V)

        # x.shape (batch_size, seq_length, num_heads, head_dim)
        x = x.permute(0, 2, 1, 3)

        # x.shape (batch_size, seq_length, hidden_dim)
        x = x.reshape(batch_size, seq_length, self.hidden_dim)

        # x.shape (batch_size, seq_length, hidden_dim)
        x = self.out(x)

        return x

In [41]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, hidden_dim: int, ff_dim: int, dropout_ratio: float = 0.1):
        super().__init__()

        self.fc_1 = nn.Linear(hidden_dim, ff_dim)
        self.fc_2 = nn.Linear(ff_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout_ratio)

    def forward(self, x) -> Tensor:
        # x.shape (batch_size, seq_length, emb_dim)
        x = self.fc_1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc_2(x)
        return x

class EncoderLayer(nn.Module):
    def __init__(self, hidden_dim: int, num_heads: int, ff_dim:int, dropout_ratio: float = 0.1):
        super().__init__()

        self.norm_1 = nn.LayerNorm(hidden_dim)
        self.norm_2 = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(dropout_ratio)
        self.ff = PositionWiseFeedForward(hidden_dim=hidden_dim, ff_dim=ff_dim)

        self.mha = MultiHeadAttention(hidden_dim=hidden_dim,
                                      num_heads=num_heads,
                                      dropout_ratio=dropout_ratio)

    def forward(self, src: Tensor, mask: Tensor | None = None) -> Tensor:
        """
        Inputs
            input of size (batch_size, seq_length, hidden_dim)
            mask of size (batch_size, 1, 1, seq_length)
        Outputs
            (batch_size, seq_length, hidden_dim)
        """
        # x1.shape (batch_size, seq_length, hidden_dim)
        x1 = self.mha(src, src, src, mask=mask)  # type: Tensor
        x1 = self.norm_1(self.dropout(x1) + src)

        # x2.shape (batch_size, seq_length, hidden_dim)
        x2 = self.ff(x1)
        x2 = self.norm_2(x1 + self.dropout(x2))

        return x2

In [42]:
class Encoder(nn.Module):
    def __init__(self, vocab_size: int, hidden_dim: int,
                 num_heads: int, num_layers: int, ff_dim: int, max_seq_length: int,
                 device: torch.device, dropout_ratio: float = 0.1):
        super().__init__()

        self.device = device
        self.scaling = hidden_dim ** (0.5)

        self.token_embedding = nn.Embedding(vocab_size, hidden_dim)
        self.positional_embedding = nn.Embedding(max_seq_length, hidden_dim)
        self.dropout = nn.Dropout(dropout_ratio)

        self.layers = nn.ModuleList([EncoderLayer(hidden_dim=hidden_dim,
                                                  num_heads=num_heads, ff_dim=ff_dim,
                                                  dropout_ratio=dropout_ratio)
                                     for _ in range(num_layers)])

    def forward(self, src: Tensor, mask: Tensor | None = None) -> Tensor:
        """
        Inputs
            input of shape (batch_size, seq_legth)
        Outputs
            encoded sequence of shape (batch_size, seq_legth, hidden_dim)
        """
        batch_size, seq_length = src.size()
        positions = torch.arange(0, seq_length).unsqueeze(0).repeat(batch_size, 1).to(self.device)

        # x.shape (batch_size, seq_legth, hidden_dim)
        x = self.token_embedding(src) * self.scaling + self.positional_embedding(positions)
        x = self.dropout(x)

        for layer in self.layers:
            # x.shape (batch_size, seq_legth, hidden_dim)
            x = layer(x, mask)

        return x

In [43]:
class DecoderLayer(nn.Module):
    def __init__(self, hidden_dim: int, num_heads: int,
                 ff_dim:int, dropout_ratio: float = 0.1):
        super().__init__()

        self.norm_1 = nn.LayerNorm(hidden_dim)
        self.norm_2 = nn.LayerNorm(hidden_dim)
        self.norm_3 = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(dropout_ratio)
        self.num_heads = num_heads
        self.ff = PositionWiseFeedForward(hidden_dim=hidden_dim, ff_dim=ff_dim)

        self.self_attention = MultiHeadAttention(hidden_dim=hidden_dim,
                                                 num_heads=num_heads,
                                                 dropout_ratio=dropout_ratio)

        self.enc_attention = MultiHeadAttention(hidden_dim=hidden_dim,
                                                num_heads=num_heads,
                                                dropout_ratio=dropout_ratio)

    def forward(self, dec_input: Tensor, enc_outputs: Tensor,
                enc_mask: Tensor, dec_mask: Tensor) -> Tensor:
        """
        Inputs
            dec_input of shape (batch_size, seq_length, hidden_dim)
            enc_input of shape (batch_size, seq_length, hidden_dim)
            enc_mask of shape (batch_size, 1, 1, seq_length)
            dec_mask of shape (batch_size, 1, seq_length, seq_length)
        Outputs
            (batch_size, seq_length, hidden_dim)
        """
        # x1.shape (batch_size, seq_length, hidden_dim)
        x1 = self.self_attention(dec_input, dec_input, dec_input, mask=dec_mask)

        x1 = self.norm_1(dec_input + self.dropout(x1))

        # x2.shape (batch_size, seq_length, hidden_dim)
        x2 = self.enc_attention(x1, enc_outputs, enc_outputs, mask=enc_mask)

        x2 = self.norm_2(x1 + self.dropout(x2))
        x2 = self.norm_3(self.ff(x2))

        return x2

In [44]:
class Decoder(nn.Module):
    def __init__(self, vocab_size: int, hidden_dim: int,
                 num_heads: int, num_layers: int, ff_dim: int, max_seq_length: int,
                 device: torch.device, dropout_ratio: float = 0.1):
        super().__init__()

        self.device = device
        self.scaling = hidden_dim ** (0.5)

        self.token_embedding = nn.Embedding(vocab_size, hidden_dim)
        self.positional_embedding = nn.Embedding(max_seq_length, hidden_dim)
        self.dropout = nn.Dropout(dropout_ratio)
        self.fc = nn.Linear(hidden_dim, vocab_size)

        self.layers = nn.ModuleList([DecoderLayer(hidden_dim=hidden_dim,
                                                  num_heads=num_heads, ff_dim=ff_dim,
                                                  dropout_ratio=dropout_ratio)
                                     for _ in range(num_layers)])

    def forward(self, dec_input: Tensor, enc_outputs: Tensor,
                enc_mask: Tensor, dec_mask: Tensor) -> Tensor:
        """
        Inputs
            dec_inputs of shape (batch_size, seq_legth)
            enc_outputs of shape (batch_size, seq_legth, hidden_dim)
            dec_mask of shape (batch_size, 1, seq_length, seq_length)
            enc_mask of shape (batch_size, 1, 1, seq_length)
        Outputs
            log-probabilities of shape (batch_size, seq_legth, hidden_dim)
        """
        batch_size, seq_length = dec_input.size()
        positions = torch.arange(0, seq_length).unsqueeze(0).repeat(batch_size, 1).to(self.device)

        # x.shape (batch_size, seq_legth, hidden_dim)
        x = self.token_embedding(dec_input) * self.scaling + self.positional_embedding(positions)
        x = self.dropout(x)

        for layer in self.layers:
            # x.shape (batch_size, seq_legth, hidden_dim)
            x = layer(x, enc_outputs, enc_mask, dec_mask)

        # x.shape (batch_size, seq_length, vocab_size)
        x = self.fc(x)
        x = F.log_softmax(x, dim=-1)

        return x

In [45]:
class TranslatorModel(nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder, enc_pad_token: int,
                 dec_pad_token: int, device: torch.device):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.enc_pad_token = enc_pad_token
        self.dec_pad_token = dec_pad_token
        self.device = device

    def _create_enc_mask(self, enc_inputs: Tensor) -> Tensor:
        """
        Inputs
            enc_inputs of shape (batch_size, seq_length)
        Outputs
            mask with 0s for PAD tokens of shape (batch_size, 1, 1, seq_length)
        """
        return (enc_inputs != self.enc_pad_token).unsqueeze(1).unsqueeze(2)

    def _create_dec_mask(self, dec_inputs: Tensor) -> Tensor:
        """
        Inputs
            dec_inputs of shape (batch_size, seq_length)
        Outputs
            masks PAD tokens and future tokens; shape (batch_size, 1, seq_length, seq_length)
        """
        # mask_1 shape (batch_size, 1, 1, seq_length)
        mask_1 = (dec_inputs != self.dec_pad_token).unsqueeze(1).unsqueeze(2)

        batch_size, seq_length = dec_inputs.size()

        # mask_2 shape (seq_length, seq_length)
        mask_2 = torch.tril(torch.ones((seq_length, seq_length), device= self.device)).bool()

        # mask_2 shape (batch_size, 1, seq_length, seq_length)
        mask_2 = mask_2.unsqueeze(0).unsqueeze(0).repeat(batch_size, 1, 1, 1)

        mask = mask_1 & mask_2

        return mask

    def forward(self, src: Tensor, tgt: Tensor) -> Tensor:
        """
        Inputs
            src of shape (batch_size, src_seq_length)
            tgt of shape (batch_size, tgt_seq_length)
        Outputs
            decoded sequence of shape (batch_size, seq_length, tgt_vocab_size)
        """
        enc_mask = self._create_enc_mask(src)
        dec_mask = self._create_dec_mask(tgt)

        enc_outputs = self.encoder(src, enc_mask)
        dec_outputs = self.decoder(tgt, enc_outputs, enc_mask, dec_mask)

        return dec_outputs

### Training

In [46]:
def init_weights(model: nn.Module) -> None:
    if hasattr(model, "weight") and model.weight.dim() > 1:
        nn.init.xavier_uniform_(model.weight.data)

def count_parameters(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

loss_function = nn.NLLLoss(ignore_index=tgt_lang.token_to_id("<PAD>"))

dropout_ratio = 0.3
hidden_dim = 256
ff_dim = 512
num_heads = 4
num_layers = 3

encoder = Encoder(vocab_size=src_lang.get_vocab_size(),
                  hidden_dim=hidden_dim,
                  num_heads=num_heads,
                  num_layers=num_layers,
                  ff_dim=ff_dim,
                  max_seq_length=100,
                  device=device,
                  dropout_ratio=dropout_ratio).to(device)

decoder = Decoder(vocab_size=tgt_lang.get_vocab_size(),
                  hidden_dim=hidden_dim,
                  num_heads=num_heads,
                  num_layers=num_layers,
                  ff_dim=ff_dim,
                  max_seq_length=100,
                  device=device,
                  dropout_ratio=dropout_ratio).to(device)

translator = TranslatorModel(encoder=encoder,
                             decoder=decoder,
                             enc_pad_token=src_lang.token_to_id("<PAD>"),
                             dec_pad_token=tgt_lang.token_to_id("<PAD>"),
                             device=device).to(device)

translator.apply(init_weights)
print(f"Model num parameters: {count_parameters(translator):,}")

optimizer = torch.optim.Adam(translator.parameters(), lr=0.0005)

Model num parameters: 7,080,864


In [47]:
def run_batch(model: nn.Module,
              loss_function: nn.NLLLoss,
              batch: dict[str, Tensor],
              device: torch.device) -> float:
    # src.shape (batch_size, seq_length)
    src = batch["de_ids"].to(device).transpose(1, 0)  # type: Tensor
    tgt = batch["en_ids"].to(device).transpose(1, 0)  # type: Tensor

    log_probs = model(src, tgt[:, :-1])  # type: Tensor
    log_probs = log_probs.reshape(-1, log_probs.size(-1))

    loss = loss_function(log_probs, tgt[:, 1:].reshape(-1).long())

    return loss

def train_one_epoch(model: nn.Module,
                    optimizer: torch.optim.Optimizer,
                    loss_function: nn.NLLLoss,
                    data_loader: torch.utils.data.DataLoader,
                    device: torch.device) -> float:
    model.train()

    losses = []
    for batch in tqdm(data_loader):
        optimizer.zero_grad()

        loss = run_batch(model=model, loss_function=loss_function, batch=batch, device=device)
        losses.append(loss.item())

        loss.backward()

        nn.utils.clip_grad_norm_(translator.parameters(), 1.0)

        optimizer.step()

    return sum(losses) / len(losses)

def translate_from_tensor(model: nn.Module, src: Tensor,
                          tgt_lang: Tokenizer, tgt: Tensor) -> str:
    # outputs.shape (batch_size, tgt_seq_length, tgt_vocab_size)
    log_probs = model(src, tgt=tgt)  # type: Tensor

    # pred_top2.shape (batch_size, tgt_seq_length, 2)
    _, pred_top2 = log_probs.topk(2, dim=-1)

    # pred_top2.shape (tgt_seq_length, 2)
    pred_top2 = pred_top2.squeeze(0)  # because batch_size=1 here

    # unpack first 2 top predictions
    first_pred, second_pred = pred_top2[:, 0].unsqueeze(1), pred_top2[:, 1].unsqueeze(1)

    # in case first top prediction is UNK use second top prediction
    unk_idx = tgt_lang.token_to_id("<UNK>")
    indices = torch.where(first_pred == unk_idx, second_pred, first_pred)

    indices = indices.squeeze().tolist()

    sentence = tgt_lang.decode(indices, skip_special_tokens=False)

    return sentence

def translate(model: TranslatorModel, src: Tensor | str, src_lang: Tokenizer,
              tgt_lang: Tokenizer, max_tgt_length: int, device: torch.device,
              clean: bool = False) -> str:
    model.eval()

    if isinstance(src, str):
        nlp = spacy.load("de_core_news_sm")  # TODO: don't like this workaround
        tokens = [token.text.lower() for token in nlp(src)]
        tokens = ["<BOS>", *tokens, "<EOS>"]
        src_idxs = np.array([src_lang.token_to_id(word) for word in tokens])
        src = torch.from_numpy(src_idxs).reshape(1, -1).to(device)  # (batch_size, src_seq_length)

    enc_mask = model._create_enc_mask(src)  # noqa: SLF001

    with torch.no_grad():
        encoder_outputs = model.encoder(src, enc_mask)

    tgt_indices = [tgt_lang.token_to_id("<BOS>")]

    for _ in range(max_tgt_length):
        tgt = torch.LongTensor(tgt_indices).unsqueeze(0).to(device)

        dec_mask = model._create_dec_mask(tgt)  # noqa: SLF001

        with torch.no_grad():
            log_probs = model.decoder(tgt, encoder_outputs, enc_mask, dec_mask)  # type: Tensor

        # pred_top2.shape (batch_size, tgt_seq_length, 2)
        _, pred_top2 = log_probs.topk(2, dim=-1)

        # pred_top2.shape (tgt_seq_length, 2)
        pred_top2 = pred_top2.squeeze(0)  # because batch_size=1 here

        # unpack first 2 top predictions
        first_pred, second_pred = pred_top2[-1, 0].item(), pred_top2[-1, 1].item()

        # in case first top prediction is UNK use second top prediction
        unk_idx = tgt_lang.token_to_id("<UNK>")
        pred_token = second_pred if first_pred == unk_idx else first_pred

        tgt_indices.append(pred_token)

        if pred_token == tgt_lang.token_to_id("<EOS>"):
            break

    sentence = tgt_lang.decode(tgt_indices, skip_special_tokens=clean)

    return sentence

def print_sentences(data: datasets.Dataset, idx: int, model: nn.Module,
                    src_lang: Tokenizer, tgt_lang: Tokenizer, device: torch.device) -> None:
    data_eval_src = data[idx]["de_ids"].to(device)
    data_eval_tgt = data[idx]["en_ids"].to(device)
    sentence_src = src_lang.decode(data_eval_src.detach().cpu().numpy().squeeze())
    sentence_tgt = tgt_lang.decode(data_eval_tgt.detach().cpu().numpy().squeeze())
    sentence_evaluated = translate_from_tensor(model, data_eval_src.unsqueeze(0), tgt_lang,
                                               data_eval_tgt.unsqueeze(0))

    print(f"SOURCE: {sentence_src}")
    print(f"TARGET: {sentence_tgt}")
    print(f"MODEL: {sentence_evaluated}")

def evaluate_model(model: nn.Module,
                   data_loader: torch.utils.data.DataLoader,
                   loss_function: nn.NLLLoss,
                   device: torch.device) -> float:
    model.eval()

    losses = []
    with torch.no_grad():
        for _, batch in enumerate(data_loader):
            loss = run_batch(model=model, loss_function=loss_function, batch=batch, device=device)
            losses.append(loss.item())

    return sum(losses) / len(losses)

In [53]:
translate(model=translator,
          src=train_data[42]["de_ids"].reshape(1, -1).to(device),
          src_lang=src_lang,
          tgt_lang=tgt_lang,
          max_tgt_length=10,
          device=device)

'<BOS> a man walking past a red car. <EOS>'

In [54]:
# print untrained model translations
print_sentences(data=train_data, idx=47, model=translator, src_lang=src_lang, tgt_lang=tgt_lang, device=device)

SOURCE: ein mann mit sonnenbrille legt seinen arm um eine frau in einer schwarz - weißen bluse.
TARGET: a man in sunglasses puts his arm around a woman in a black and white blouse.
MODEL: a man with sunglasses is his woman into a woman in a white shirt white shirt. <EOS>.


In [50]:
# training time on RTX 4090 1 min

In [52]:
NUM_EPOCHS = 5

translator.train()

model_name = f"translator_transformer_v2_{num_layers}_layers"
best_val_loss = float("inf")
train_losses, valid_losses = [], []
for epoch in range(NUM_EPOCHS):
    time_start = time.time()
    epoch_loss = train_one_epoch(translator, optimizer,
                                 loss_function, train_data_loader, device)

    time_passed_seconds = time.time() - time_start

    train_losses.append(epoch_loss)

    valid_loss = evaluate_model(model=translator, data_loader=valid_data_loader, loss_function=loss_function, device=device)
    valid_losses.append(valid_loss)

    if valid_loss < best_val_loss:
        # save best validaiton loss model
        best_val_loss = valid_loss
        print("Saving model state...")
        torch.save(translator.state_dict(), f"models/{model_name}_bestval.pt")

    # save model
    torch.save(translator.state_dict(), f"models/{model_name}.pt")

    print(f"Epoch: {epoch + 1}, elapsed: {time_passed_seconds:.0f} sec, train loss: {epoch_loss:.4f}, validation loss: {valid_loss:.4f}")

    if (epoch + 1) % 10 == 0:
        random_eval_idx = int(np.random.choice(list(range(len(train_data)))))
        print_sentences(data=train_data, idx=random_eval_idx, model=translator, src_lang=src_lang, tgt_lang=tgt_lang, device=device)

    print("-" * 100)

100%|██████████| 227/227 [02:04<00:00,  1.82it/s]


Saving model state...
Epoch: 1, elapsed: 125 sec, train loss: 4.8214, validation loss: 3.9135
----------------------------------------------------------------------------------------------------


100%|██████████| 227/227 [02:53<00:00,  1.31it/s]


Saving model state...
Epoch: 2, elapsed: 174 sec, train loss: 3.7449, validation loss: 3.4478
----------------------------------------------------------------------------------------------------


100%|██████████| 227/227 [02:43<00:00,  1.39it/s]


Saving model state...
Epoch: 3, elapsed: 164 sec, train loss: 3.3629, validation loss: 3.1322
----------------------------------------------------------------------------------------------------


100%|██████████| 227/227 [02:59<00:00,  1.26it/s]


Saving model state...
Epoch: 4, elapsed: 180 sec, train loss: 3.0509, validation loss: 2.7958
----------------------------------------------------------------------------------------------------


100%|██████████| 227/227 [03:01<00:00,  1.25it/s]


Saving model state...
Epoch: 5, elapsed: 182 sec, train loss: 2.7744, validation loss: 2.5852
----------------------------------------------------------------------------------------------------


In [None]:
torch.nn.functional.scaled_dot_product_attention

In [None]:
plt.plot(train_losses, color="blue", label="Train")
plt.plot(valid_losses, color="red", label="Validation")
plt.legend()
plt.grid()
plt.title("Average loss per epoch")

### Evaluate

In [None]:
translator.load_state_dict(torch.load(f"models/translator_transformer_v2_{num_layers}_layers_bestval.pt", map_location=device))
translator.eval()

In [None]:
random_eval_idx = int(np.random.choice(list(range(len(train_data)))))
print_sentences(data=train_data, idx=random_eval_idx, model=translator, src_lang=src_lang, tgt_lang=tgt_lang, device=device)

In [None]:
idxs = [42, 422, 10, 7, 999]

for idx in idxs:
    print_sentences(data=train_data, idx=idx, model=translator, src_lang=src_lang, tgt_lang=tgt_lang, device=device)
    print("-" * 100)

In [None]:
bleu = evaluate.load("bleu")

def get_tokenizer_fn(tgt_tokenizer: Tokenizer):
    def tokenizer_fn(s):
        tokens = tgt_tokenizer.encode(s).tokens
        return tokens

    return tokenizer_fn

tokenizer_fn = get_tokenizer_fn(tgt_lang)

In [None]:
# compute BLEU metric on test data
predictions, references = [], []
for idx in tqdm(range(test_data.num_rows)):
    data_eval_src = test_data[idx]["de_ids"].reshape(1, -1).to(device)
    sentence_evaluated = translate(model=translator,
                                   src=data_eval_src,
                                   src_lang=src_lang,
                                   tgt_lang=tgt_lang,
                                   max_tgt_length=100,
                                   device=device,
                                   clean=True)


    predictions.append(sentence_evaluated)
    references.append(test_data[idx]["en"])

bleu.compute(predictions=predictions, references=references, tokenizer=tokenizer_fn)