In [2]:
from model import Transformer
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import zipfile

In [None]:
!kaggle datasets download -d jigarpanjiyar/english-to-bengali-dataset -p dataset
zip_ref = zipfile.ZipFile('/content/dataset/english-to-bengali-dataset.zip', 'r')
zip_ref.extractall('/content/dataset')
zip_ref.close()

In [3]:
english_file = './dataset/1_Eng.txt'
bengali_file = './dataset/1_Bengali.txt'
START_TOKEN = '<START>'
PADDING_TOKEN = '<PADDING>'
END_TOKEN = '<END>'
bengali_vocabulary = [START_TOKEN, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', 
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', '<', '=', '>', '?', 'ˌ', 
    'অ', 'আ', 'ই', 'ঈ', 'উ', 'ঊ', 'ঋ', 'ৠ', 'ঌ', 'এ', 'ঐ', 'ও', 'ঔ', 
    'ক', 'খ', 'গ', 'ঘ', 'ঙ', 
    'চ', 'ছ', 'জ', 'ঝ', 'ঞ', 
    'ট', 'ঠ', 'ড', 'ঢ', 'ণ', 
    'ত', 'থ', 'দ', 'ধ', 'ন', 
    'প', 'ফ', 'ব', 'ভ', 'ম', 
    'য', 'র', 'ল', 'শ', 'ষ', 'স', 'হ', 
    'ড়', 'ঢ়','য়',
    '০', '১', '২', '৩', '৪', '৫', '৬', '৭', '৮', '৯', 
    '।', '৳', '...', ',,','া','ে','ং','ি','\n','্','ো','ু','ী',';',PADDING_TOKEN, END_TOKEN,'়']

english_vocabulary = [START_TOKEN, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', 
                        '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
                        ':', '<', '=', '>', '?', '@',
                        '[', '\\', ']', '^', '_', '`', 
                        'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
                        'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 
                        'y', 'z','{', '|', '}', '~',';',PADDING_TOKEN, END_TOKEN]

In [18]:
index_to_bengali = {k:v for k,v in enumerate(bengali_vocabulary)}
bengali_to_index = {v:k for k,v in enumerate(bengali_vocabulary)}
index_to_english = {k:v for k,v in enumerate(english_vocabulary)}
english_to_index = {v:k for k,v in enumerate(english_vocabulary)}

In [5]:
with open(english_file, 'r') as file:
    english_sentences = file.readlines()
with open(bengali_file, 'r') as file:
    bengali_sentences = file.readlines()
TOTAL_SENTENCES = 200000
english_sentences = english_sentences[:TOTAL_SENTENCES]
bengali_sentences = bengali_sentences[:TOTAL_SENTENCES]
english_sentences = [sentence.rstrip('\n').lower() for sentence in english_sentences]
bengali_sentences = [sentence.rstrip('\n') for sentence in bengali_sentences]

In [6]:
max_sequence_length = 200

def is_valid_tokens(sentence, vocab):
    for token in list(set(sentence)):
        if token not in vocab:
            return False
    return True

def is_valid_length(sentence, max_sequence_length):
    return len(list(sentence)) < (max_sequence_length - 1)

valid_sentence_indicies = []
for index in range(len(bengali_sentences)):
    bengali_sentence, english_sentence = bengali_sentences[index], english_sentences[index]
    if is_valid_length(bengali_sentence, max_sequence_length) \
      and is_valid_length(english_sentence, max_sequence_length) \
      and is_valid_tokens(bengali_sentence, bengali_vocabulary):
        valid_sentence_indicies.append(index)

print(f"Number of sentences: {len(bengali_sentences)}")
print(f"Number of valid sentences: {len(valid_sentence_indicies)}")

Number of sentences: 200000
Number of valid sentences: 111557


In [7]:
bengali_sentences = [bengali_sentences[i] for i in valid_sentence_indicies] 
english_sentences = [english_sentences[i] for i in valid_sentence_indicies]

In [8]:
d_model = 512
batch_size = 30
ffn_hidden = 2048
num_heads = 8
drop_prob = 0.1
num_layers = 1
max_sequence_length = 200
bn_vocab_size = len(bengali_vocabulary)

transformer = Transformer(d_model, 
                          ffn_hidden,
                          num_heads, 
                          drop_prob, 
                          num_layers, 
                          max_sequence_length,
                          bn_vocab_size,
                          english_to_index,
                          bengali_to_index,
                          START_TOKEN, 
                          END_TOKEN, 
                          PADDING_TOKEN)

In [9]:
class TextDataset(Dataset):

    def __init__(self, english_sentences, bengali_sentences):
        self.english_sentences = english_sentences
        self.bengali_sentences = bengali_sentences

    def __len__(self):
        return len(self.english_sentences)

    def __getitem__(self, idx):
        return self.english_sentences[idx], self.bengali_sentences[idx]

In [10]:
dataset = TextDataset(english_sentences, bengali_sentences)

In [11]:
dataset[1]

