In [None]:
import datasets
import spacy
import random
import numpy as np
import evaluate
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn
import torch.nn.functional as F
import torch.optim as optim
from datetime import datetime
from collections import Counter
from torch.utils.data import DataLoader

In [None]:
# Commands used for Google colab
google_colab = False
if google_colab:
  !mkdir -p data/Multi30k_HuggingFace
  !pip install -U datasets
  !python -m spacy download en_core_web_sm
  !python -m spacy download de_core_news_sm

In [None]:
# Set random seeds
def setseed(seed):
    """Set all seeds and deterministic CuDNN behavior"""
    # Python random module
    random.seed(seed)
    
    # NumPy
    np.random.seed(seed)
    
    # PyTorch (CPU and all GPUs)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    # CuDNN configurations (critical for reproducibility)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

setseed(1711)

In [None]:
device = torch.device('cpu')
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
    print(f"MPS device available: {device}")
elif torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"CUDA device available: {device}")

In [None]:
# Read the dataset using dataset.load_dataset()
dataset = datasets.load_dataset("data/Multi30k_HuggingFace")
train_set, val_set, test_set = dataset['train'], dataset['validation'], dataset['test']
train_set[0]

In [None]:
# Use tokenizer from spacy
en_nlp = spacy.load("en_core_web_sm")
de_nlp = spacy.load("de_core_news_sm")

In [None]:
# Build the token frequency dict, ignore tokens with low frequency
en_token_dict = Counter()
de_token_dict = Counter()
unk, pad, sos, eos = '<unk>', '<pad>', '<sos>', '<eos>'
special_tokens = [unk, pad, sos, eos]
min_freq = 2

for example in train_set:
    en_tokens = [token.text.lower() for token in en_nlp.tokenizer(example['en'])]
    de_tokens = [token.text.lower() for token in de_nlp.tokenizer(example['de'])]
    en_token_dict.update(en_tokens)
    de_token_dict.update(de_tokens)

# No need to keep track of the frequency
en_token_dict = [k for (k, v) in en_token_dict.items() if v >= min_freq]
en_token_dict = special_tokens + en_token_dict
en_token_dict = {value: index for (index, value) in enumerate(en_token_dict)}
en_idx_token_dict = {value: key for (key, value) in en_token_dict.items()}

de_token_dict = [k for (k, v) in de_token_dict.items() if v >= min_freq]
de_token_dict = special_tokens + de_token_dict
de_token_dict = {value: index for (index, value) in enumerate(de_token_dict)}
de_idx_token_dict = {value: key for (key, value) in de_token_dict.items()}

In [None]:
# Check if special tokens share the same index
for special in special_tokens:
    if not en_token_dict[special] == de_token_dict[special]:
        print(f"Token {special} mismatch between EN and DE dictionary")

In [None]:
# Create token list and token IDs for each sentence in the dataset
def tokenize_example(example, en_nlp, de_nlp, en_token_dict, de_token_dict, sos, eos):
    en_tokens, de_tokens = [], []
    en_ids, de_ids = [], []
    for token in en_nlp.tokenizer(example['en']):
        token = token.text.lower()
        if token not in en_token_dict:
            token = unk

        en_tokens.append(token)
        en_ids.append(en_token_dict[token])

    # Just add both sos and eos
    # sos and eos tokens will be processed later in the collate_fn when merging data into batch
    en_tokens = [sos] + en_tokens + [eos]
    en_ids = [en_token_dict[sos]] + en_ids + [en_token_dict[eos]]

    for token in de_nlp.tokenizer(example['de']):
        token = token.text.lower()
        if token not in de_token_dict:
            token = unk

        de_tokens.append(token)
        de_ids.append(de_token_dict[token])

    de_tokens = [sos] + de_tokens + [eos]
    de_ids = [de_token_dict[sos]] + de_ids + [de_token_dict[eos]]

    example['en_tokens'] = en_tokens
    example['en_ids'] = en_ids
    example['de_tokens'] = de_tokens
    example['de_ids'] = de_ids

    return example


In [None]:
fn_kwargs = {
    'en_nlp': en_nlp,
    'de_nlp': de_nlp,
    'en_token_dict': en_token_dict,
    'de_token_dict': de_token_dict,
    'sos': sos,
    'eos': eos,
}
train_set = train_set.map(tokenize_example, fn_kwargs=fn_kwargs)
val_set = val_set.map(tokenize_example, fn_kwargs=fn_kwargs)
test_set = test_set.map(tokenize_example, fn_kwargs=fn_kwargs)

