<a href="https://colab.research.google.com/github/code-1-mukul/Deep-Learning-Lab/blob/main/DL_LAB_6.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
import re
from torch.utils.data import Dataset, DataLoader
from collections import Counter
from nltk.translate.bleu_score import sentence_bleu
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)

Using device: mps


In [None]:
def normalize_sentence(raw_sentence):
    processed = raw_sentence.lower().strip()
    processed = re.sub(
        r"[^a-zA-Z¿?¡!.']+",
        " ",
        processed
    )
    return processed


parallel_corpus = []

with open("spa.txt", encoding="utf-8") as file_handle:

    for raw_line in file_handle:

        columns = raw_line.strip().split("\t")

        if len(columns) >= 2:

            source_text = normalize_sentence(columns[0])
            target_text = normalize_sentence(columns[1])

            parallel_corpus.append(
                (source_text, target_text)
            )


# limit dataset size
parallel_corpus = parallel_corpus[:10000]

random.shuffle(parallel_corpus)


total_examples = len(parallel_corpus)

train_cutoff = int(0.8 * total_examples)
val_cutoff = int(0.1 * total_examples)


train_data = parallel_corpus[:train_cutoff]

validation_data = parallel_corpus[
    train_cutoff: train_cutoff + val_cutoff
]

test_data = parallel_corpus[
    train_cutoff + val_cutoff:
]


print(
    "Train:", len(train_data),
    "Test:", len(test_data)
)

Train: 8000 Test: 1000


In [None]:
class TextVocabulary:

    def __init__(self, sentence_collection):

        token_counter = Counter()

        for sentence in sentence_collection:
            token_counter.update(sentence.split())

        base_tokens = ["<pad>", "<sos>", "<eos>", "<unk>"]

        self.index_to_token = list(base_tokens)
        self.index_to_token += list(token_counter.keys())

        self.token_to_index = {
            token: idx
            for idx, token in enumerate(self.index_to_token)
        }

    def encode(self, sentence):

        unknown_id = self.token_to_index["<unk>"]

        return [
            self.token_to_index.get(token, unknown_id)
            for token in sentence.split()
        ]


# ---------- build vocabularies ----------

source_vocab = TextVocabulary(
    [pair[0] for pair in train_data]
)

target_vocab = TextVocabulary(
    [pair[1] for pair in train_data]
)


SRC_VOCAB_SIZE = len(source_vocab.index_to_token)
TGT_VOCAB_SIZE = len(target_vocab.index_to_token)

print("Input vocab:", SRC_VOCAB_SIZE)
print("Output vocab:", TGT_VOCAB_SIZE)

Input vocab: 2931
Output vocab: 5122


In [None]:
class ParallelTextDataset(Dataset):

    def __init__(self, sentence_pairs):
        self.samples = sentence_pairs

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        return self.samples[index]


# ----------------------------------------------------------

def translation_collate(batch_samples):

    src_sequences = []
    tgt_sequences = []

    sos_src = source_vocab.token_to_index["<sos>"]
    eos_src = source_vocab.token_to_index["<eos>"]

    sos_tgt = target_vocab.token_to_index["<sos>"]
    eos_tgt = target_vocab.token_to_index["<eos>"]

    for src_text, tgt_text in batch_samples:

        src_ids = (
            [sos_src]
            + source_vocab.encode(src_text)
            + [eos_src]
        )

        tgt_ids = (
            [sos_tgt]
            + target_vocab.encode(tgt_text)
            + [eos_tgt]
        )

        src_sequences.append(torch.tensor(src_ids))
        tgt_sequences.append(torch.tensor(tgt_ids))

    padded_src = nn.utils.rnn.pad_sequence(
        src_sequences,
        padding_value=0
    )

    padded_tgt = nn.utils.rnn.pad_sequence(
        tgt_sequences,
        padding_value=0
    )

    return padded_src.to(DEVICE), padded_tgt.to(DEVICE)


# ----------------------------------------------------------

train_data_loader = DataLoader(
    ParallelTextDataset(train_data),
    batch_size=32,
    shuffle=True,
    collate_fn=translation_collate
)

