In [1]:
import spacy
import random
import numpy as np
import datasets
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn
import torch.nn.functional as F
import torch.optim as optim
from datetime import datetime
from collections import Counter
from torch.utils.data import DataLoader
from torcheval.metrics.functional import bleu_score
from torcheval.metrics import BLEUScore
from transformer import Transformer

In [2]:
device = torch.device('cpu')
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device('mps')
elif torch.cuda.is_available():
    device = torch.device('cuda')

print(f"Run on device: {device}")

Run on device: mps


In [3]:
# ML cloud services
google_colab = False
azure_ml = False
kaggle = False

In [4]:
if google_colab:
    !mkdir -p data/Multi30k_HuggingFace
    !python -m spacy download en_core_web_sm
    !python -m spacy download de_core_news_sm
elif kaggle:
    !pip install torcheval
    !python -m spacy download en_core_web_sm
    !python -m spacy download de_core_news_sm
elif azure_ml:
    !python -m spacy download en_core_web_sm
    !python -m spacy download de_core_news_sm

In [5]:
dataset_path = "data/Multi30k_HuggingFace"
if azure_ml:
    dataset_path = "Users/khoi.hoangdai/" + "data/Multi30k_HuggingFace"
elif kaggle:
    dataset_path = "/kaggle/input/untitled"

In [6]:
dataset = datasets.load_dataset(dataset_path)
train_set, val_set, test_set = dataset['train'], dataset['validation'], dataset['test']

In [7]:
# Set random seeds
def setseed(seed):
    """Set all seeds and deterministic CuDNN behavior"""
    # Python random module
    random.seed(seed)
    
    # NumPy
    np.random.seed(seed)
    
    # PyTorch (CPU and all GPUs)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    # CuDNN configurations (critical for reproducibility)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

setseed(1711)

In [8]:
# Use tokenizer from spacy
en_nlp = spacy.load('en_core_web_sm')
de_nlp = spacy.load('de_core_news_sm')

In [9]:
# Build the token frequency dict, ignore tokens with low frequency
en_token_dict = Counter()
de_token_dict = Counter()
unk, pad, sos, eos = '<unk>', '<pad>', '<sos>', '<eos>'
special_tokens = [unk, pad, sos, eos]
min_freq = 2

for example in train_set:
    en_tokens = [token.text.lower() for token in en_nlp.tokenizer(example['en'])]
    de_tokens = [token.text.lower() for token in de_nlp.tokenizer(example['de'])]
    en_token_dict.update(en_tokens)
    de_token_dict.update(de_tokens)

# No need to keep track of the frequency
en_token_dict = [k for (k, v) in en_token_dict.items() if v >= min_freq]
en_token_dict = special_tokens + en_token_dict
en_token_dict = {value: index for (index, value) in enumerate(en_token_dict)}
en_idx_token_dict = {value: key for (key, value) in en_token_dict.items()}

de_token_dict = [k for (k, v) in de_token_dict.items() if v >= min_freq]
de_token_dict = special_tokens + de_token_dict
de_token_dict = {value: index for (index, value) in enumerate(de_token_dict)}
de_idx_token_dict = {value: key for (key, value) in de_token_dict.items()}

In [10]:
# Check if special tokens share the same index
for special in special_tokens:
    if not en_token_dict[special] == de_token_dict[special]:
        print(f"Token {special} mismatch between EN and DE dictionary")

In [11]:
# Create token list and token IDs for each sentence in the dataset
def tokenize_example(example, en_nlp, de_nlp, en_token_dict, de_token_dict, sos, eos):
    en_tokens, de_tokens = [], []
    en_ids, de_ids = [], []
    for token in en_nlp.tokenizer(example['en']):
        token = token.text.lower()
        if token not in en_token_dict:
            token = unk

        en_tokens.append(token)
        en_ids.append(en_token_dict[token])

    en_tokens = [sos] + en_tokens + [eos]
    en_ids = [en_token_dict[sos]] + en_ids + [en_token_dict[eos]]

    for token in de_nlp.tokenizer(example['de']):
        token = token.text.lower()
        if token not in de_token_dict:
            token = unk

        de_tokens.append(token)
        de_ids.append(de_token_dict[token])

    de_tokens = [sos] + de_tokens + [eos]
    de_ids = [de_token_dict[sos]] + de_ids + [de_token_dict[eos]]

    example['en_tokens'] = en_tokens
    example['en_ids'] = en_ids
    example['de_tokens'] = de_tokens
    example['de_ids'] = de_ids

    return example


