In [None]:
!pip install --quiet transformers
!pip install --quiet pytorch-lightning
!pip install --quiet tokenizers
!pip install --quiet SentencePiece 

In [None]:
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from sklearn.model_selection import train_test_split
from termcolor import colored
import textwrap

from transformers import(
    AdamW,
    T5ForConditionalGeneration,
    T5Tokenizer
)
from tqdm.auto import tqdm

In [None]:
pl.seed_everything(20)

In [None]:
df = pd.read_csv("../input/persian-paraphrase-dataset/Chunk_4_official.csv")

In [None]:
df.head()

In [None]:
df = df.rename(columns= {"input_text": "question", "target_text": "answer"})


In [None]:
train_df, test_df = train_test_split(df, test_size= 0.2)

len(train_df)

# Dataset

In [None]:
class ParaphraseDataset(Dataset):
    def __init__(
        self,
        data: pd.DataFrame,
        tokenizer: T5Tokenizer.from_pretrained("google/mt5-small"),
        text_max_token_len: int = 512,
        paraphrase_max_token_len: int = 512
        
        ):
        
        self.data = data
        self.tokenizer = tokenizer
        self.text_max_token_len = text_max_token_len
        self.paraphrase_max_token_len = paraphrase_max_token_len
        
    
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, index: int):
            
        data_row = self.data.iloc[index]
            
        text = data_row["question"]
            
        text_encoding = tokenizer(
            text,
            max_length= self.text_max_token_len,
            padding= "max_length",
            truncation= True,
            return_attention_mask= True,
            add_special_tokens= True,
            return_tensors= "pt"
        )
            
        paraphrase_encoding = tokenizer(
            data_row["answer"],
            max_length= self.paraphrase_max_token_len,
            padding= "max_length",
            truncation= True,
            return_attention_mask= True,
            add_special_tokens= True,
            return_tensors= "pt"
        )
            
        labels = paraphrase_encoding["input_ids"]
        labels[labels == 0] = -100
            
        return dict(
            text = text,
            paraphrase = data_row["answer"],
            text_input_ids = text_encoding["input_ids"].flatten(),
            text_attention_mask = text_encoding["attention_mask"].flatten(),
            labels = labels.flatten(),
            labels_attention_mask = paraphrase_encoding["attention_mask"].flatten()
            )

In [None]:
class ParaphraseDataModule(pl.LightningDataModule):
    def __init__(
        self,
        train_df: pd.DataFrame,
        test_df: pd.DataFrame,
        tokenizer: T5Tokenizer,
        batch_size: int = 8,
        text_max_token_len: int = 512,
        paraphrase_max_token_len: int = 128
    ):
        
        super().__init__()
        
        self.train_df = train_df
        self.test_df = test_df
        
        self.batch_size = batch_size
        self.tokenizer = tokenizer
        self.text_max_token_len = text_max_token_len
        self.paraphrase_max_token_len = paraphrase_max_token_len
        
    def setup(self, stage= None):
        self.train_dataset = ParaphraseDataset(
            self.train_df,
            self.tokenizer,
            self.text_max_token_len,
            self.paraphrase_max_token_len
        )
        
        self.test_dataset = ParaphraseDataset(
            self.test_df,
            self.tokenizer,
            self.text_max_token_len,
            self.paraphrase_max_token_len
        )
    
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size = self.batch_size,
            shuffle = True,
            num_workers = 2
        )
    
    
    def val_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size = self.batch_size,
            shuffle = False,
            num_workers = 2  
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size = self.batch_size,
            shuffle = False,
            num_woekers = 2
        )
        
        
        
        

In [None]:
MODEL_NAME = "google/mt5-small"

tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)

In [None]:
text_token_counts, paraphrase_token_counts = [], []

train_df
for _, row in train_df.iterrows():
    text_token_count = len(tokenizer.encode(row["question"]))
    text_token_counts.append(text_token_count)
    
    paraphrase_token_count = len(tokenizer.encode(row["answer"]))
    paraphrase_token_counts.append(paraphrase_token_count)