In [None]:
class SeqEncoder(nn.Module):

    def __init__(self, vocab_size,
                 embedding_dim,
                 hidden_dim):

        super().__init__()

        self.embedding_layer = nn.Embedding(
            vocab_size,
            embedding_dim
        )

        self.recurrent_layer = nn.LSTM(
            embedding_dim,
            hidden_dim
        )

    def forward(self, src_tokens):

        embedded_tokens = self.embedding_layer(src_tokens)

        encoder_outputs, (h_n, c_n) = self.recurrent_layer(
            embedded_tokens
        )

        return encoder_outputs, h_n, c_n


# ----------------------------------------------------------

class SeqDecoder(nn.Module):

    def __init__(self, vocab_size,
                 embedding_dim,
                 hidden_dim):

        super().__init__()

        self.embedding_layer = nn.Embedding(
            vocab_size,
            embedding_dim
        )

        self.recurrent_layer = nn.LSTM(
            embedding_dim,
            hidden_dim
        )

        self.output_projection = nn.Linear(
            hidden_dim,
            vocab_size
        )

    def forward(self,
                input_token,
                hidden_state,
                cell_state):

        input_token = input_token.unsqueeze(0)

        embedded_token = self.embedding_layer(
            input_token
        )

        decoder_output, (h_n, c_n) = self.recurrent_layer(
            embedded_token,
            (hidden_state, cell_state)
        )

        logits = self.output_projection(
            decoder_output.squeeze(0)
        )

        return logits, h_n, c_n

In [None]:
class SequenceToSequenceModel(nn.Module):

    def __init__(self, encoder_module, decoder_module):
        super().__init__()

        self.encoder_module = encoder_module
        self.decoder_module = decoder_module

    def forward(self,
                source_batch,
                target_batch,
                teacher_force_prob=0.5):

        target_length = target_batch.size(0)
        batch_size = target_batch.size(1)

        vocab_dim = self.decoder_module.output_projection.out_features

        decoder_outputs = torch.zeros(
            target_length,
            batch_size,
            vocab_dim,
            device=DEVICE
        )

        _, hidden_state, cell_state = self.encoder_module(
            source_batch
        )

        current_input = target_batch[0]

        for step in range(1, target_length):

            logits, hidden_state, cell_state = self.decoder_module(
                current_input,
                hidden_state,
                cell_state
            )

            decoder_outputs[step] = logits

            use_teacher = (
                random.random() < teacher_force_prob
            )

            predicted_tokens = logits.argmax(dim=1)

            current_input = (
                target_batch[step]
                if use_teacher
                else predicted_tokens
            )

        return decoder_outputs

In [None]:
EMBEDDING_SIZE = 256
RNN_HIDDEN_SIZE = 512


encoder_model = SeqEncoder(
    SRC_VOCAB_SIZE,
    EMBEDDING_SIZE,
    RNN_HIDDEN_SIZE
)

decoder_model = SeqDecoder(
    TGT_VOCAB_SIZE,
    EMBEDDING_SIZE,
    RNN_HIDDEN_SIZE
)


translation_model = SequenceToSequenceModel(
    encoder_model,
    decoder_model
).to(DEVICE)


translation_optimizer = optim.Adam(
    translation_model.parameters()
)

loss_function = nn.CrossEntropyLoss(
    ignore_index=0
)

In [None]:
class AdditiveAttention(nn.Module):

    def __init__(self, hidden_dim):
        super().__init__()

        self.energy_layer = nn.Linear(
            hidden_dim * 2,
            hidden_dim
        )

        self.score_layer = nn.Linear(
            hidden_dim,
            1,
            bias=False
        )

    def forward(self,
                decoder_hidden,
                encoder_states):

        src_length = encoder_states.size(0)

        # (num_layers, batch, hidden) -> (batch, hidden)
        last_hidden = decoder_hidden[-1]

        expanded_hidden = (
            last_hidden
            .unsqueeze(1)
            .repeat(1, src_length, 1)
        )

        encoder_states = encoder_states.permute(1, 0, 2)

        energy = torch.tanh(
            self.energy_layer(
                torch.cat(
                    (expanded_hidden, encoder_states),
                    dim=2
                )
            )
        )

        scores = self.score_layer(energy).squeeze(2)

        attention_weights = torch.softmax(
            scores,
            dim=1
        )

        return attention_weights


# ----------------------------------------------------------