In [None]:
print(train_set[0]['en'])
print(train_set[0]['en_tokens'])
print(train_set[0]['en_ids'])
print(train_set[0]['de'])
print(train_set[0]['de_tokens'])
print(train_set[0]['de_ids'])

In [None]:
# Write a collate_fn to pad sequences with variable length into a batch of tensors for Dataloader
def get_collate_fn(pad_index=1):
    def collate_fn(batch):
        # Encoder input: <sequence> + <eos>
        encoder_input = [torch.tensor(sequence['en_ids'][1:]) for sequence in batch]
        encoder_input = rnn.pad_sequence(encoder_input, padding_value=pad_index, batch_first=True)

        # Decode input: <sos> + <sequence>
        decoder_input = [torch.tensor(sequence['de_ids'][:-1]) for sequence in batch]
        decoder_input = rnn.pad_sequence(decoder_input, padding_value=pad_index, batch_first=True)

        # Decode output: <sequence> + <eos>
        decoder_output = [torch.tensor(sequence['de_ids'][1:]) for sequence in batch]
        decoder_output = rnn.pad_sequence(decoder_output, padding_value=pad_index, batch_first=True)

        return encoder_input, decoder_input, decoder_output

    return collate_fn

In [None]:
collate_fn = get_collate_fn()
batch_size = 128
train_dl = DataLoader(train_set, collate_fn=collate_fn, batch_size=batch_size, shuffle=True)
val_dl = DataLoader(val_set, collate_fn=collate_fn, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_set, collate_fn=collate_fn, batch_size=batch_size, shuffle=True)

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_dim, embedding_dim, rnn_hidden_dim, rnn_num_layers):
        super().__init__()
        # 1 layer Embedding
        # 2 layers GRU
        # the latent space is the same as the hidden space of the last layer of the GRU
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.encoder = nn.GRU(embedding_dim, rnn_hidden_dim, num_layers=rnn_num_layers, batch_first=True, bias=True)

    def forward(self, x):
        # hidden state at the last layer for every word in the sequence:
        #       batch, sequence, hidden_dim
        # final hidden state at every layer
        #       layer, batch, hidden_dim
        x = self.embedding(x)
        _, state_layer = self.encoder(x)
        return state_layer

In [None]:
class Decoder(nn.Module):
    def __init__(self, output_dim, embedding_dim, rnn_hidden_dim, rnn_num_layers):
        super().__init__()
        self.embedding = nn.Embedding(output_dim, embedding_dim)
        self.decoder = nn.GRU(embedding_dim, rnn_hidden_dim, num_layers=rnn_num_layers, batch_first=True, bias=True)
        self.fc = nn.Linear(rnn_hidden_dim, output_dim)

    def forward(self, x, latent):
        x = self.embedding(x)
        # hidden state at the last layer for every word in the sequence:
        #       batch, sequence, hidden_dim
        # final hidden state at every layer
        #       layer, batch, hidden_dim
        state_sequence, state_layer = self.decoder(x, latent)
        return self.fc(state_sequence), state_layer

In [None]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, encoder_input, decoder_input):
        z = self.encoder(encoder_input)
        decoder_output, decoder_state_layer = self.decoder(decoder_input, z)
        return decoder_output, decoder_state_layer


In [None]:
input_dim = len(en_token_dict)
output_dim = len(de_token_dict)
encoder_embedding_dim = 256
decoder_embedding_dim = 256
rnn_hidden_dim = 512
rnn_num_layers = 2

encoder = Encoder(input_dim, encoder_embedding_dim, rnn_hidden_dim, rnn_num_layers)
decoder = Decoder(output_dim, decoder_embedding_dim, rnn_hidden_dim, rnn_num_layers)
seq2seq = Seq2Seq(encoder, decoder).to(device)
optimizer = optim.Adam(seq2seq.parameters(), lr=1e-3)
epochs = 20

