In [None]:
! pip install datasets

In [None]:
! pip install transformers

In [None]:
from datasets import load_dataset
import torch
import numpy as np
import pandas as pd
from torchtext.data.utils import get_tokenizer
from transformers import AutoTokenizer, BertTokenizer
import string
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

# Load Database

In [None]:
dataset = load_dataset("persiannlp/parsinlu_translation_fa_en")
dataset

In [None]:
df_train = pd.DataFrame(dataset['train'], columns = ['source','targets'])
df_test = pd.DataFrame(dataset['test'], columns = ['source','targets'])
df_val = pd.DataFrame(dataset['validation'], columns = ['source', 'targets'])

In [None]:
df_train['target'] = df_train['targets'].map(lambda x: x[0])
df_test['target'] = df_train['targets'].map(lambda x: x[0])
df_val['target'] = df_train['targets'].map(lambda x: x[0])

In [None]:
df_train.head()

# Pre Process

In [None]:
def remove_punctuation(s):
    punctuationFree="".join([i for i in s if i not in string.punctuation])
    return punctuationFree

In [None]:
def preprocess_persian_string(s):
    punctuationFree = remove_punctuation(s)
    return punctuationFree

In [None]:
def preprocess_english_string(s):
    punctuationFree = remove_punctuation(s)
    return punctuationFree.lower().strip()

# Tokenizer

In [None]:
fa_model = "HooshvareLab/bert-fa-base-uncased"
fa_tokenizer = BertTokenizer.from_pretrained(fa_model)

In [None]:
en_model = "bert-base-uncased"
en_tokenizer = BertTokenizer.from_pretrained(en_model)

In [None]:
FA_SPACIAL_TOKENS = fa_tokenizer.all_special_tokens
FA_SPACIAL_TOKENS_IDS = fa_tokenizer.all_special_ids

print(FA_SPACIAL_TOKENS)
print(FA_SPACIAL_TOKENS_IDS)

In [None]:
EN_SPACIAL_TOKENS = en_tokenizer.all_special_tokens
EN_SPACIAL_TOKENS_IDS = en_tokenizer.all_special_ids

print(EN_SPACIAL_TOKENS)
print(EN_SPACIAL_TOKENS_IDS)

# Dataloader

In [None]:
def generate_data(df):
    data = []
    for index, row in df.iterrows():
        source = preprocess_persian_string(row['source'])
        target = preprocess_english_string(row['target'])

        fa_tensor_ = torch.tensor(fa_tokenizer(source).input_ids)
        en_tensor_ = torch.tensor(en_tokenizer(target).input_ids)
        data.append((fa_tensor_, en_tensor_))
    return data

In [None]:
train_data = generate_data(df_train)
test_data = generate_data(df_test)
val_data = generate_data(df_val)

In [None]:
train = pd.DataFrame(train_data)
train.to_csv('train.csv', index=False, header=False)

val = pd.DataFrame(val_data)
val.to_csv('val.csv', index=False, header=False)

test = pd.DataFrame(test_data)
test.to_csv('val.csv', index=False, header=False)

In [None]:
train_data[0]

In [None]:
FA_TRAIN_VOCAB_SIZE = max([max(data[0]) for data in train_data])
FA_VALIDATION_VOCAB_SIZE = max([max(data[0]) for data in val_data])
FA_TEST_VOCAB_SIZE = max([max(data[0]) for data in test_data])

In [None]:
EN_TRAIN_VOCAB_SIZE = max([max(data[1]) for data in train_data])
EN_VALIDATION_VOCAB_SIZE = max([max(data[1]) for data in val_data])
EN_TEST_VOCAB_SIZE = max([max(data[1]) for data in test_data])

In [None]:
FA_TRAIN_MAX_LEN = max([len(data[0]) for data in train_data])
FA_VALIDATION_MAX_LEN = max([len(data[0]) for data in val_data])
FA_TEST_MAX_LEN = max([len(data[0]) for data in test_data])

In [None]:
EN_TRAIN_MAX_LEN = max([len(data[1]) for data in train_data])
EN_VALIDATION_MAX_LEN = max([len(data[1]) for data in val_data])
EN_TEST_MAX_LEN = max([len(data[1]) for data in test_data])

In [None]:
SOURCE_MAX_LEN = max([FA_TRAIN_MAX_LEN,FA_VALIDATION_MAX_LEN,FA_TEST_MAX_LEN])
TARGET_MAX_LEN = max([EN_TRAIN_MAX_LEN,EN_VALIDATION_MAX_LEN,EN_TEST_MAX_LEN])

In [None]:
TARGET_MAX_LEN

In [None]:
train_data[0]

In [None]:
def generate_batch(data_batch):

    data = []
    for idx, (fa, en) in enumerate(data_batch):
        if len(fa) < TARGET_MAX_LEN:
            padding_tensor =  torch.zeros(TARGET_MAX_LEN - len(fa),dtype=int)
            padding_tensor[:] = FA_SPACIAL_TOKENS_IDS[2]
            fa = torch.cat((fa,padding_tensor), dim=0) 
        
        if len(en) < TARGET_MAX_LEN:
            padding_tensor =  torch.zeros(TARGET_MAX_LEN - len(en),dtype=int)
            padding_tensor[:] = EN_SPACIAL_TOKENS_IDS[2]
            en = torch.cat((en,padding_tensor), dim=0) 
        data.append( (fa,en) )
            
    return data

In [None]:
train_padded = generate_batch(train_data)

In [None]:
val_padded = generate_batch(val_data)

In [None]:
test_padded = generate_batch(test_data)

In [None]:
BATCH_SIZE = 16

train_iter = DataLoader(train_padded, batch_size=BATCH_SIZE, shuffle=True)
valid_iter = DataLoader(val_padded, batch_size=BATCH_SIZE, shuffle=False)
test_iter = DataLoader(test_padded, batch_size=BATCH_SIZE, shuffle=False)