class AttentionDecoder(nn.Module):

    def __init__(self,
                 vocab_size,
                 embedding_dim,
                 hidden_dim,
                 attention_module):

        super().__init__()

        self.embedding_layer = nn.Embedding(
            vocab_size,
            embedding_dim
        )

        self.recurrent_layer = nn.LSTM(
            embedding_dim + hidden_dim,
            hidden_dim
        )

        self.output_projection = nn.Linear(
            hidden_dim * 2,
            vocab_size
        )

        self.attention_module = attention_module

    def forward(self,
                input_token,
                hidden_state,
                cell_state,
                encoder_states):

        input_token = input_token.unsqueeze(0)

        embedded_token = self.embedding_layer(
            input_token
        )

        attention_scores = self.attention_module(
            hidden_state,
            encoder_states
        )

        attention_scores = attention_scores.unsqueeze(1)

        encoder_states = encoder_states.permute(1, 0, 2)

        context_vector = torch.bmm(
            attention_scores,
            encoder_states
        )

        context_vector = context_vector.permute(1, 0, 2)

        rnn_input = torch.cat(
            (embedded_token, context_vector),
            dim=2
        )

        rnn_output, (next_hidden, next_cell) = self.recurrent_layer(
            rnn_input,
            (hidden_state, cell_state)
        )

        combined_output = torch.cat(
            (
                rnn_output.squeeze(0),
                context_vector.squeeze(0)
            ),
            dim=1
        )

        logits = self.output_projection(
            combined_output
        )

        return logits, next_hidden, next_cell, attention_scores

In [None]:
class MultiplicativeAttention(nn.Module):

    def __init__(self, hidden_dim):
        super().__init__()

        self.projection_layer = nn.Linear(
            hidden_dim,
            hidden_dim
        )

    def forward(self,
                decoder_hidden,
                encoder_states):

        # decoder_hidden: (num_layers, batch, hidden)
        # encoder_states: (src_len, batch, hidden)

        last_hidden = decoder_hidden[-1].unsqueeze(2)

        encoder_states = encoder_states.permute(1, 0, 2)

        transformed_states = self.projection_layer(
            encoder_states
        )

        scores = torch.bmm(
            transformed_states,
            last_hidden
        ).squeeze(2)

        attention_weights = torch.softmax(
            scores,
            dim=1
        )

        return attention_weights


# ----------------------------------------------------------

class LuongStyleDecoder(nn.Module):

    def __init__(self,
                 vocab_size,
                 embedding_dim,
                 hidden_dim,
                 attention_module):

        super().__init__()

        self.embedding_layer = nn.Embedding(
            vocab_size,
            embedding_dim
        )

        self.recurrent_layer = nn.LSTM(
            embedding_dim,
            hidden_dim
        )

        self.output_projection = nn.Linear(
            hidden_dim * 2,
            vocab_size
        )

        self.attention_module = attention_module

    def forward(self,
                input_token,
                hidden_state,
                cell_state,
                encoder_states):

        input_token = input_token.unsqueeze(0)

        embedded_token = self.embedding_layer(
            input_token
        )

        rnn_output, (next_hidden, next_cell) = self.recurrent_layer(
            embedded_token,
            (hidden_state, cell_state)
        )

        attention_scores = self.attention_module(
            next_hidden,
            encoder_states
        )

        attention_scores = attention_scores.unsqueeze(1)

        encoder_states = encoder_states.permute(1, 0, 2)

        context_vector = torch.bmm(
            attention_scores,
            encoder_states
        )

        combined = torch.cat(
            (
                rnn_output.squeeze(0),
                context_vector.squeeze(1)
            ),
            dim=1
        )

        logits = self.output_projection(
            combined
        )

        return logits, next_hidden, next_cell, attention_scores