('the agency does not believe that the primary reason for the spread of the virus is the importation or packaging of salmon.',
 'স্যামন মাছের আমদানি বা প্যাকেজিং থেকেই ভাইরাসের সংক্রমণ ছড়াচ্ছে, এটাই প্রাথমিক কারণ হতে পারে না বলে মনে করছে সংস্থাটি।')

In [12]:
train_loader = DataLoader(dataset, batch_size)
iterator = iter(train_loader)

In [13]:
for batch_num, batch in enumerate(iterator):
    if batch_num > 3:
        break

In [14]:
from torch import nn

criterian = nn.CrossEntropyLoss(ignore_index=bengali_to_index[PADDING_TOKEN],
                                reduction='none')
optim = torch.optim.Adam(transformer.parameters(), lr=1e-4)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [15]:
NEG_INFTY = -1e9

def create_masks(eng_batch, bn_batch):
    num_sentences = len(eng_batch)
    look_ahead_mask = torch.full([max_sequence_length, max_sequence_length] , True)
    look_ahead_mask = torch.triu(look_ahead_mask, diagonal=1)
    encoder_padding_mask = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_self_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_cross_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    for idx in range(num_sentences):
      eng_sentence_length, bn_sentence_length = len(eng_batch[idx]), len(bn_batch[idx])
      eng_chars_to_padding_mask = np.arange(eng_sentence_length + 1, max_sequence_length)
      bn_chars_to_padding_mask = np.arange(bn_sentence_length + 1, max_sequence_length)
      encoder_padding_mask[idx, :, eng_chars_to_padding_mask] = True
      encoder_padding_mask[idx, eng_chars_to_padding_mask, :] = True
      decoder_padding_mask_self_attention[idx, :, bn_chars_to_padding_mask] = True
      decoder_padding_mask_self_attention[idx, bn_chars_to_padding_mask, :] = True
      decoder_padding_mask_cross_attention[idx, :, eng_chars_to_padding_mask] = True
      decoder_padding_mask_cross_attention[idx, bn_chars_to_padding_mask, :] = True
    encoder_self_attention_mask = torch.where(encoder_padding_mask, NEG_INFTY, 0)
    decoder_self_attention_mask =  torch.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFTY, 0)
    decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, NEG_INFTY, 0)
    return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask

In [None]:
transformer.train()
transformer.to(device)
total_loss = 0
num_epochs = 10
train_losses=[]
for epoch in range(num_epochs):
    print(f"Epoch {epoch}")
    iterator = iter(train_loader)
    for batch_num, batch in enumerate(iterator):
        transformer.train()
        eng_batch, bn_batch = batch
        encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(eng_batch, bn_batch)
        optim.zero_grad()
        bn_predictions = transformer(eng_batch,
                                     bn_batch,
                                     encoder_self_attention_mask.to(device), 
                                     decoder_self_attention_mask.to(device), 
                                     decoder_cross_attention_mask.to(device),
                                     enc_start_token=False,
                                     enc_end_token=False,
                                     dec_start_token=True,
                                     dec_end_token=True)
        labels = transformer.decoder.sentence_embedding.batch_tokenize(bn_batch, start_token=False, end_token=True)
        loss = criterian(
            bn_predictions.view(-1, bn_vocab_size).to(device),
            labels.view(-1).to(device)
        ).to(device)
        valid_indicies = torch.where(labels.view(-1) == bengali_to_index[PADDING_TOKEN], False, True)
        loss = loss.sum() / valid_indicies.sum()
        loss.backward()
        optim.step()
        train_losses.append(loss.item())
        if batch_num % 100 == 0:
            print(f"Iteration {batch_num} : {loss.item()}")
            print(f"English: {eng_batch[0]}")
            print(f"Bengali Translation: {bn_batch[0]}")
            bn_sentence_predicted = torch.argmax(bn_predictions[0], axis=1)
            predicted_sentence = ""
            for idx in bn_sentence_predicted:
              if idx == bengali_to_index[END_TOKEN]:
                break
              predicted_sentence += index_to_bengali[idx.item()]
            print(f"Bengali Prediction: {predicted_sentence}")


            transformer.eval()
            bn_sentence = ("",)
            eng_sentence = ("should we go to the mall?",)
            for word_counter in range(max_sequence_length):
                encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask= create_masks(eng_sentence, bn_sentence)
                predictions = transformer(eng_sentence,
                                          bn_sentence,
                                          encoder_self_attention_mask.to(device), 
                                          decoder_self_attention_mask.to(device), 
                                          decoder_cross_attention_mask.to(device),
                                          enc_start_token=False,
                                          enc_end_token=False,
                                          dec_start_token=True,
                                          dec_end_token=False)
                next_token_prob_distribution = predictions[0][word_counter]
                next_token_index = torch.argmax(next_token_prob_distribution).item()
                next_token = index_to_bengali[next_token_index]
                bn_sentence = (bn_sentence[0] + next_token, )
                if next_token == END_TOKEN:
                  break
            
            print(f"Evaluation translation (should we go to the mall?) : {bn_sentence}")
            print("-------------------------------------------")