# Model

In [None]:
SOURCE_VOCAB_SIZE = int(max([FA_TEST_VOCAB_SIZE, FA_TRAIN_VOCAB_SIZE, FA_VALIDATION_VOCAB_SIZE]))
TARGET_VOCAB_SIZE = int(max([EN_TEST_VOCAB_SIZE, EN_TRAIN_VOCAB_SIZE, EN_VALIDATION_VOCAB_SIZE]))
NUM_LAYER = 6
# BATCH_SIZE = 32
EMBED_DIMENSION = 512
NUMBER_HEADS = 8

In [None]:
%run transformer.ipynb

In [None]:
class TransformerModel(nn.Module):
    def __init__(self, src_vocab_size, target_vocab_size, embed_dim, seq_len, num_layers, expansion_factor=4, n_heads=8):
        super().__init__()
        
        self.embed_dim = embed_dim
        self.transformer = Transformer(embed_dim,src_vocab_size,target_vocab_size,seq_len,num_layers, expansion_factor, n_heads)
        self.output = nn.Linear(embed_dim, target_vocab_size)

    def forward(self,src, tgt):
        print("tgt",tgt.shape)
        output = self.transformer(src, tgt)
        
        return self.output(output)

In [None]:
def train_transformer(model, iterator, optimizer, loss_fn, device, clip=None):
    model.train()
        
    epoch_loss = 0
    with tqdm(total=len(iterator), leave=False) as t:
        for i, (src, tgt) in enumerate(iterator):
            src = src.to(device)
            tgt = tgt.to(device)
           
            
            # Create tgt_inp and tgt_out (which is tgt_inp but shifted by 1)
            tgt_inp, tgt_out = tgt[:, :-1], tgt[:, 1:]
            print("tgt_inp", tgt_inp.shape)
            optimizer.zero_grad()
                
            output = model(src=src, tgt=tgt_inp)
            
            loss = loss_fn(output.view(-1, output.shape[2]),
                           tgt_out.view(-1))
            
            loss.backward()
            
            # if clip is not None:
            #     nn.utils.clip_grad_norm_(model.parameters(), clip)
            
            optimizer.step()
            epoch_loss += loss.item()
            
            avg_loss = epoch_loss / (i+1)
            t.set_postfix(loss='{:05.3f}'.format(avg_loss),
                          ppl='{:05.3f}'.format(np.exp(avg_loss)))
            t.update()
            
    return epoch_loss / len(iterator)

In [None]:
def evaluate_transformer(model, iterator, loss_fn, device):
    model.eval()
        
    epoch_loss = 0
    with torch.no_grad():
        with tqdm(total=len(iterator), leave=False) as t:
            for i, (src, tgt) in enumerate(iterator):
                src = src.to(device)
                tgt = tgt.to(device)
                
                # Create tgt_inp and tgt_out (which is tgt_inp but shifted by 1)
                tgt_inp, tgt_out = tgt[:, :-1], tgt[:, 1:]

                output = model(src=src, tgt=tgt_inp)
                
                loss = loss_fn(output.view(-1, output.shape[2]),
                               tgt_out.view(-1))
                
                epoch_loss += loss.item()
                
                avg_loss = epoch_loss / (i+1)
                t.set_postfix(loss='{:05.3f}'.format(avg_loss),
                              ppl='{:05.3f}'.format(np.exp(avg_loss)))
                t.update()
    
    return epoch_loss / len(iterator)

In [None]:
def count_params(model, return_int=False):
    params = sum([torch.prod(torch.tensor(x.shape)).item() for x in model.parameters() if x.requires_grad])
    if return_int:
        return params
    else:
        print("There are {:,} trainable parameters in this model.".format(params))

In [None]:
transformer = TransformerModel(src_vocab_size=SOURCE_VOCAB_SIZE, target_vocab_size=TARGET_VOCAB_SIZE, embed_dim=256, seq_len=TARGET_MAX_LEN, num_layers=2, expansion_factor=4, n_heads=8)

transformer = transformer.to(device)

In [None]:
count_params(transformer)

In [None]:
xf_optim = torch.optim.AdamW(transformer.parameters(), lr=1e-4)

In [None]:
PAD_IDX = FA_SPACIAL_TOKENS_IDS[2]
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

In [None]:
%%time
N_EPOCHS = 50
CLIP = 16 # clipping value, or None to prevent gradient clipping
EARLY_STOPPING_EPOCHS = 5
SAVE_DIR = ''
import os
    
model_path = os.path.join(SAVE_DIR, 'transformer_en_fr.pt')
transformer_metrics = {}
best_valid_loss = float("inf")
early_stopping_count = 0
for epoch in tqdm(range(N_EPOCHS), desc="Epoch"):
    train_loss = train_transformer(transformer, train_iter, xf_optim, loss_fn, device, clip=CLIP)
    valid_loss = evaluate_transformer(transformer, valid_iter, loss_fn, device)
    
    if valid_loss < best_valid_loss:
        tqdm.write(f"Checkpointing at epoch {epoch + 1}")
        best_valid_loss = valid_loss
        torch.save(transformer.state_dict(), model_path)
        early_stopping_count = 0
    elif epoch > EARLY_STOPPING_EPOCHS:
        early_stopping_count += 1
    
    transformer_metrics[epoch+1] = dict(
        train_loss = train_loss,
        train_ppl = np.exp(train_loss),
        valid_loss = valid_loss,
        valid_ppl = np.exp(valid_loss)
    )
    
    if early_stopping_count == EARLY_STOPPING_EPOCHS:
        tqdm.write(f"Early stopping triggered in epoch {epoch + 1}")
        break