In [82]:
import numpy as np
import torch
import pathlib
import re
import unicodedata
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader
from collections import Counter
from tqdm import tqdm
from pprint import pprint

torch.set_grad_enabled(True)
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x16115c03150>

In [2]:
config = {
    'MAX_VOCAB_SIZE': 13000,
    'BATCH_SIZE': 8,
    'raw_dataset_path': './dataset/por.txt',
    'MAX_SEQ_LEN': 16,
    'BUFFER_SIZE': 1000,
    'UNITS': 256,
    'EPOCHS': 10
}

In [3]:
# data loader
dataset_path = pathlib.Path(config['raw_dataset_path'])
text_data = dataset_path.read_text(encoding = 'utf-8')

lines = text_data.splitlines()
pairs = [line.split('\t') for line in lines][:1200]

context_en = np.array([context for context, target, _ in pairs])
target_por = np.array([target for context, target, _ in pairs])

sentences = np.array((context_en, target_por))

In [4]:
def tokenizer(text):
    text = unicodedata.normalize("NFKD", text)
    text = text.lower()
    text = re.sub(r"[^ a-z.?!,¿]", "", text)
    text = re.sub(r"([.?!,¿])", r" \1 ", text)
    text = text.strip()
    return text.split()

tokenizer(context_en[34]), tokenizer(target_por[34])

(['go', 'on', '.'], ['siga', 'em', 'frente', '.'])

In [188]:
# build a vocabulary

class Vocabulary:
    def __init__(self, freq_threshold, max_vocab_size):
        # maintain two different mappings
        self.itos = {0: '[PAD]', 1: '[SOS]', 2: '[EOS]', 3: '[UNK]'}
        self.stoi = {'[PAD]': 0, '[SOS]': 1, '[EOS]': 2, '[UNK]': 3}
        self.freq_threshold = freq_threshold
        self.max_vocab_size = max_vocab_size
        
        self.pad_id = self.stoi['[PAD]']
        self.sos_id = self.stoi['[SOS]']
        self.eos_id = self.stoi['[EOS]']
        self.oov_id = self.stoi['[UNK]']

    def __len__(self):
        return len(self.itos)

    def vocab_size(self):
        return len(self.itos)

    def get_vocabulary(self):
        return self.stoi

    def token_to_ids(self, tokens):
        if isinstance(tokens, str): # handle a single word or sentence here
            token_list = self.tokenizer(tokens)
            return [self.stoi[t] if t in self.stoi else self.stoi['[UNK]'] for t in token_list]

        elif isinstance(tokens, list):
            return [self.stoi[t] if t in self.stoi else self.stoi['[UNK]'] for t in tokens]
        
        else:
            raise TypeError("Input must be either String or List of words.")

    def ids_to_token(self, ids):
        return [self.itos[id] if id in self.itos else self.itos[3] for id in ids]

    # building vocab with the input sentence list
    def adapt(self, sentences, tokenizer):
        self.tokenizer = tokenizer
        idx = len(self.itos)
        token_freqs = {}

        for sentence in sentences:
            for token in self.tokenizer(sentence):
                if token not in self.stoi:
                    token_freqs[token] = 1
                else:
                    token_freqs[token] += 1
                
                if (token_freqs[token] == self.freq_threshold) and (idx < self.max_vocab_size):
                    self.itos[idx] = token
                    self.stoi[token] = idx
                    idx += 1

In [189]:
# english vocabulary
en_vocab = Vocabulary(freq_threshold = 1, max_vocab_size = config['MAX_VOCAB_SIZE'])
en_vocab.adapt(context_en, tokenizer)

# portuguese vocabulary
por_vocab = Vocabulary(freq_threshold = 1, max_vocab_size = config['MAX_VOCAB_SIZE'])
por_vocab.adapt(target_por, tokenizer)

en_vocab.vocab_size(), por_vocab.vocab_size()

(419, 800)

In [7]:
# test
test_idx = 789
en_translation = context_en[test_idx]
por_translation = target_por[test_idx]

print(en_translation, '--------->', por_translation)

max_seq_len = 16
context_tokens = en_vocab.token_to_ids(en_translation)
target_tokens = por_vocab.token_to_ids(por_translation)

print("\nEncoder Input IDs: ")
print([en_vocab.sos_id] + context_tokens + [en_vocab.eos_id] + (max_seq_len - len(context_tokens) - 2) * [en_vocab.pad_id])
print("\nPre-Attention Decoder Input IDs (Shifted to the Right): ")
print([en_vocab.sos_id] + target_tokens + [en_vocab.eos_id] + (max_seq_len - len(context_tokens) - 2) * [en_vocab.pad_id])
print("\nPost-Attention Decoder Input IDs: ")
print(target_tokens + [en_vocab.eos_id] + (max_seq_len - len(context_tokens) - 1) * [en_vocab.pad_id])

