In [28]:
from transformer import Transformer
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import importlib
from multi_head_attention import MultiHeadAttention

In [29]:
def reload_modules():
    importlib.reload(Transformer)
    importlib.reload(MultiHeadAttention)

In [30]:
START_TOKEN = '<s>'
END_TOKEN = '<\s>'
PADDING_TOKEN = '<pad>'
english_vocabulary = [START_TOKEN, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', 
                    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
                    ':', '<', '=', '>', '?', '@', ';',
                    '[', '\\', ']',
                    '^', '_', '`', 
                    'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
                    'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 
                    'y', 'z', 
                    '{', '|', '}', '~', PADDING_TOKEN, END_TOKEN
                    ]

persian_vocabulary = [
START_TOKEN, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', 
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ';',
':', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', 
'آ', 'ا', 'ب', 'پ', 'ت', 'ث', 'ج', 'چ', 'ح', 'خ', 'د', 'ذ', 'ر', 'ز', 'ژ', 'س', 'ش', 
'ص', 'ض', 'ط', 'ظ', 'ع', 'غ', 'ف', 'ق', 'ک', 'گ', 'ل', 'م', 'ن', 'و', 'ه', 'ی',
'ء', 'ۀ', 'ؤ', 'ي', 'ك', 'ة', '‌', 'ٔ', 'ى', PADDING_TOKEN, END_TOKEN
]
index_to_persian = {k:v for k,v in enumerate(persian_vocabulary)}
persian_to_index = {v:k for k,v in enumerate(persian_vocabulary)}
index_to_english = {k:v for k,v in enumerate(english_vocabulary)}
english_to_index = {v:k for k,v in enumerate(english_vocabulary)}

In [31]:
df = pd.read_csv('../dataset/shortened_dataset.csv')
df.head()

Unnamed: 0,persian,english
0,گلدان روی میز چای حاضر و آماده بود.,the vase filled with water was ready in the ce...
1,آن وقت قاضی چه کرد؟,What did the justice do?
2,به روزگار فيلماي ؛ نقطه تلاقي ؛ ماري کثيف يا ه...,"vanishing point days , the dirty mary crazy la..."
3,افراد مورد اعتماد زیردستهایشان به عنوان سرپرست...,with the trust of his subordinates as the head...
4,زودتر برویم. من حاضرم.,"I am ready, my son, said Mercedes."


In [32]:
df['english'] = df['english'].astype(str)
df['persian'] = df['persian'].astype(str)

In [33]:
def helper_english(x:str):
        for c in x:
            if not c in english_vocabulary:
                x = x.replace(c, '')
        return x

def helper_persian(x:str):
    for c in x:
        if not c in persian_vocabulary:
            x = x.replace(c, '')
    return x

df['english'] = df['english'].apply(str.lower)
df['english'] = df['english'].apply(helper_english)
df['persian'] = df['persian'].apply(helper_persian)
persian_sentences = df['persian'].to_list()
english_sentences = df['english'].to_list()
enlish_sentences = df['english'].to_list()
persian_sentences = df['persian'].to_list()

In [34]:
model_dim = 512
batch_size = 30
hidden_fc = 2048
num_heads = 8
drop_prob = 0.1
num_layers = 1
max_sequence_length = 200
persian_vocab_size = len(persian_vocabulary)

transformer = Transformer((batch_size, max_sequence_length, model_dim),
                          model_dim, 
                          hidden_fc,
                          num_heads, 
                          drop_prob, 
                          num_layers, 
                          max_sequence_length,
                          persian_vocab_size,
                          english_to_index,
                          persian_to_index,
                          START_TOKEN, 
                          END_TOKEN, 
                          PADDING_TOKEN)

In [35]:
class TranslateDataset(Dataset):
    def __init__(self, english_sentences, persian_sentences):
        super().__init__()
        self.english_sentences = english_sentences
        self.persian_sentences = persian_sentences

    def __len__(self):
        return len(self.persian_sentences)

    def __getitem__(self, idx):
        return self.english_sentences[idx], self.persian_sentences[idx]    
    

dataset = TranslateDataset(english_sentences, persian_sentences)
train_loader = DataLoader(dataset, batch_size)

In [36]:
criterion = nn.CrossEntropyLoss(ignore_index=persian_to_index[PADDING_TOKEN])
for params in transformer.parameters():
    if params.dim() > 1:
        nn.init.xavier_uniform_(params)

optim = torch.optim.Adam(transformer.parameters(), lr=1e-4)
device = torch.device('cuda')

In [37]:

NEG_INFTY = -1e9  