In [None]:
def execute_training_epoch(model,
                           optimizer,
                           label="Model"):

    model.train()
    total_epoch_loss = 0.0

    iterator = tqdm(
        train_data_loader,
        desc=label,
        leave=False
    )

    for source_batch, target_batch in iterator:

        optimizer.zero_grad()

        # ---------- Standard Seq2Seq (no attention) ----------
        if isinstance(model.decoder_module, SeqDecoder):

            logits = model(
                source_batch,
                target_batch
            )

        # ---------- Attention-based Seq2Seq ----------
        else:

            seq_length = target_batch.size(0)
            batch_sz = target_batch.size(1)

            logits = torch.zeros(
                seq_length,
                batch_sz,
                TGT_VOCAB_SIZE,
                device=DEVICE
            )

            encoder_states, hidden_state, cell_state = \
                model.encoder_module(source_batch)

            decoder_input = target_batch[0]

            for step in range(1, seq_length):

                step_logits, hidden_state, cell_state, _ = \
                    model.decoder_module(
                        decoder_input,
                        hidden_state,
                        cell_state,
                        encoder_states
                    )

                logits[step] = step_logits

                apply_teacher = random.random() < 0.5
                predicted_tokens = step_logits.argmax(dim=1)

                decoder_input = (
                    target_batch[step]
                    if apply_teacher
                    else predicted_tokens
                )

        # ---------- Loss calculation ----------

        vocab_dim = logits.size(-1)

        logits = logits[1:].reshape(-1, vocab_dim)
        targets_flat = target_batch[1:].reshape(-1)

        batch_loss = loss_function(
            logits,
            targets_flat
        )

        batch_loss.backward()
        optimizer.step()

        total_epoch_loss += batch_loss.item()

        iterator.set_postfix(loss=batch_loss.item())

    return total_epoch_loss / len(train_data_loader)

In [None]:
# ============================
# Bahdanau (Additive Attention)
# ============================

bahdanau_enc = SeqEncoder(
    SRC_VOCAB_SIZE,
    EMBEDDING_SIZE,
    RNN_HIDDEN_SIZE
)

bahdanau_attn = AdditiveAttention(
    RNN_HIDDEN_SIZE
)

bahdanau_dec = AttentionDecoder(
    TGT_VOCAB_SIZE,
    EMBEDDING_SIZE,
    RNN_HIDDEN_SIZE,
    bahdanau_attn
)

bahdanau_seq2seq = SequenceToSequenceModel(
    bahdanau_enc,
    bahdanau_dec
).to(DEVICE)

bahdanau_optim = optim.Adam(
    bahdanau_seq2seq.parameters()
)


# ============================
# Luong (Multiplicative Attention)
# ============================

luong_enc = SeqEncoder(
    SRC_VOCAB_SIZE,
    EMBEDDING_SIZE,
    RNN_HIDDEN_SIZE
)

luong_attn = MultiplicativeAttention(
    RNN_HIDDEN_SIZE
)

luong_dec = LuongStyleDecoder(
    TGT_VOCAB_SIZE,
    EMBEDDING_SIZE,
    RNN_HIDDEN_SIZE,
    luong_attn
)

luong_seq2seq = SequenceToSequenceModel(
    luong_enc,
    luong_dec
).to(DEVICE)

luong_optim = optim.Adam(
    luong_seq2seq.parameters()
)

In [None]:
TOTAL_EPOCHS = 50

for epoch_idx in range(TOTAL_EPOCHS):

    print(f"\nEpoch {epoch_idx + 1}:")

    vanilla_epoch_loss = execute_training_epoch(
        translation_model,
        translation_optimizer,
        label="Vanilla"
    )

    bahdanau_epoch_loss = execute_training_epoch(
        bahdanau_seq2seq,
        bahdanau_optim,
        label="Bahdanau"
    )

    luong_epoch_loss = execute_training_epoch(
        luong_seq2seq,
        luong_optim,
        label="Luong"
    )

    print(f"Vanilla Loss:   {vanilla_epoch_loss:.4f}")
    print(f"Bahdanau Loss:  {bahdanau_epoch_loss:.4f}")
    print(f"Luong Loss:     {luong_epoch_loss:.4f}")


Epoch 1:


                                                                       

Vanilla Loss:   0.4997
Bahdanau Loss:  0.5395
Luong Loss:     0.5402

Epoch 2:


                                                                       

Vanilla Loss:   0.4851
Bahdanau Loss:  0.5272
Luong Loss:     0.5332

Epoch 3:


                                                                       

Vanilla Loss:   0.4673
Bahdanau Loss:  0.5036
Luong Loss:     0.5237

Epoch 4:


                                                                       

Vanilla Loss:   0.4743
Bahdanau Loss:  0.4935
Luong Loss:     0.5072

Epoch 5:


                                                                       

Vanilla Loss:   0.4650
Bahdanau Loss:  0.4809
Luong Loss:     0.4999