Here I am. ---------> Aqui estou.

Encoder Input IDs: 
[1, 194, 20, 62, 5, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

Pre-Attention Decoder Input IDs (Shifted to the Right): 
[1, 402, 47, 5, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

Post-Attention Decoder Input IDs: 
[402, 47, 5, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


### Neural Machine Translation Custom Dataset 

In [8]:
class NMT_dataset(Dataset):
    def __init__(self, translation_pairs, tokenizer, vocabularies, max_seq_len, device = 'cpu'):
        print(translation_pairs.shape)
        self.translation_pairs = translation_pairs
        self.tokenizer = tokenizer
        self.en_vocab, self.por_vocab = vocabularies
        self.max_seq_len = max_seq_len
        self.device = device

        # for convenience 
        self.sos_id = self.en_vocab.sos_id
        self.eos_id = self.en_vocab.eos_id
        self.pad_id = self.en_vocab.pad_id
        self.oov_id = self.en_vocab.oov_id

    def __len__(self):
        return self.translation_pairs.shape[-1]

    def __getitem__(self, idx):
        req_pair = self.translation_pairs[:, idx]
        en_translation, por_translation = req_pair

        context_tokens = self.en_vocab.token_to_ids(en_translation)
        target_tokens = self.por_vocab.token_to_ids(por_translation)

        # encoder input tokens
        encoder_input = (
            [self.sos_id] + 
            context_tokens + 
            [self.eos_id] + 
            (self.max_seq_len - len(context_tokens) - 2) * [self.pad_id]
            )
        
        # pre-attention decoder input tokens
        pre_decoder_input = (
            [self.sos_id] + 
            target_tokens + 
            [self.eos_id] + 
            (self.max_seq_len - len(target_tokens) - 2) * [self.pad_id] 
        )

        # post-attention decoder output tokens
        post_decoder_output = (
            target_tokens + 
            [self.eos_id] + 
            (self.max_seq_len - len(target_tokens) - 1) * [self.pad_id]
        )

        encoder_input_tensor = torch.tensor(encoder_input[:max_seq_len], dtype = torch.long).to(self.device)
        pre_decoder_input_tensor = torch.tensor(pre_decoder_input[:max_seq_len], dtype = torch.long).to(self.device)
        post_decoder_output_tensor = torch.tensor(post_decoder_output[:max_seq_len], dtype = torch.long).to(self.device)

        return encoder_input_tensor, pre_decoder_input_tensor, post_decoder_output_tensor

In [9]:
# train and val_dataset
is_train = np.random.uniform(size = (sentences.shape[-1],)) < 0.85
train_raw_set = sentences[:, is_train]
val_raw_set = sentences[:, ~is_train]

train_raw_set.shape, val_raw_set.shape

((2, 1002), (2, 198))

In [10]:
train_dataset = NMT_dataset(train_raw_set, tokenizer, (en_vocab, por_vocab), config['MAX_SEQ_LEN'])
val_dataset = NMT_dataset(val_raw_set, tokenizer, (en_vocab, por_vocab), config['MAX_SEQ_LEN'])

(2, 1002)
(2, 198)


In [11]:
max_len_dict = {
    'encoder_in': 0, 
    'target_in': 0, 
    'target_out': 0
}
context_target_token_lengths = {
    'encoder_in': [], 
    'target_in': [], 
    'target_out': []
}

for data in tqdm(train_dataset, ncols = 100):
    encoder_in, target_in, target_out = data
    context_target_token_lengths['encoder_in'].append(len(encoder_in))
    context_target_token_lengths['target_in'].append(len(target_in))
    context_target_token_lengths['target_out'].append(len(target_out))
    
    if len(encoder_in) > max_len_dict['encoder_in']:
        max_len_dict['encoder_in'] = len(encoder_in)
        
    if len(target_in) > max_len_dict['target_in']:
        max_len_dict['target_in'] = len(target_in)
        
    if len(target_out) > max_len_dict['target_out']:
        max_len_dict['target_out'] = len(target_out)

max_len_dict

100%|████████████████████████████████████████████████████████| 1002/1002 [00:00<00:00, 11207.57it/s]


{'encoder_in': 16, 'target_in': 16, 'target_out': 16}

In [12]:
Counter(context_target_token_lengths['encoder_in']).most_common(1), Counter(context_target_token_lengths['target_in']).most_common(1), Counter(context_target_token_lengths['target_out']).most_common(1)

([(16, 1002)], [(16, 1002)], [(16, 1002)])

In [13]:
def collate_fn(batch):
    # pad sequences according to the maximum len in a batch, and then standardize all the batch inputs to same length
    contexts, target_ins, target_outs = zip(*batch)
    max_len = max([len(x) for x in contexts])

    # padding sequences
    padded_contexts = [torch.nn.functional.pad(x, (0, max_len - len(x))) for x in contexts]
    padded_target_ins = [torch.nn.functional.pad(x, (0, max_len - len(x))) for x in target_ins]
    padded_target_outs = [torch.nn.functional.pad(x, (0, max_len - len(x))) for x in target_outs]

    padded_contexts = torch.stack(padded_contexts)
    padded_target_ins = torch.stack(padded_target_ins)
    padded_target_outs = torch.stack(padded_target_outs)

    return padded_contexts, padded_target_ins, padded_target_outs


In [14]:
train_loader = DataLoader(
    train_dataset, 
    batch_size = config['BATCH_SIZE'], 
    shuffle = True,
#     collate_fn = collate_fn
)

val_loader = DataLoader(
    val_dataset, 
    batch_size = config['BATCH_SIZE'], 
    shuffle = False,
#     collate_fn = collate_fn
)

In [15]:
len(train_loader) * config['BATCH_SIZE'], len(val_loader) * config['BATCH_SIZE']

(1008, 200)

In [16]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, units):
        super(Encoder, self).__init__()
        
        # input embedding
        self.embedding = nn.Embedding(
            num_embeddings = vocab_size,
            embedding_dim = units,
            padding_idx = 0 
        )
        
        # bi-directional LSTM
        self.rnn = nn.LSTM(
            input_size = units,
            hidden_size = units,
            batch_first = True,
            bidirectional = True
        )

    def forward(self, context):
        x = self.embedding(context)
        x, _ = self.rnn(x)
        x = x[:, :, :self.rnn.hidden_size] + x[:, :, self.rnn.hidden_size:]

        return x
    
vocab_size = config['MAX_VOCAB_SIZE']
units = config['UNITS']

encoder = Encoder(vocab_size, units)

input_tensor = next(iter(train_dataset))
input_tensor = input_tensor[0].unsqueeze(0)
output = encoder(input_tensor)
print(output.shape) 

torch.Size([1, 16, 256])


In [17]:
class CrossAttention(nn.Module):
    def __init__(self, units):

        super(CrossAttention, self).__init__()

        self.multihead_attn = nn.MultiheadAttention(
            embed_dim = units,  # the size of Q, K, V dims is the embedding dimension
            num_heads = 1,  
            batch_first = True  # (batch_size, sequence_length, embedding_dim) either (seq_len, batch_size, embedding_dim)
        )

        self.layernorm = nn.LayerNorm(units)
        
        # to accumulate the inputs from encoder output (encoder_in) and the pre-attention decoder (target_in)
        self.add = nn.ModuleList([nn.Linear(units, units) for _ in range(2)])

    def forward(self, context, target):
        # query is the target_in from the pre-attention decoder, while key/value are the context from the encoded context from the encoder
        attn_output, _ = self.multihead_attn(query = target, key = context, value = context)
        
        x = self.add[0](target) + self.add[1](attn_output)
        return self.layernorm(x)


units = config['UNITS']

cross_attention = CrossAttention(units)
context_tensor = torch.randn(8, 16, units)  
target_tensor = torch.randn(8, 16, units)    
output = cross_attention(context_tensor, target_tensor)
print(output.shape) 

torch.Size([8, 16, 256])


In [18]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, units):
        super(Decoder, self).__init__()

        self.embedding = nn.Embedding(vocab_size, units, padding_idx = 0)
        self.pre_attention_rnn = nn.LSTM(units, units, batch_first = True)
        self.attention = CrossAttention(units)
        self.post_attention_rnn = nn.LSTM(units, units, batch_first = True)
        self.output_layer = nn.Linear(units, vocab_size)

    def forward(self, context, target, state = None, return_state = False):
        
        x = self.embedding(target)
        
        # pre-attention-LSTM (decoder)
        if state is None:
            x, (hidden_state, cell_state) = self.pre_attention_rnn(x)
        else:
            x, (hidden_state, cell_state) = self.pre_attention_rnn(x, state)
        
        # cross-attention between pre-attention-LSTM output and the encoded context
        x = self.attention(context, x)
        
        # post-attention-LSTM (decoder)
        x, _ = self.post_attention_rnn(x)

        # last linear (dense) layer
        logits = self.output_layer(x)
        logits = F.log_softmax(logits, dim = -1)

        if return_state:
            return logits, (hidden_state, cell_state)

        return logits

vocab_size = config['MAX_VOCAB_SIZE']
units = config['UNITS']
decoder = Decoder(vocab_size, units)

context_tensor = torch.randn(8, 15, units)  
target_tensor = torch.randint(0, vocab_size, (8, 16)).long() 

output = decoder(context_tensor, target_tensor)
print(output.shape) # (batch_size, max_seq_len, vocab_size)

torch.Size([8, 16, 13000])


In [63]:
class NMT_Translator(nn.Module):
    def __init__(self, vocab_size, units):
        super(NMT_Translator, self).__init__()
        
        self.encoder = Encoder(vocab_size, units)
        self.decoder = Decoder(vocab_size, units)
        
    def forward(self, context, target):
        encoded_context = self.encoder(context)
        logits = self.decoder(encoded_context, target) # this target here is the target_in (target_out is used for training)
        
        return logits # (batch_size, max_seq_len, vocab_size)

In [64]:
vocab_size = config['MAX_VOCAB_SIZE']
units = config['UNITS']

translator = NMT_Translator(vocab_size, units)

context_tensor = torch.randint(0, vocab_size, (8, 15)).long()  # (batch_size, seq_len, units)
target_tensor = torch.randint(0, vocab_size, (8, 16)).long()  # (batch_size, seq_len)
print(context_tensor.shape, target_tensor.shape)


output = translator(context_tensor, target_tensor)
print(output.shape)  # (batch_size, max_seq_len)

torch.Size([8, 15]) torch.Size([8, 16])
torch.Size([8, 16, 13000])


In [95]:
class MaskedAcc(nn.Module):
    def __init__(self):
        super(MaskedAcc, self).__init__()

    def forward(self, y_pred, y_true):
        y_pred = y_pred.argmax(dim = -1)

        mask = (y_true != 0).float()
        correct = (y_true == y_pred).float() * mask
        return correct.sum() / mask.sum()
    
class SparseCategoricalMaskedLoss(nn.Module):
    def __init__(self):
        super(SparseCategoricalMaskedLoss, self).__init__()
    
    def forward(self, y_pred, y_true):
        batch_size, seq_len, vocab_size = y_pred.shape
    
        # mask to avoid the padding sequence
        # for loss mask
        mask = (y_true != 0).float()

        y_pred = y_pred.view(-1, y_pred.size(-1))
        y_true = y_true.view(-1)
        
        loss = F.cross_entropy(y_pred, y_true.long(), reduction='none')

        loss = loss.view(batch_size, seq_len)
        loss *= mask

        return loss.sum() / mask.sum()

In [96]:
vocab_size = config['MAX_VOCAB_SIZE']
units = config['UNITS']
model = NMT_Translator(vocab_size, units)

optimizer = optim.Adam(model.parameters(), lr = 0.001)
loss = nn.CrossEntropyLoss()
masked_loss_fn = MaskedLoss()
masked_acc_fn = MaskedAcc()

sparse_masked_loss_fn = SparseCategoricalMaskedLoss()

In [97]:
best_val_loss = float('inf')
patience_counter = 0

for epoch in range(config['EPOCHS']):
    model.train()
    total_loss = 0
    total_acc = 0

    train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['EPOCHS']}", leave = False)
    for batch in train_loader_tqdm:
        
        # encoder input, pre-attention-decoder input and post-attention-decoder target output
        context, target_in, target_out = batch

        optimizer.zero_grad()
        output = model(context, target_in)

        loss = sparse_masked_loss_fn(output.float(), target_out.float())
        acc = masked_acc_fn(output.float(), target_out.float())
        
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_acc += acc.item()
        train_loader_tqdm.set_postfix(loss = loss.item(), accuracy = acc.item())

    avg_loss = total_loss / len(train_loader)
    avg_acc = total_acc / len(train_loader)

    print(f"Epoch {epoch + 1}/{config['EPOCHS']}")
    print(f"Training loss: {avg_loss:.4f}, accuracy: {avg_acc:.4f}")

    # Validation
    model.eval()
    val_loss = 0

    with torch.no_grad():
        for batch in val_loader:
            context, target_in, target_out = batch
            output = model(context, target_in)
            loss = sparse_masked_loss_fn(output.float(), target_out.float())
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"Validation loss: {avg_val_loss:.4f}")

    # Early stopping
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        torch.save(model.state_dict(), 'best_model.pth')
    else:
        patience_counter += 1
        if patience_counter >= 3:
            print("Early stopping")
            break

