In [None]:
!pip install rouge_score
!pip install py7zr

### Work overview  (fine tuning a pretrained model for text summarization)

- base_mdoel - google pegasus pretrained on cnn_dailymail(daily news articles and its summaries)
- fine tuned on -samsum(samsumng customer conversation data and its summaries)

In [None]:
import torch
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm
import transformers
from transformers import AutoModelForSeq2SeqLM,AutoTokenizer,AutoModel
from datasets import load_dataset,load_metric

In [None]:
samsum_data = load_dataset("samsum")
# samsum -customer conversations and its summaries
# cnn_dailymail -news articles and its summaries

In [5]:
class Config:
    device='cuda' if torch.cuda.is_available() else 'cpu'
    infer_batch_size=4
    model_checkpoint='google/pegasus-cnn_dailymail' #pretrained on cnn_dailymail

class LoadData:
    def __init__(self,data,x,y,tokenizer=None):
        self.article=data[x]
        self.summary=data[y]
        self.tokenizer=tokenizer
    def __len__(self):
        return len(self.article)
    def __getitem__(self,idx):
        article=self.article.iloc[idx]
        summary=self.summary.iloc[idx]
        if self.tokenizer is not None:
            article=tokenizer(self.article.iloc[idx])
        return {'x':article,'y':summary}

tokenizer=AutoTokenizer.from_pretrained(Config.model_checkpoint)
model=AutoModelForSeq2SeqLM.from_pretrained(Config.model_checkpoint).to(Config.device)
model=torch.nn.DataParallel(model)
rouge_metric=load_metric('rouge')

def infer(data,model,tokenizer,metric):
    """function tokenizes and computes the output of the data 
    from the model,evaluates the score and returns it"""
    for batch in tqdm(data,total=len(data)):
        articles=batch['x']
        summaries=batch['y']
        tokens=tokenizer(articles,max_length=1024,padding='max_length',truncation=True,return_tensors='pt')
        out=model.module.generate(input_ids=tokens['input_ids'].to(Config.device),attention_mask=tokens['attention_mask'].to(Config.device),length_penalty=0.8,num_beams=8,max_length=128)
        pred_summaries=[tokenizer.decode(item,skip_special_tokens=True,clean_up_tokenization_spaces=True) for item in out]
        pred_summaries=[item.replace('<n>'," ") for item in pred_summaries]
        metric.add_batch(predictions=pred_summaries,references=summaries) 
        torch.cuda.empty_cache()
    score=metric.compute()
    return {i:score[i].mid.fmeasure for i in score.keys()}


# loading the test data and inferring on it to see the base performance
samsum_test=pd.DataFrame(samsum_data['test'])
samsum_test_loader=DataLoader(LoadData(samsum_test,x='dialogue',y='summary'),batch_size=Config.infer_batch_size,shuffle=False)
score_befor_fine_tuning=infer(samsum_test_loader,model,tokenizer,rouge_metric)

  0%|          | 0/205 [00:00<?, ?it/s]

In [6]:
def make_tokens(example_batch):
    """prepares tokens for training"""
    input_encodings=tokenizer(example_batch["dialogue"],max_length=1024,truncation=True)    
    with tokenizer.as_target_tokenizer():
        target_encodings=tokenizer(example_batch["summary"],max_length=128,truncation=True)
        
    return {"input_ids":input_encodings["input_ids"], 
            "attention_mask":input_encodings["attention_mask"], 
            "labels":target_encodings["input_ids"]
           }

samsum_tokens=samsum_data.map(make_tokens,batched=True)
columns=["input_ids", "labels","attention_mask"]
samsum_tokens.set_format(type="torch",columns=columns)

  0%|          | 0/15 [00:00<?, ?ba/s]



  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

#### training on huggingface Trainer class

In [7]:
from transformers import DataCollatorForSeq2Seq,TrainingArguments,Trainer

seq2seq_data_collator=DataCollatorForSeq2Seq(tokenizer,model=model.module)
training_args=TrainingArguments(output_dir='/kaggle/working/hf',
                                num_train_epochs=1,
                                warmup_steps=500,
                                per_device_train_batch_size=1, 
                                per_device_eval_batch_size=1,
                                weight_decay=0.01, 
                                logging_steps=10,
                                evaluation_strategy='steps',
                                eval_steps=500,
                                save_steps=1e6,
                                gradient_accumulation_steps=16)

trainer=Trainer(model=model.module, 
                  args=training_args,
                  tokenizer=tokenizer,
                  data_collator=seq2seq_data_collator,
                  train_dataset=samsum_tokens["train"], 
                  eval_dataset=samsum_tokens["validation"]) 

trainer.train()

#inferring on test data after fine tuning
score_after_fine_tuning=infer(data=samsum_test_loader,model=torch.nn.DataParallel(trainer.model),tokenizer=tokenizer,metric=rouge_metric)

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


You're using a PegasusTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss
500,1.6958,1.48378


  0%|          | 0/205 [00:00<?, ?it/s]

In [29]:
print(f'scores_before_tuning: {score_befor_fine_tuning} \n\nscores_after_tuning: {score_after_fine_tuning}')

scores_before_tuning: {'rouge1': 0.29624950139134154, 'rouge2': 0.08791421594725683, 'rougeL': 0.22920481105677556, 'rougeLsum': 0.22912836825557523} 

scores_after_tuning: {'rouge1': 0.43118236794877784, 'rouge2': 0.2004617211248804, 'rougeL': 0.3394828236863836, 'rougeLsum': 0.33964179236118636}
