# Thesis Experiment: Pegasus Model
## Michael LeVine, April 14, 2024

The purpose of this notebook is to test the summarization capabilities o the fPegasu5 model.

Attribution: This approach partially is based on a training course of Janana Ravi, Certified Google Cloud Architect and Data Engineer - from the LinkedIn Learning course: AI Text Summarization with Hugging Face, released 10/30/2.  In addition, some of the code was derived from the Hugging Face transformers summarization page https://huggingface.co/docs/transformers/en/tasks/summarization23.

## Overview: Using a Transformer Model from Hugging Face: Pegasus

### The Pegasus model
The pre-trained model that we will use is the "Pegasus" model from Hugging Face, which can be found here: https://huggingface.co/docs/transformers/model_doc/pegasus

The "model card," which describes the model, notes that Pegasus is a Sequence-to-sequence model with the same encoder-decoder model architecture as BART. Pegasus is pre-trained jointly on two self-supervised objective functions: Masked Language Modeling (MLM) and a novel summarization specific pretraining objective, called Gap Sentence Generation (GSG).  The model has 568M paramaters.

## Verifying the Compute Environment

### Graphics Processing Unit (GPU)
Running inference on transformer models can be done without a GPU.  However, for training, a GPU is recommended.  The following block of code shows checks whether a GPU is available for use in a PyTorch environment. 
.

In [1]:
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    
print("using", device, "device")

using cuda device


## Installing and importing required libraries and dependencies

In [2]:
#command line pip install the necessary required libraries and dependencies
#the transformers library allows us to access the pre-trained t5-small model
#the datasets library provides access to the Hugging Face datasets
#the evaluate model enables us to evaluate the summarizations the model produces
#the rouge_score is a standard evaluation metric used in text summarization tasks
#the accelerate function allows for distributed training on GPUs
#pip install transformers datasets evaluate rouge_score accelerate

### Import the transformers library

In [3]:
#this code imports needed libraries

import transformers
import datasets
import evaluate
import rouge_score
import accelerate

print(transformers.__version__) # verifies the transformers version

4.32.1


## Importing, Reducing, and Exploring the dataset

The experiment will use the CNN/Daily Mail dataset.  Two datasets will be created:
* Training Dataset
* Holdout Dataset (for running inference on the model)


### Instantiating a training dataset

In [4]:
#loading the dataset which was previously saved
from datasets import load_dataset
cnn_news_summary_ds = load_dataset("arrow", data_files={'train': 'data/cnn_news_summary_ds/train/data-00000-of-00001.arrow', 'test': 'data/cnn_news_summary_ds/test/data-00000-of-00001.arrow'})

In [5]:
cnn_news_summary_ds

DatasetDict({
    train: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 2296
    })
    test: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 575
    })
})

As per above output, now the dataset is broken down into two components:
* a `train` (training) dataset of 2296 articles
* a `test` dataset of 575

### Instantiating a holdout dataset (200 records)

The holdout dataset is used to run inference on both the "off-the-shelf" model and the fine-tuned model.  The purpose of having a holdout set is so the model is running inference on a different dataset from what it was trained on in order to test its performance.

In [6]:
#load holdout set for inference from a local csv (to ensure same order)
cnn_holdout_ds = load_dataset ("csv", data_files='data/cnn_holdout_ds.csv', split = "train[0:200]")
cnn_holdout_ds

Dataset({
    features: ['article', 'highlights', 'id'],
    num_rows: 200
})

### Exploring the Dataset

In [7]:
#dataset shape
cnn_news_summary_ds.shape

{'train': (2296, 3), 'test': (575, 3)}

The above output shows the cnn_dailymail `train` set is 2296 rows x 3 columns.  The `test` set is 575 rows x 3 columns.

In [8]:
#dataset object type
type(cnn_news_summary_ds)

datasets.dataset_dict.DatasetDict

The above output shows the cnn_dailymail dataset is of type Dataset, within the datasets library.

In [9]:
#dataset structure
cnn_news_summary_ds

DatasetDict({
    train: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 2296
    })
    test: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 575
    })
})

The Dataset has three features: 
* `article`: The full text of the news article
*  `highlights`: the target summary, also known as the reference summary
*  `id`: the unique id for each article/highlights pair