# Load the best model
# model.load_state_dict(torch.load('best_model.pth'))


                                                                                                                                                                                                                  

Epoch 1/10
Training loss: 4.3761, accuracy: 0.4669
Validation loss: 3.5949


                                                                                                                                                                                                                  

Epoch 2/10
Training loss: 2.9841, accuracy: 0.5087
Validation loss: 3.3240


                                                                                                                                                                                                                  

Epoch 3/10
Training loss: 2.6018, accuracy: 0.5450
Validation loss: 3.2649


                                                                                                                                                                                                                  

Epoch 4/10
Training loss: 2.3407, accuracy: 0.5679
Validation loss: 3.1431


                                                                                                                                                                                                                  

Epoch 5/10
Training loss: 2.1225, accuracy: 0.5872
Validation loss: 3.0771


                                                                                                                                                                                                                  

Epoch 6/10
Training loss: 1.9355, accuracy: 0.6114
Validation loss: 3.1084


                                                                                                                                                                                                                  

Epoch 7/10
Training loss: 1.7610, accuracy: 0.6262
Validation loss: 2.9952


                                                                                                                                                                                                                  

Epoch 8/10
Training loss: 1.5568, accuracy: 0.6572
Validation loss: 2.9404


                                                                                                                                                                                                                  

Epoch 9/10
Training loss: 1.3914, accuracy: 0.6735
Validation loss: 2.8968


                                                                                                                                                                                                                  