Epoch 6:


                                                                       

Vanilla Loss:   0.4507
Bahdanau Loss:  0.4614
Luong Loss:     0.4998

Epoch 7:


                                                                       

Vanilla Loss:   0.4500
Bahdanau Loss:  0.4546
Luong Loss:     0.4815

Epoch 8:


                                                                       

Vanilla Loss:   0.4530
Bahdanau Loss:  0.4445
Luong Loss:     0.4728

Epoch 9:


                                                                       

Vanilla Loss:   0.4363
Bahdanau Loss:  0.4308
Luong Loss:     0.4744

Epoch 10:


                                                                       

Vanilla Loss:   0.4272
Bahdanau Loss:  0.4210
Luong Loss:     0.4466

Epoch 11:


                                                                       

Vanilla Loss:   0.4310
Bahdanau Loss:  0.4260
Luong Loss:     0.4466

Epoch 12:


                                                                       

Vanilla Loss:   0.4157
Bahdanau Loss:  0.4184
Luong Loss:     0.4373

Epoch 13:


                                                                       

Vanilla Loss:   0.4133
Bahdanau Loss:  0.4123
Luong Loss:     0.4271

Epoch 14:


                                                                       

Vanilla Loss:   0.4165
Bahdanau Loss:  0.4037
Luong Loss:     0.4202

Epoch 15:


                                                                       

Vanilla Loss:   0.4007
Bahdanau Loss:  0.4069
Luong Loss:     0.4272

Epoch 16:


                                                                       

Vanilla Loss:   0.4115
Bahdanau Loss:  0.3857
Luong Loss:     0.4155

Epoch 17:


                                                                       

Vanilla Loss:   0.4113
Bahdanau Loss:  0.3814
Luong Loss:     0.4090

Epoch 18:


                                                                       

Vanilla Loss:   0.3924
Bahdanau Loss:  0.3700
Luong Loss:     0.3994

Epoch 19:


                                                                        

Vanilla Loss:   0.3893
Bahdanau Loss:  0.3737
Luong Loss:     0.3953

Epoch 20:


                                                                       

Vanilla Loss:   0.3837
Bahdanau Loss:  0.3647
Luong Loss:     0.3926

Epoch 21:


                                                                       

Vanilla Loss:   0.3735
Bahdanau Loss:  0.3641
Luong Loss:     0.3892

Epoch 22:


                                                                       

Vanilla Loss:   0.3805
Bahdanau Loss:  0.3636
Luong Loss:     0.3745

Epoch 23:


                                                                       

Vanilla Loss:   0.3763
Bahdanau Loss:  0.3575
Luong Loss:     0.3734

Epoch 24:


                                                                       

Vanilla Loss:   0.3757
Bahdanau Loss:  0.3551
Luong Loss:     0.3679

Epoch 25:


                                                                       

Vanilla Loss:   0.3708
Bahdanau Loss:  0.3486
Luong Loss:     0.3639

Epoch 26:


                                                                        

Vanilla Loss:   0.3739
Bahdanau Loss:  0.3396
Luong Loss:     0.3594

Epoch 27:


                                                                        

Vanilla Loss:   0.3673
Bahdanau Loss:  0.3416
Luong Loss:     0.3557

Epoch 28:


                                                                       

Vanilla Loss:   0.3641
Bahdanau Loss:  0.3349
Luong Loss:     0.3510

Epoch 29:


                                                                       

Vanilla Loss:   0.3588
Bahdanau Loss:  0.3308
Luong Loss:     0.3514

Epoch 30:


                                                                       

Vanilla Loss:   0.3550
Bahdanau Loss:  0.3283
Luong Loss:     0.3464

Epoch 31:


                                                                       

Vanilla Loss:   0.3490
Bahdanau Loss:  0.3293
Luong Loss:     0.3482

Epoch 32:


                                                                       

Vanilla Loss:   0.3497
Bahdanau Loss:  0.3322
Luong Loss:     0.3463

Epoch 33:


                                                                       

Vanilla Loss:   0.3401
Bahdanau Loss:  0.3290
Luong Loss:     0.3426

Epoch 34:


                                                                       

Vanilla Loss:   0.3378
Bahdanau Loss:  0.3232
Luong Loss:     0.3439

Epoch 35:


                                                                       