def train_fn(model=seq2seq, optimizer=optimizer, loss_fn=F.cross_entropy, epochs=epochs, dataloader=train_dl, pad_idx=1):
    total_loss = 0
    for epoch in range(epochs):
        epoch_loss = 0
        epoch_start = datetime.now()
        next_chunk = 0
        for idx, dl in enumerate(train_dl):
            batch_start = datetime.now()
            encoder_input, decoder_input, decoder_output = dl
            encoder_input = encoder_input.to(device)
            decoder_input = decoder_input.to(device)
            decoder_output = decoder_output.to(device)
            output, _ = model(encoder_input, decoder_input)
            loss = loss_fn(output.permute(0, 2, 1), decoder_output, ignore_index=pad_idx, reduction='mean')

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            batch_runtime = datetime.now() - batch_start
            if idx == next_chunk:
                print(f"Chunk={idx}: loss={loss.item():.2f}, batch runtime={batch_runtime.total_seconds()*1000:.2f} ms")
                next_chunk += len(train_dl) // 10

        total_loss += epoch_loss
        epoch_runtime = datetime.now() - epoch_start
        print(f"Epoch={epoch}: Loss={epoch_loss / len(train_dl):.2f}, epoch runtime={epoch_runtime.seconds:.2f} seconds")

    return total_loss / len(train_dl)

In [None]:
train_err = train_fn(seq2seq)

In [None]:
# Save the whole model
MODEL_PATH = "seq2seq.pt"
torch.save(seq2seq.state_dict(), MODEL_PATH)

In [None]:
# Load the Seq2Seq model by first initializing the architecture of Encoder and Decoder
GOOGLE_COLAB_MODEL_PATH = "seq2seq_googlecolab.pt"
reload_encoder = Encoder(input_dim, encoder_embedding_dim, rnn_hidden_dim, rnn_num_layers)
reload_decoder = Decoder(output_dim, decoder_embedding_dim, rnn_hidden_dim, rnn_num_layers)
reload_seq2seq = Seq2Seq(reload_encoder, reload_decoder).to(device)
reload_seq2seq.load_state_dict(torch.load(GOOGLE_COLAB_MODEL_PATH, weights_only=False, map_location=torch.device('mps')))

In [None]:
def translate(model, input_tokens, de_idx_token_dict, device, sos_idx=2, max_length=100):
    model.eval()
    output = []
    input_tokens = input_tokens.unsqueeze(0)
    
    with torch.no_grad():
        word = torch.tensor([[sos_idx]]).to(device)
        state = model.encoder(input_tokens)
        for _ in range(max_length):
            word, _, state = model.decoder(word, state)
            # Convert FCC into embedding value
            word = torch.argmax(word, dim=-1).item()
            word_token = de_idx_token_dict[word]
            word = torch.tensor([[word]]).to(device)
            output.append(word_token)
            if word_token == eos:
                break
    
    return output

In [None]:
def get_blue_tokenizer(de_nlp, de_token_dict, unk):
    def blue_tokenizer(s):
        de_tokens = [token.text.lower() if token.text.lower() in de_token_dict else unk for token in de_nlp.tokenizer(s)]
        return de_tokens

    return blue_tokenizer

In [None]:
def de_idx_to_sentence(indices, de_idx_token_dict, pad_idx=1):
    sentence = [de_idx_token_dict[idx.item()] for idx in indices if idx.item() != pad_idx]
    return " ".join(sentence)

In [None]:
val_dl = DataLoader(val_set, collate_fn=collate_fn, batch_size=batch_size, shuffle=True)
max_length = 30
# def eval(model, val_dl):
model = reload_seq2seq
model.eval()
tokenizer_fn = get_blue_tokenizer(de_nlp, de_token_dict, unk)
bleu = evaluate.load("bleu")

all_translated = []
all_ground_truth = []
with open("bleu.txt", "w") as f:
    with torch.no_grad():
        for idx, dl in enumerate(val_dl):
            encoder_input, _, decoder_output = dl
            encoder_input = encoder_input.to(device)
            decoder_output = decoder_output.to(device)

            for i, seq_input in enumerate(encoder_input):
                translated = translate(model, seq_input, en_nlp, en_token_dict, de_idx_token_dict, device, max_length)
                de_groud_truth = de_idx_to_sentence(decoder_output[i], de_idx_token_dict)
                translated_sentence = " ".join(translated)
                all_translated.append(translated_sentence)
                all_ground_truth.append(de_groud_truth)
                blue_results = bleu.compute(
                    predictions=[" ".join(translated)], references=[de_groud_truth], tokenizer=tokenizer_fn
                )
                print(translated_sentence)
                print(de_groud_truth)
                print(blue_results)
                f.write(translated_sentence + "\n")
                f.write(de_groud_truth + "\n")
                f.write(str(blue_results) + "\n")
                f.write("=======================" +  "\n")

final_bleu = bleu.compute(predictions=all_translated, references=all_ground_truth, tokenizer=tokenizer_fn)
print(final_bleu)