In [10]:
#looking at the features of the dataset
cnn_news_summary_ds['train'].features

{'article': Value(dtype='string', id=None),
 'highlights': Value(dtype='string', id=None),
 'id': Value(dtype='string', id=None)}

In [11]:
#examining the first record of the dataset.  
cnn_news_summary_ds['train'][0]

{'article': '(CNN) -- After almost 10 months, the FBI has zeroed in on a suspect in the case of missing Florida pilot Robert Wiles, who may have been kidnapped for ransom. Missing Florida pilot Robert Wiles is thought to have been kidnapped for ransom. "We\'re close to solving the case," said FBI special agent David Couvertier. He would not elaborate. Agents also would not identify the suspect, and they said the person is not in custody. Investigators would only reveal that the "key suspect" is in Florida, either in Orlando, Lakeland or Melbourne. "They\'re holding that back in hopes of getting additional information," said Couvertier. The FBI says it\'s also looking at several persons of interest in those same three Florida cities. Wiles, 27, was last seen in the family\'s aircraft maintenance business, National Flight Services, at Lakeland Linder Regional Airport on April 1, 2008. The day Wiles disappeared, he left behind his bags, his computer, and even his car. His father says the 

In [12]:
#examining the first record of the holdout dataset.  
cnn_holdout_ds[0]

{'article': "Manchester United have fallen off their perch. And they’re dropping like a stone towards mediocrity. That is the undeniable fact that has been hammered home relentlessly during the past six months. Whether we are talking about the events of Wednesday night at Olympiacos or before the startled eyes of the faithful at Old Trafford, the evidence is there for all to see. Can't stop the slump: David Moyes cannot believe it as he watches Manchester United lose at Olympiacos . Down and almost out: Robin van Persie lies on the floor during a defeat which sees United's Champions League campaign hanging by a thread . Disbelief: Wayne Rooney cries out in vain during another shambolic United display . Abject: The frustration shows on the Man United players' faces on taking the restart after conceding to Olympiacos . Coming to get you: Liverpool are looking to take United's place in the top four . Now is it time for Man United to sack Moyes? Out of the title race, out of the FA . Cup, 

In [13]:
cnn_holdout_ds[120]

{'article': "Every day Sportsmail takes a look at the European papers to see what are the biggest stories creating talking points on the continent. On Saturday, Italian newspapers Tuttosport and Corriere dello Sport both lead with reports that Juventus manager Antonio Conte could be set to leave the club this summer. Tuttosport claim that Conte, who has just led Juve to a third consecutive Serie A title but has failed to make progress in the Champions League, and club president Andrea Agnelli seem very distant and CDS says Juve and Conte are 'moving apart'. Uncertain future: Juventus manager Antonio Conte could leave the club this summer . Triple crown: Conte has led Juventus to three consecutive Serie A league titles . Conte has been linked with Monaco, who are also reportedly keen on Arsenal manager Arsene Wenger and Benfica's Jorge Jesus. Elsewhere in Italy, La Gazzetta dello Sport pays tribute to Inter Milan right-back Javier Zanetti who is set to retire at the end of the season, l

## Preprocessing and Cleaning the Data

In [14]:
#defining a text cleaning function.  This function iterates over the 'article'
# and 'highlights' section and replaces various text strings (like
#backslashes, new lines, etc.) with the empty string

def clean_txt(example):
  for txt in ['article', 'highlights']:
    example[txt] = example[txt].lower() #convert text to lowercase
    example[txt] = example[txt].replace('\\','')
    example[txt] = example[txt].replace('/','')
    example[txt] = example[txt].replace('\n','')
    example[txt] = example[txt].replace('``','')
    example[txt] = example[txt].replace('"','')
    example[txt] = example[txt].replace('--','')
  return example

### Mapping the text cleaning function to the training dataset and the holdout dataset

In [15]:
#hugging face datasets allow the .map operation to
#apply a function to all records in a dataset, and then will
#update the dataset.  In efect, the .map() method
# maps the `clean_txt` function to
# all the records in the `cnn_news_summary_ds dataset.
#The result will be that the training and
#test data will now be cleaned
cleaned_cnn_news_summary_ds = cnn_news_summary_ds.map(clean_txt)

cleaned_cnn_news_summary_ds

DatasetDict({
    train: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 2296
    })
    test: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 575
    })
})