In [12]:
fn_kwargs = {
    'en_nlp': en_nlp,
    'de_nlp': de_nlp,
    'en_token_dict': en_token_dict,
    'de_token_dict': de_token_dict,
    'sos': sos,
    'eos': eos,
}
train_set = train_set.map(tokenize_example, fn_kwargs=fn_kwargs)
val_set = val_set.map(tokenize_example, fn_kwargs=fn_kwargs)
test_set = test_set.map(tokenize_example, fn_kwargs=fn_kwargs)

In [13]:
print(train_set[0]['de'])
print(train_set[0]['de_tokens'])
print(train_set[0]['de_ids'])
print(train_set[0]['en'])
print(train_set[0]['en_tokens'])
print(train_set[0]['en_ids'])

Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.
['<sos>', 'zwei', 'junge', 'weiße', 'männer', 'sind', 'im', 'freien', 'in', 'der', 'nähe', 'vieler', 'büsche', '.', '<eos>']
[2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 3]
Two young, White males are outside near many bushes.
['<sos>', 'two', 'young', ',', 'white', 'males', 'are', 'outside', 'near', 'many', 'bushes', '.', '<eos>']
[2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 3]


In [14]:
# Write a collate_fn to pad sequences with variable length into a batch of tensors for Dataloader
def get_collate_fn(pad_index=1):
    def collate_fn(batch):
        # Encoder input: <sequence> + <eos>
        encoder_input = [torch.tensor(sequence['de_ids'][1:]) for sequence in batch]
        encoder_input = rnn.pad_sequence(encoder_input, padding_value=pad_index, batch_first=True)

        # Decode input: <sos> + <sequence>
        decoder_input = [torch.tensor(sequence['en_ids'][:-1]) for sequence in batch]
        decoder_input = rnn.pad_sequence(decoder_input, padding_value=pad_index, batch_first=True)

        # Decode output: <sequence> + <eos>
        decoder_output = [torch.tensor(sequence['en_ids'][1:]) for sequence in batch]
        decoder_output = rnn.pad_sequence(decoder_output, padding_value=pad_index, batch_first=True)

        return encoder_input, decoder_input, decoder_output

    return collate_fn

In [15]:
collate_fn = get_collate_fn()
batch_size = 128
train_dl = DataLoader(train_set, collate_fn=collate_fn, batch_size=batch_size, shuffle=True)
val_dl = DataLoader(val_set, collate_fn=collate_fn, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_set, collate_fn=collate_fn, batch_size=batch_size, shuffle=True)

In [16]:
def train_fn(model, optimizer, epochs, loss_fn=F.cross_entropy, dataloader=train_dl, pad_idx=1):
    total_loss = 0
    for epoch in range(epochs):
        epoch_loss = 0
        epoch_start = datetime.now()
        next_chunk = 0
        for idx, dl in enumerate(dataloader):
            batch_start = datetime.now()
            encoder_input, decoder_input, decoder_output = dl
            encoder_input = encoder_input.to(device)
            decoder_input = decoder_input.to(device)
            decoder_output = decoder_output.to(device)
            output = model(encoder_input, decoder_input)
            loss = loss_fn(output.permute(0, 2, 1), decoder_output, ignore_index=pad_idx, reduction='mean')

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            batch_runtime = datetime.now() - batch_start
            if idx == next_chunk:
                print(f"Chunk={idx}: loss={loss.item():.2f}, batch runtime={batch_runtime.total_seconds()*1000:.2f} ms")
                next_chunk += len(train_dl) // 10

        total_loss += epoch_loss
        epoch_runtime = datetime.now() - epoch_start
        print(f"Epoch={epoch}: Loss={epoch_loss / len(train_dl):.2f}, epoch runtime={epoch_runtime.seconds:.2f} seconds")

    return total_loss / len(train_dl)

In [17]:
input_dim = len(de_token_dict)
output_dim = len(en_token_dict)
emb_dim = 512
attn_dim = 64
att_heads = 8
ffn_dim = 2048
layers = 6
max_seq_len = 50
epochs = 20

# Transformer
transformer_model = Transformer(input_dim, output_dim,
        emb_dim=emb_dim, attn_dim=attn_dim, attn_heads=att_heads, 
        ffn_dim=ffn_dim, layers=layers, max_seq_len=max_seq_len
    ).to(device)
transformer_optimizer = optim.Adam(transformer_model.parameters(), lr=1e-3)

In [18]:

MODEL_DIR = "saved_weights"
MODEL_FILE = f"transfomer_epochs={epochs}.pt"

In [19]:
inference = True
if not inference:
    model = transformer_model
    optimizer = transformer_optimizer
    train_gru_err = train_fn(model, optimizer, epochs)
    torch.save(model.state_dict(), MODEL_FILE)