def create_masks(eng_batch, persian_batch, number_of_heads):  
    num_sentences = len(eng_batch)  
    look_ahead_mask = torch.full([max_sequence_length, max_sequence_length], True)  
    look_ahead_mask = torch.triu(look_ahead_mask, diagonal=1)  
    
    encoder_padding_mask = torch.full([num_sentences, max_sequence_length, max_sequence_length], False)  
    decoder_padding_mask_self_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length], False)  
    decoder_padding_mask_cross_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length], False)  

    for idx in range(num_sentences):  
        eng_sentence_length, persian_sentence_length = len(eng_batch[idx]), len(persian_batch[idx])  
        eng_chars_to_padding_mask = np.arange(eng_sentence_length, max_sequence_length)  
        persian_chars_to_padding_mask = np.arange(persian_sentence_length, max_sequence_length)  

        encoder_padding_mask[idx, eng_chars_to_padding_mask, :] = True  
        encoder_padding_mask[idx, :, eng_chars_to_padding_mask] = True  
        
        decoder_padding_mask_self_attention[idx, persian_chars_to_padding_mask, :] = True  
        decoder_padding_mask_self_attention[idx, :, persian_chars_to_padding_mask] = True  
        
        decoder_padding_mask_cross_attention[idx, persian_chars_to_padding_mask, :] = True  
        decoder_padding_mask_cross_attention[idx, :, eng_chars_to_padding_mask] = True  

    # Create self-attention masks  
    encoder_self_attention_mask = torch.where(encoder_padding_mask, NEG_INFTY, 0)  
    decoder_self_attention_mask = torch.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFTY, 0)  
    decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, NEG_INFTY, 0)  

    # Expand the masks to have the number of heads  
    encoder_self_attention_mask = encoder_self_attention_mask.unsqueeze(1).repeat(1, number_of_heads, 1, 1)  
    decoder_self_attention_mask = decoder_self_attention_mask.unsqueeze(1).repeat(1, number_of_heads, 1, 1)  
    decoder_cross_attention_mask = decoder_cross_attention_mask.unsqueeze(1).repeat(1, number_of_heads, 1, 1)  

    return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask  


In [38]:
eng_batch = ['hello', 'how are you']
persian_batch = ['سلام', 'چطوری']
encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(eng_batch, persian_batch, 8)

In [39]:
encoder_self_attention_mask.shape

torch.Size([2, 8, 200, 200])

In [40]:
transformer.train()
transformer.to(device)
total_loss = 0
num_epochs = 5

for epoch in range(num_epochs):
    print(f"Epoch {epoch}")
    iterator = iter(train_loader)
    for batch_num, batch in enumerate(iterator):
        transformer.train()
        eng_batch, per_batch = batch
        encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(eng_batch, per_batch, num_heads)
        optim.zero_grad()
        persian_predictions = transformer(eng_batch,
                                     per_batch,
                                     encoder_self_attention_mask.to(device), 
                                     decoder_self_attention_mask.to(device), 
                                     decoder_cross_attention_mask.to(device),
                                     encoder_start_token=False,
                                     encoder_end_token=False,
                                     decoder_start_token=True,
                                     decoder_end_token=True)
        labels = transformer.decoder.sentence_embedding.batch_tokenize(per_batch, start_token=False, end_token=True)
        loss = criterion(
            persian_predictions.view(-1, persian_vocab_size).to(device),
            labels.view(-1).to(device)
        ).to(device)
        valid_indicies = torch.where(labels.view(-1) == persian_to_index[PADDING_TOKEN], False, True)
        loss = loss.sum() / valid_indicies.sum()
        loss.backward()
        optim.step()
        #train_losses.append(loss.item())
        if batch_num % 100 == 0:
            print(f"Iteration {batch_num} : {loss.item()}")
            print(f"English: {eng_batch[0]}")
            print(f"Persian Translation: {per_batch[0]}")
            persian_sentence_predicted = torch.argmax(persian_predictions[0], axis=1)
            predicted_sentence = ""
            for idx in persian_sentence_predicted:
              if idx == persian_to_index[END_TOKEN]:
                break
              predicted_sentence += index_to_persian[idx.item()]
            print(f"Persian Prediction: {predicted_sentence}")


            transformer.eval()
            persian_sentence = ("",)
            eng_sentence = ("should we go to the mall?",)
            for word_counter in range(max_sequence_length):
                encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask= create_masks(eng_sentence, persian_sentence, num_heads)
                predictions = transformer(eng_sentence,
                                          persian_sentence,
                                          encoder_self_attention_mask.to(device), 
                                          decoder_self_attention_mask.to(device), 
                                          decoder_cross_attention_mask.to(device),
                                          encoder_start_token=False,
                                          encoder_end_token=False,
                                          decoder_start_token=True,
                                          decoder_end_token=False)
                next_token_prob_distribution = predictions[0][word_counter] # not actual probs
                next_token_index = torch.argmax(next_token_prob_distribution).item()
                next_token = index_to_persian[next_token_index]
                persian_sentence = (persian_sentence[0] + next_token, )
                if next_token == END_TOKEN:
                  break
            
            print(f"Evaluation translation (should we go to the mall?) : {persian_sentence}")
            print("-------------------------------------------")

Epoch 0
Iteration 0 : 0.003514297306537628
English: the vase filled with water was ready in the center of the tea table.
Persian Translation: گلدان روی میز چای حاضر و آماده بود.
Persian Prediction: >هظ*ك/2$1غ"تظ]ص<_#ۀ#&ا
Evaluation translation (should we go to the mall?) : ('ۀ/ظ*?/2$1غ"ن5لص%_#خ9&ا<\\s>',)
-------------------------------------------
Iteration 100 : 0.0025248141027987003
English: including you .
Persian Translation: تو را هم شامل ميشه .
Persian Prediction: اا                                             ا                                            ا       ی  ایا  ا
Evaluation translation (should we go to the mall?) : ('اا                                             ا                                            ا       یر  یا  ن<\\s>',)
-------------------------------------------
Iteration 200 : 0.002507735276594758
English: miguel centellas of pronto* takes a look at the worrisome inflation rate in bolivia, which has affected the lower middle classes and small business own

KeyboardInterrupt: 