In [16]:
#verifying that we have a clean training dataset by comparing a record from the original
#dataset to the cleaned dataset
print('\n\n==== Original dataset ====\n\n')

print (cnn_news_summary_ds['train']['article'][0])

print('\n\n==== Cleaned dataset ====\n\n')

cleaned_cnn_news_summary_ds['train']['article'][0]



==== Original dataset ====


(CNN) -- After almost 10 months, the FBI has zeroed in on a suspect in the case of missing Florida pilot Robert Wiles, who may have been kidnapped for ransom. Missing Florida pilot Robert Wiles is thought to have been kidnapped for ransom. "We're close to solving the case," said FBI special agent David Couvertier. He would not elaborate. Agents also would not identify the suspect, and they said the person is not in custody. Investigators would only reveal that the "key suspect" is in Florida, either in Orlando, Lakeland or Melbourne. "They're holding that back in hopes of getting additional information," said Couvertier. The FBI says it's also looking at several persons of interest in those same three Florida cities. Wiles, 27, was last seen in the family's aircraft maintenance business, National Flight Services, at Lakeland Linder Regional Airport on April 1, 2008. The day Wiles disappeared, he left behind his bags, his computer, and even his car. His fa

"(cnn)  after almost 10 months, the fbi has zeroed in on a suspect in the case of missing florida pilot robert wiles, who may have been kidnapped for ransom. missing florida pilot robert wiles is thought to have been kidnapped for ransom. we're close to solving the case, said fbi special agent david couvertier. he would not elaborate. agents also would not identify the suspect, and they said the person is not in custody. investigators would only reveal that the key suspect is in florida, either in orlando, lakeland or melbourne. they're holding that back in hopes of getting additional information, said couvertier. the fbi says it's also looking at several persons of interest in those same three florida cities. wiles, 27, was last seen in the family's aircraft maintenance business, national flight services, at lakeland linder regional airport on april 1, 2008. the day wiles disappeared, he left behind his bags, his computer, and even his car. his father says the next day, wiles was supp

In [17]:
cleaned_cnn_holdout_ds = cnn_holdout_ds.map(clean_txt)

cleaned_cnn_holdout_ds

Dataset({
    features: ['article', 'highlights', 'id'],
    num_rows: 200
})

In [18]:
#verifying that we have a clean holdout dataset by comparing a record from the original
#holdout dataset to the cleaned holdout dataset
print('\n\n==== Original holdout dataset ====\n\n')

print (cnn_holdout_ds['article'][0])

print('\n\n==== Cleaned holdout dataset ====\n\n')

cleaned_cnn_holdout_ds['article'][0]



==== Original holdout dataset ====


Manchester United have fallen off their perch. And they’re dropping like a stone towards mediocrity. That is the undeniable fact that has been hammered home relentlessly during the past six months. Whether we are talking about the events of Wednesday night at Olympiacos or before the startled eyes of the faithful at Old Trafford, the evidence is there for all to see. Can't stop the slump: David Moyes cannot believe it as he watches Manchester United lose at Olympiacos . Down and almost out: Robin van Persie lies on the floor during a defeat which sees United's Champions League campaign hanging by a thread . Disbelief: Wayne Rooney cries out in vain during another shambolic United display . Abject: The frustration shows on the Man United players' faces on taking the restart after conceding to Olympiacos . Coming to get you: Liverpool are looking to take United's place in the top four . Now is it time for Man United to sack Moyes? Out of the title r

"manchester united have fallen off their perch. and they’re dropping like a stone towards mediocrity. that is the undeniable fact that has been hammered home relentlessly during the past six months. whether we are talking about the events of wednesday night at olympiacos or before the startled eyes of the faithful at old trafford, the evidence is there for all to see. can't stop the slump: david moyes cannot believe it as he watches manchester united lose at olympiacos . down and almost out: robin van persie lies on the floor during a defeat which sees united's champions league campaign hanging by a thread . disbelief: wayne rooney cries out in vain during another shambolic united display . abject: the frustration shows on the man united players' faces on taking the restart after conceding to olympiacos . coming to get you: liverpool are looking to take united's place in the top four . now is it time for man united to sack moyes? out of the title race, out of the fa . cup, out of the l