Vanilla Loss:   0.3415
Bahdanau Loss:  0.3200
Luong Loss:     0.3405

Epoch 36:


                                                                       

Vanilla Loss:   0.3352
Bahdanau Loss:  0.3146
Luong Loss:     0.3308

Epoch 37:


                                                                       

Vanilla Loss:   0.3347
Bahdanau Loss:  0.3095
Luong Loss:     0.3213

Epoch 38:


                                                                       

Vanilla Loss:   0.3416
Bahdanau Loss:  0.3126
Luong Loss:     0.3211

Epoch 39:


                                                                       

Vanilla Loss:   0.3344
Bahdanau Loss:  0.3070
Luong Loss:     0.3209

Epoch 40:


                                                                       

Vanilla Loss:   0.3264
Bahdanau Loss:  0.3022
Luong Loss:     0.3158

Epoch 41:


                                                                       

Vanilla Loss:   0.3153
Bahdanau Loss:  0.3037
Luong Loss:     0.3152

Epoch 42:


                                                                       

Vanilla Loss:   0.3233
Bahdanau Loss:  0.2973
Luong Loss:     0.3152

Epoch 43:


                                                                       

Vanilla Loss:   0.3165
Bahdanau Loss:  0.3049
Luong Loss:     0.3092

Epoch 44:


                                                                        

Vanilla Loss:   0.3237
Bahdanau Loss:  0.3001
Luong Loss:     0.3111

Epoch 45:


                                                                       

Vanilla Loss:   0.3150
Bahdanau Loss:  0.2900
Luong Loss:     0.3173

Epoch 46:


                                                                       

Vanilla Loss:   0.3127
Bahdanau Loss:  0.2993
Luong Loss:     0.3045

Epoch 47:


                                                                       

Vanilla Loss:   0.3125
Bahdanau Loss:  0.2925
Luong Loss:     0.3031

Epoch 48:


                                                                       

Vanilla Loss:   0.3185
Bahdanau Loss:  0.2888
Luong Loss:     0.3072

Epoch 49:


                                                                       

Vanilla Loss:   0.3117
Bahdanau Loss:  0.2965
Luong Loss:     0.3023

Epoch 50:


                                                                       

Vanilla Loss:   0.3057
Bahdanau Loss:  0.2868
Luong Loss:     0.3106




In [None]:
def compute_bleu_score(model):

    model.eval()
    bleu_scores = []

    with torch.no_grad():

        for src_text, tgt_text in test_data[:1000]:

            src_ids = (
                [source_vocab.token_to_index["<sos>"]]
                + source_vocab.encode(src_text)
                + [source_vocab.token_to_index["<eos>"]]
            )

            src_tensor = torch.tensor(
                src_ids
            ).unsqueeze(1).to(DEVICE)

            encoder_states, hidden_state, cell_state = \
                model.encoder_module(src_tensor)

            current_token = torch.tensor(
                [target_vocab.token_to_index["<sos>"]],
                device=DEVICE
            )

            generated_tokens = []

            for _ in range(20):

                # -------- Vanilla Decoder --------
                if isinstance(model.decoder_module, SeqDecoder):

                    logits, hidden_state, cell_state = \
                        model.decoder_module(
                            current_token,
                            hidden_state,
                            cell_state
                        )

                # -------- Attention Decoder --------
                else:

                    logits, hidden_state, cell_state, _ = \
                        model.decoder_module(
                            current_token,
                            hidden_state,
                            cell_state,
                            encoder_states
                        )

                predicted = logits.argmax(dim=1)

                if predicted.item() == \
                   target_vocab.token_to_index["<eos>"]:
                    break

                generated_tokens.append(
                    target_vocab.index_to_token[
                        predicted.item()
                    ]
                )

                current_token = predicted

            reference = [tgt_text.split()]
            bleu_scores.append(
                sentence_bleu(reference, generated_tokens)
            )

    return float(np.mean(bleu_scores))


print("\nBLEU Scores:")
print("Vanilla:", compute_bleu_score(translation_model))
print("Bahdanau:", compute_bleu_score(bahdanau_seq2seq))
print("Luong:", compute_bleu_score(luong_seq2seq))


BLEU Scores:
Vanilla: 0.034573508129145496
Bahdanau: 0.03648960622222519
Luong: 0.035924854104106264