Epoch 10/10
Training loss: 1.2401, accuracy: 0.6970
Validation loss: 2.9446


  model.load_state_dict(torch.load('best_model.pth'))


<All keys matched successfully>

#### Next Token Generation

In [148]:
def generate_next_token(decoder, context, next_token, done, state, eos_id, temperature=0.0):
    
    logits, state = decoder(context, next_token, state = state, return_state = True)
    
    # Shape: [batch_size, vocab_size]
    logits = logits.squeeze(0)    
    
    if temperature == 0.0:
        next_token = torch.argmax(logits, dim=-1)
        
    else:
        logits = logits / temperature
        next_token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
    

    logits = logits.squeeze()
    next_token = next_token.squeeze()
    
    logit = logits[next_token].item()  
    
    
    next_token = next_token.view(1, 1)
    
    if next_token.item() == eos_id:
        done = True
    
    return next_token, logit, state, done


In [149]:
model_path = './best_model.pth'

inf_model = NMT_Translator(vocab_size, units)
inf_model.load_state_dict(torch.load(model_path, weights_only = True))

<All keys matched successfully>

In [205]:
def translate(model, context, context_tokenizer, sos_id = 1, eos_id = 2, max_gen_len = 16, temp = 0.7):
    
    tokens, logits = [], []

    context = context_tokenizer(test_en_sentence)
    context = torch.tensor(en_vocab.token_to_ids(context), dtype = torch.long).to('cpu')
    context = model.encoder(context.unsqueeze(0))

    next_token = torch.tensor([[sos_id]], dtype=torch.long)

    state = [
        torch.zeros((1, 1, config['UNITS']), dtype=torch.float),
        torch.zeros((1, 1, config['UNITS']), dtype=torch.float)
        ]

    done = False

    for i in range(max_gen_len):
        next_token, logit, state, done = generate_next_token(
                decoder = model.decoder,
                context = context,
                next_token = next_token,
                done = done,
                state = state,
                eos_id = eos_id,
                temperature = temp
            )
        if done:
            break


        tokens.append(next_token)

        logits.append(logit)

    tokens = torch.cat(tokens, dim=-1)
    return " ".join(por_vocab.ids_to_token(tokens.squeeze().tolist()))

temp = 0.7
test_en_sentence = 'Hello, there'
sos_id = 1
eos_id = 2
max_gen_len = 16

translate(inf_model, test_en_sentence, tokenizer, sos_id, eos_id, max_gen_len, temp)

'salve o [UNK] ele o [UNK] o [UNK] tom meu e uma [UNK] o [UNK] e'

In [211]:
import sacrebleu


candidates = [
    "The cat is on the mat.",
    "There is a cat on the mat."
]

references = [
    ["The cat is sitting on the mat."],
    ["There is a cat on the mat."]
]

bleu = sacrebleu.corpus_bleu(candidates, references)

print(f"BLEU score: {bleu.score}")


BLEU score: 51.54486831107658