## Instantiating an "off-the-shelf" Pegasus model and tokenizer

In [19]:
from transformers import pipeline, PegasusForConditionalGeneration, PegasusTokenizer, TrainingArguments, Trainer, \
                         DataCollatorForSeq2Seq
import pandas as pd
from datasets import Dataset
import random

In [20]:
from transformers import AutoModel, AutoTokenizer 

In [21]:
pip install sentencepiece

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [22]:
from transformers import AutoTokenizer

checkpoint = "google/pegasus-large"
#tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenizer = PegasusTokenizer.from_pretrained(checkpoint)

In [23]:
#looking at model architecture
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer

model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
model

Some weights of PegasusForConditionalGeneration were not initialized from the model checkpoint at google/pegasus-large and are newly initialized: ['model.decoder.embed_positions.weight', 'model.encoder.embed_positions.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


PegasusForConditionalGeneration(
  (model): PegasusModel(
    (shared): Embedding(96103, 1024, padding_idx=0)
    (encoder): PegasusEncoder(
      (embed_tokens): Embedding(96103, 1024, padding_idx=0)
      (embed_positions): PegasusSinusoidalPositionalEmbedding(1024, 1024)
      (layers): ModuleList(
        (0-15): 16 x PegasusEncoderLayer(
          (self_attn): PegasusAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): ReLU()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (final_layer_no

## Preprocessing Datasets

In [24]:
#define a preprocessing function

#prefix = "summarize: "


def preprocess_function(examples):
    #inputs = [prefix + doc for doc in examples["article"]]
    inputs = [doc for doc in examples["article"]]
    #model_inputs = tokenizer(inputs, max_length=1000, truncation=True)
    model_inputs = tokenizer(inputs, truncation=True)

    #labels = tokenizer(text_target=examples["highlights"], max_length=256, truncation=True)
    labels = tokenizer(text_target=examples["highlights"], truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

### Preprocessing the training dataset

In [25]:
tokenized_cnn_training_ds = cleaned_cnn_news_summary_ds.map(preprocess_function, batched=True)
tokenized_cnn_training_ds

DatasetDict({
    train: Dataset({
        features: ['article', 'highlights', 'id', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 2296
    })
    test: Dataset({
        features: ['article', 'highlights', 'id', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 575
    })
})

### Preprocessing the holdout dataset

In [26]:
tokenized_cnn_holdout_ds = cleaned_cnn_holdout_ds.map(preprocess_function, batched=True)
tokenized_cnn_holdout_ds

Dataset({
    features: ['article', 'highlights', 'id', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 200
})

## Instantiating a Data Collator

In [27]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

## Importing ROUGE Metric
*  Rouge is a standard evaluation metric used in text summarization tasks
*  ROUGE provides an *objective metric* to compare model-produced summary with the dataset's reference summary.  

### Importing the evaluate library

The 'evaluate' library from Hugging Face allows us to evaluate ML models.  The 'evaluate' library provides access to dozens of evaluation metrics across many ML domains (including NLP, computer vision, etc.).

In [28]:
import evaluate

rouge = evaluate.load("rouge")
rouge

EvaluationModule(name: "rouge", module_type: "metric", features: [{'predictions': Value(dtype='string', id='sequence'), 'references': Sequence(feature=Value(dtype='string', id='sequence'), length=-1, id=None)}, {'predictions': Value(dtype='string', id='sequence'), 'references': Value(dtype='string', id='sequence')}], usage: """
Calculates average rouge scores for a list of hypotheses and references
Args:
    predictions: list of predictions to score. Each prediction
        should be a string with tokens separated by spaces.
    references: list of reference for each prediction. Each
        reference should be a string with tokens separated by spaces.
    rouge_types: A list of rouge types to calculate.
        Valid names:
        `"rouge{n}"` (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring,
        `"rougeL"`: Longest common subsequence based scoring.
        `"rougeLsum"`: rougeLsum splits text using `"
"`.
        See details in https://github.com/huggingface/

## Running inference on holdout set with "off-the-shelf" Pegasus model

In [29]:
#creating 'holdout_article_texts' variable to hold articles from the test set
holdout_article_texts = tokenized_cnn_holdout_ds["article"]

#creating 'holdout_article_summaries' variable to hold summaries from the test set
holdout_article_summaries = tokenized_cnn_holdout_ds["highlights"]


In [30]:
tokenized_cnn_holdout_ds

Dataset({
    features: ['article', 'highlights', 'id', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 200
})

In [31]:
tokenized_cnn_holdout_ds[0]['input_ids']

[110,
 80869,
 10381,
 133,
 6852,
 299,
 153,
 28955,
 107,
 111,
 157,
 123,
 216,
 7325,
 172,
 114,
 1909,
 1239,
 64873,
 107,
 120,
 117,
 109,
 25839,
 617,
 120,
 148,
 174,
 30524,
 238,
 35198,
 333,
 109,
 555,
 1029,
 590,
 107,
 682,
 145,
 127,
 1767,
 160,
 109,
 702,
 113,
 73492,
 565,
 134,
 4429,
 445,
 65722,
 29475,
 116,
 132,
 269,
 109,
 52083,
 1525,
 113,
 109,
 10316,
 134,
 459,
 29857,
 44457,
 108,
 109,
 1812,
 117,
 186,
 118,
 149,
 112,
 236,
 107,
 137,
 131,
 144,
 923,
 109,
 26973,
 151,
 42841,
 11434,
 8086,
 967,
 697,
 126,
 130,
 178,
 8346,
 110,
 80869,
 10381,
 2019,
 134,
 4429,
 445,
 65722,
 29475,
 116,
 110,
 107,
 308,
 111,
 744,
 165,
 151,
 37741,
 4406,
 446,
 15316,
 4269,
 124,
 109,
 1030,
 333,
 114,
 6714,
 162,
 6659,
 10381,
 131,
 116,
 11256,
 3867,
 1541,
 4229,
 141,
 114,
 3926,
 110,
 107,
 35593,
 151,
 230,
 2979,
 20911,
 50399,
 28087,
 165,
 115,
 23213,
 333,
 372,
 31811,
 44022,
 10381,
 1381,
 110,
 107,
 718

In [32]:
len(tokenized_cnn_holdout_ds[0]['input_ids'])

1024

In [34]:
#instantiating a summarizer pipeline with an off-the-shelf t5-small model
summarizer = pipeline ('summarization', model=checkpoint, truncation=True) #added the truncation argument to the pipeline paramaters

Some weights of PegasusForConditionalGeneration were not initialized from the model checkpoint at google/pegasus-large and are newly initialized: ['model.decoder.embed_positions.weight', 'model.encoder.embed_positions.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [35]:
summarizer ("The researchers investigate the feasibility of using BART to enhance machine translation decoders for translating into English. Using pre-trained encoders has been proven to improve models, while the benefits of incorporating pre-trained language models into decoders have been more limited. Using a set of encoder parameters learned from bitext, they demonstrate that the entire BART model can be used as a single pretrained decoder for machine translation. More specifically, they swap out the embedding layer of BART's encoder with a brand new encoder using random initialization. When the model is trained from start to end, the new encoder is trained to map foreign words into an input BART can then translate into English. In both stages of training, the cross-entropy loss is backpropagated from the BART model's output to train the source encoder. In the first stage, they fix most of BART's parameters and only update the randomly initialized source encoder, the BART positional embeddings, and the self-attention input projection matrix of BART's encoder first layer. Second, they perform a limited number of training iterations on all model parameters.")

Your max_length is set to 256, but your input_length is only 224. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=112)


[{'summary_text': "In the first stage, they fix most of BART's parameters and only update the randomly initialized source encoder, the BART positional embeddings, and the self-attention input projection matrix of BART's encoder first layer."}]

In [36]:
holdout_article_texts[0]

"manchester united have fallen off their perch. and they’re dropping like a stone towards mediocrity. that is the undeniable fact that has been hammered home relentlessly during the past six months. whether we are talking about the events of wednesday night at olympiacos or before the startled eyes of the faithful at old trafford, the evidence is there for all to see. can't stop the slump: david moyes cannot believe it as he watches manchester united lose at olympiacos . down and almost out: robin van persie lies on the floor during a defeat which sees united's champions league campaign hanging by a thread . disbelief: wayne rooney cries out in vain during another shambolic united display . abject: the frustration shows on the man united players' faces on taking the restart after conceding to olympiacos . coming to get you: liverpool are looking to take united's place in the top four . now is it time for man united to sack moyes? out of the title race, out of the fa . cup, out of the l

In [37]:
summarizer(holdout_article_texts[0])

[{'summary_text': "whether we are talking about the events of wednesday night at olympiacos or before the startled eyes of the faithful at old trafford, the evidence is there for all to see. can't stop the slump: david moyes cannot believe it as he watches manchester united lose at olympiacos . abject: the frustration shows on the man united players' faces on taking the restart after conceding to olympiacos . cup, out of the league cup, out of the top four and now in desperate ."}]

In [38]:
type (cnn_holdout_ds[0])

dict

In [39]:
type(holdout_article_texts[0])

str

In [40]:
len(holdout_article_texts[0])

7698

In [41]:
#running inference on holdout set with off the shelf model
from tqdm import tqdm

#instantiating an empty list named 'holdout summaries'
holdout_off_the_shelf_summaries = []

#prefix ='summarize: '

for i, text in enumerate(tqdm(holdout_article_texts[:500])):

    #candidate = summarizer(prefix + text)
    candidate = summarizer(text)
    
    holdout_off_the_shelf_summaries.append(candidate[0]["summary_text"])

 46%|███████████████████▎                      | 92/200 [15:45<18:36, 10.34s/it]Your max_length is set to 256, but your input_length is only 200. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=100)
100%|█████████████████████████████████████████| 200/200 [34:22<00:00, 10.31s/it]


## Evaluating the off-the-shelf Pegasus performance

In [42]:
#aggregated results of inference on holdout set
result_off_the_shelf_agg = rouge.compute(predictions = holdout_off_the_shelf_summaries,
                           references = holdout_article_summaries[:],
                           use_stemmer=True)

result_off_the_shelf_agg

{'rouge1': 0.3152476645473691,
 'rouge2': 0.11224910571748392,
 'rougeL': 0.19593952795768982,
 'rougeLsum': 0.1962393489745864}

In [43]:
#unaggregated results of inference on holdout set
result_off_the_shelf_unagg = rouge.compute(predictions = holdout_off_the_shelf_summaries,
                           references = holdout_article_summaries[:],
                           use_stemmer=True,
                             use_aggregator=False)

result_off_the_shelf_unagg

{'rouge1': [0.37499999999999994,
  0.3625730994152046,
  0.16981132075471697,
  0.5759162303664922,
  0.3146067415730337,
  0.12195121951219512,
  0.24864864864864866,
  0.32608695652173914,
  0.40287769784172667,
  0.21276595744680848,
  0.25477707006369427,
  0.3973509933774834,
  0.37333333333333335,
  0.3766233766233766,
  0.4297520661157025,
  0.3389830508474576,
  0.28125,
  0.35384615384615387,
  0.16091954022988506,
  0.3928571428571428,
  0.36538461538461536,
  0.43243243243243246,
  0.2809917355371901,
  0.24864864864864863,
  0.3707865168539326,
  0.2753623188405797,
  0.2266666666666667,
  0.13714285714285715,
  0.23376623376623376,
  0.3870967741935484,
  0.1772151898734177,
  0.2018348623853211,
  0.4094488188976378,
  0.17886178861788615,
  0.27450980392156865,
  0.23684210526315788,
  0.4772727272727273,
  0.31884057971014496,
  0.2994011976047904,
  0.4217687074829932,
  0.5341614906832298,
  0.5210084033613445,
  0.25925925925925924,
  0.3148148148148148,
  0.33333333

## Creating DataFrames to hold new summary

In [44]:
cleaned_cnn_holdout_df = pd.read_csv('data/data_frame_after_t5_bart.csv')
cleaned_cnn_holdout_df

Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,article,highlights,id,t5_summaries,t5_fine_tuned_summaries,bart_summaries,bart_fine_tuned_summaries
0,0,0,manchester united have fallen off their perch....,manchester united were beaten 2-0 in their cha...,5b3a626078390cb0e05327b4019753fd11cb8cea,manchester united lost 1-0 to olympiacos in th...,manchester united lost 1-0 to olympiacos in th...,manchester united have fallen off their perch....,manchester united have fallen off their perch....
1,1,1,a mother whose russian husband snatched their ...,"rachael neustadt's sons - daniel, eight and jo...",59d478d4a4299e2192997e56a9db9003fa2bac2d,"rachael neustadt's sons daniel, eight, and jon...","rachael neustadt's sons daniel, eight, and jon...",a mother whose russian husband snatched their ...,rachael neustadt and her two sons were freed i...
2,2,2,claim: supporters of mayor lutfur rahman alleg...,islamic voters allegedly told to be 'good musl...,ec961b7d0912e7753dffe4360b77481eba96f2e1,supporters of mayor lutfur rahman allegedly ha...,supporters of mayor lutfur rahman allegedly ha...,claim: supporters of mayor lutfur rahman alleg...,claim: supporters of mayor lutfur rahman hande...
3,3,3,the 15-year-old cousin of a palestinian boy wh...,"mohammed abu khder, 16, abducted and burned to...",092d90d61eb105b3955820cc4894ac2c4995ad1b,"mohammed abu khder, 16, was abducted from his ...","mohammed abu khder, 16, was burned to death in...",the 15-year-old cousin of a palestinian boy wh...,cousin of palestinian boy burned to death in i...
4,4,4,it may have made its way up the pole to become...,spearmint rhino records £2.1m loss in 2011 .lo...,d0d59018cdf48aaeb6e1838c0323f8555e800765,spearmint rhino has recorded a loss of £2.1mil...,spearmint rhino has recorded a loss of £2.1m i...,it may have made its way up the pole to become...,spearmint rhino has filed accounts showing tha...
...,...,...,...,...,...,...,...,...,...
195,195,195,reality tv show the block has been accused of ...,'the block' caught out faking a visit from the...,985b1bf7fc710e4ffdd9dd02e72d889a7997e89d,reality tv show the block has been accused of ...,reality tv show the block has been accused of ...,reality tv show the block has been accused of ...,reality tv show the block has been accused of ...
196,196,196,the average cost of raising a child to seconda...,average cost of raising a child from birth up ...,e466296e19d7a14cf4916d70a2cbc296e4659c99,average cost of raising a child to secondary s...,average cost of raising a child to secondary s...,the average cost of raising a child to seconda...,the average cost of raising a child to seconda...
197,197,197,thai police investigating the murder of two br...,"pornprasit sukdam claims he was offered £13,30...",a07624a84fe59a3321e83f153d6fd615207a8545,"pornprasit sukdam claims he was offered 700,00...","pornprasit sukdam claims he was offered 700,00...",thai police investigating the murder of two br...,thai police investigating the murder of two br...
198,198,198,from clumpy flat shoes that seem to shorten a ...,clumpy flat shoes that seem to shorten a woman...,e16474b52bbf45f49434fc4a0b1d68e2d3fba3c3,kim carillo says she feels surprisingly sexy i...,"kim carillo, who usually favours a more alluri...",from clumpy flat shoes that seem to shorten a ...,kim carillo tests some of the latest man-repel...


In [45]:
cleaned_cnn_holdout_df['pegasus_summaries'] = holdout_off_the_shelf_summaries
cleaned_cnn_holdout_df

Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,article,highlights,id,t5_summaries,t5_fine_tuned_summaries,bart_summaries,bart_fine_tuned_summaries,pegasus_summaries
0,0,0,manchester united have fallen off their perch....,manchester united were beaten 2-0 in their cha...,5b3a626078390cb0e05327b4019753fd11cb8cea,manchester united lost 1-0 to olympiacos in th...,manchester united lost 1-0 to olympiacos in th...,manchester united have fallen off their perch....,manchester united have fallen off their perch....,whether we are talking about the events of wed...
1,1,1,a mother whose russian husband snatched their ...,"rachael neustadt's sons - daniel, eight and jo...",59d478d4a4299e2192997e56a9db9003fa2bac2d,"rachael neustadt's sons daniel, eight, and jon...","rachael neustadt's sons daniel, eight, and jon...",a mother whose russian husband snatched their ...,rachael neustadt and her two sons were freed i...,a mother whose russian husband snatched their ...
2,2,2,claim: supporters of mayor lutfur rahman alleg...,islamic voters allegedly told to be 'good musl...,ec961b7d0912e7753dffe4360b77481eba96f2e1,supporters of mayor lutfur rahman allegedly ha...,supporters of mayor lutfur rahman allegedly ha...,claim: supporters of mayor lutfur rahman alleg...,claim: supporters of mayor lutfur rahman hande...,a petition brought before the high court claim...
3,3,3,the 15-year-old cousin of a palestinian boy wh...,"mohammed abu khder, 16, abducted and burned to...",092d90d61eb105b3955820cc4894ac2c4995ad1b,"mohammed abu khder, 16, was abducted from his ...","mohammed abu khder, 16, was burned to death in...",the 15-year-old cousin of a palestinian boy wh...,cousin of palestinian boy burned to death in i...,the 15-year-old cousin of a palestinian boy wh...
4,4,4,it may have made its way up the pole to become...,spearmint rhino records £2.1m loss in 2011 .lo...,d0d59018cdf48aaeb6e1838c0323f8555e800765,spearmint rhino has recorded a loss of £2.1mil...,spearmint rhino has recorded a loss of £2.1m i...,it may have made its way up the pole to become...,spearmint rhino has filed accounts showing tha...,"despite the losses, spearmint rhino says it ha..."
...,...,...,...,...,...,...,...,...,...,...
195,195,195,reality tv show the block has been accused of ...,'the block' caught out faking a visit from the...,985b1bf7fc710e4ffdd9dd02e72d889a7997e89d,reality tv show the block has been accused of ...,reality tv show the block has been accused of ...,reality tv show the block has been accused of ...,reality tv show the block has been accused of ...,the show says it was a case of 'misidentificat...
196,196,196,the average cost of raising a child to seconda...,average cost of raising a child from birth up ...,e466296e19d7a14cf4916d70a2cbc296e4659c99,average cost of raising a child to secondary s...,average cost of raising a child to secondary s...,the average cost of raising a child to seconda...,the average cost of raising a child to seconda...,"the bulk of the total £83,627-a-year bill come..."
197,197,197,thai police investigating the murder of two br...,"pornprasit sukdam claims he was offered £13,30...",a07624a84fe59a3321e83f153d6fd615207a8545,"pornprasit sukdam claims he was offered 700,00...","pornprasit sukdam claims he was offered 700,00...",thai police investigating the murder of two br...,thai police investigating the murder of two br...,a spokesman for the royal thai police confirme...
198,198,198,from clumpy flat shoes that seem to shorten a ...,clumpy flat shoes that seem to shorten a woman...,e16474b52bbf45f49434fc4a0b1d68e2d3fba3c3,kim carillo says she feels surprisingly sexy i...,"kim carillo, who usually favours a more alluri...",from clumpy flat shoes that seem to shorten a ...,kim carillo tests some of the latest man-repel...,"here, kim carillo, who usually favours a more ..."


In [46]:
cleaned_cnn_holdout_df.to_csv('data/data_frame_after_t5_bart_pegasus.csv')

## Fine-tuning the Pegasus model

In [47]:
#defining a custom compute metrics function
import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

In [48]:
training_args = Seq2SeqTrainingArguments(
    output_dir="pegasus_results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=4,
    predict_with_generate=True,
    fp16=True,
    push_to_hub=False,
    load_best_model_at_end=True, #even if we overtrain model by accident, we will still 
    #load the checkpoint that had lowest evaluation loss

    #evaluation_strategy can be steps or epochs - correlates to how often we stop training and evaluate our model
    eval_steps=50,
    save_strategy='epoch' #save the model after every epoch
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_cnn_training_ds["train"],
    eval_dataset=tokenized_cnn_training_ds["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
    
)
trainer.evaluate() #adding a max_new_tokens paramater)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None)


OutOfMemoryError: CUDA out of memory. Tried to allocate 94.00 MiB. GPU 0 has a total capacty of 47.50 GiB of which 81.12 MiB is free. Including non-PyTorch memory, this process has 46.63 GiB memory in use. Of the allocated memory 46.16 GiB is allocated by PyTorch, and 161.55 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
CUDA out of memory