In [None]:
max(text_token_counts), max(paraphrase_token_counts)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2)

sns.histplot(text_token_counts, ax= ax1)
ax1.set_title("Question token counts")

sns.histplot(paraphrase_token_counts, ax= ax2)
ax2.set_title("Answer token counts")

In [None]:
N_EPOCHS = 3
BATCH_SIZE = 8 

data_module = ParaphraseDataModule(train_df, test_df, tokenizer, batch_size= BATCH_SIZE) 


# Model

In [None]:
class ParaphraseModel(pl.LightningModule):
    
    def __init__(self):
        super().__init__()
        self.model =  T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict= True)
    
    def forward(self, input_ids, attention_mask, decoder_attention_mask, labels= None):
        
        output = self.model(
            input_ids,
            attention_mask = attention_mask,
            labels = labels,
            decoder_attention_mask = decoder_attention_mask
        )
        
        return output.loss, output.logits
    
    def training_step(self, batch, batch_idx):
        input_ids = batch["text_input_ids"]
        attention_mask = batch["text_attention_mask"]
        labels = batch["labels"]
        labels_attention_mask = batch["labels_attention_mask"]
        
        loss, outputs = self(
            input_ids = input_ids,
            attention_mask = attention_mask,
            decoder_attention_mask = labels_attention_mask,
            labels = labels
        )
        
        self.log("train_loss", loss, prog_bar= True, logger= True)
        return loss
        
        
    def validation_step(self, batch, batch_idx):
        input_ids = batch["text_input_ids"]
        attention_mask = batch["text_attention_mask"]
        labels = batch["labels"]
        labels_attention_mask = batch["labels_attention_mask"]
        
        loss, outputs = self(
            input_ids = input_ids,
            attention_mask = attention_mask,
            decoder_attention_mask = labels_attention_mask,
            labels = labels
        )
        
        self.log("val_loss", loss, prog_bar= True, logger= True)
        return loss
        
    def test_step(self, batch, batch_idx):
        input_ids = batch["text_input_ids"]
        attention_mask = batch["text_attention_mask"]
        labels = batch["labels"]
        labels_attention_mask = batch["labels_attention_mask"]
        
        loss, outputs = self(
            input_ids = input_ids,
            attention_mask = attention_mask,
            decoder_attention_mask = labels_attention_mask,
            labels = labels
        )
        
        self.log("test_loss", loss, prog_bar= True, logger= True)
        return loss
    
    
    def configure_optimizers(self):
        return AdamW(self.parameters(), lr = 0.0001)
        

In [None]:
model = ParaphraseModel()

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./lightning_logs


In [None]:
checkpoint_callback = ModelCheckpoint(
    dirpath   = "checkpoint",
    filename   = "best-checkpoint",
    save_top_k = 1,
    monitor    = "val_loss",
    mode       = "min"
)

logger = TensorBoardLogger("lightning_logs", name= "Paraphraser")


trainer =  pl.Trainer(
    logger                    = logger,
    enable_checkpointing      = checkpoint_callback,
    max_epochs                = N_EPOCHS,
    gpus                      = 2,
)


In [None]:
trainer.fit(model, data_module)

# Test

In [None]:
MODEL_NAME = "google/mt5-small"

tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)

trained_model = ParaphraseModel.load_from_checkpoint(
    trainer.checkpoint_callback.best_model_path
)
trained_model.freeze()

In [None]:
def paraphraser(text):
  text_encoding = tokenizer(
      text,
      max_length= 90,
      padding= True,
      return_attention_mask= True,
      add_special_tokens= True,
      return_tensors= "pt"
  )

  generated_ids = trained_model.model.generate(
      input_ids= text_encoding["input_ids"],
      attention_mask= text_encoding["attention_mask"],
      max_length= 512,
      num_beams= 2,
      reopetition_penalty= 2.5,
      lenght_penalty= 1.0,
      early_stopping= True
  )

  preds =  [
      tokenizer.decode(gen_id, skip_special_tokens= True, clean_up_tokenization_spaces= True)
      for gen_id in generated_ids
      ]

  return "".join(preds)