In [20]:
def get_blue_tokenizer(en_nlp, en_token_dict, unk):
    def blue_tokenizer(s):
        en_tokens = [token.text.lower() if token.text.lower() in en_token_dict else unk for token in en_nlp.tokenizer(s)]
        return en_tokens

    return blue_tokenizer

In [21]:
def en_idx_to_sentence(indices, en_idx_token_dict, pad_idx=1):
    sentence = [en_idx_token_dict[idx.item()] for idx in indices if idx.item() != pad_idx]
    return " ".join(sentence)

In [None]:
def evaluate(model, model_file, val_set=val_set, collate_fn=collate_fn, batch_size=batch_size, en_idx_token_dict=en_idx_token_dict, eos=eos, device=device, max_output_len=30):
    val_dl = DataLoader(val_set, collate_fn=collate_fn, batch_size=batch_size, shuffle=True)
    bleu_metric = BLEUScore(
        n_gram=4
    )
    with open(f"bleu_output/{model_file}_bleu.txt", "w") as f:
        with torch.no_grad():
            for dl in val_dl:
                encoder_input, _, decoder_output = dl
                encoder_input = encoder_input.to(device)
                decoder_output = decoder_output.to(device)

                for i, seq_input in enumerate(encoder_input):
                    translated = model.translate(seq_input, en_idx_token_dict=en_idx_token_dict, device=device, eos=eos, sos_idx=2, max_output_len=max_output_len)
                    #  = model.translate(model, seq_input, en_idx_token_dict, device)
                    en_groud_truth = en_idx_to_sentence(decoder_output[i], en_idx_token_dict)
                    translated_sentence = " ".join(translated)
                    bleu_metric.update([en_groud_truth], [[translated_sentence]])
                    
                    try:
                        blue_results = bleu_score([translated_sentence], [en_groud_truth])
                        print(en_groud_truth)
                        print(translated_sentence)
                        print(blue_results)       
                        
                        f.write(en_groud_truth + "\n")
                        f.write(translated_sentence + "\n")
                        f.write(str(blue_results.item()) + "\n")
                        f.write("=======================" +  "\n")
                    except ValueError as e:
                        print(f"An unexpected error occurred: {e}")

            scopus_level_bleu = bleu_metric.compute().item()
            print("Scopus level blue: " + str(scopus_level_bleu))
            f.write(f"Scopus level blue: {scopus_level_bleu}")
    

In [None]:
if inference:
    MODEL_DIR = "saved_weights"
    MODEL_FILE= "transfomer_epochs=20.pt"
    # Load the Seq2Seq model by first initializing the architecture of Encoder and Decoder
    transformer_model = Transformer(input_dim, output_dim,
            emb_dim=emb_dim, attn_dim=attn_dim, attn_heads=att_heads, 
            ffn_dim=ffn_dim, layers=layers, max_seq_len=max_seq_len
        ).to(device)
    transformer_model.load_state_dict(torch.load(MODEL_DIR + "/" + MODEL_FILE, weights_only=False, map_location=device))
    evaluate(transformer_model, MODEL_FILE)

tensor([[[2.3388e-03, 2.2880e-09, 2.8605e-09,  ..., 1.9816e-09,
          3.9709e-09, 1.2049e-08]]], device='mps:0')
tensor([[[2.3388e-03, 2.2880e-09, 2.8605e-09,  ..., 1.9816e-09,
          3.9709e-09, 1.2049e-08],
         [5.9888e-03, 1.0259e-08, 7.7283e-09,  ..., 2.4228e-07,
          9.1170e-08, 1.3454e-08]]], device='mps:0')
tensor([[[2.3388e-03, 2.2880e-09, 2.8605e-09,  ..., 1.9816e-09,
          3.9709e-09, 1.2049e-08],
         [5.9888e-03, 1.0259e-08, 7.7283e-09,  ..., 2.4228e-07,
          9.1170e-08, 1.3454e-08],
         [6.6872e-03, 3.4287e-09, 9.2962e-09,  ..., 7.7340e-09,
          2.0528e-05, 5.4652e-09]]], device='mps:0')
tensor([[[2.3388e-03, 2.2880e-09, 2.8605e-09,  ..., 1.9816e-09,
          3.9709e-09, 1.2049e-08],
         [5.9888e-03, 1.0259e-08, 7.7283e-09,  ..., 2.4228e-07,
          9.1170e-08, 1.3454e-08],
         [6.6872e-03, 3.4287e-09, 9.2962e-09,  ..., 7.7340e-09,
          2.0528e-05, 5.4652e-09],
         [6.1581e-03, 3.3082e-09, 4.4945e-09,  ..., 7.4