# Summarization

In [24]:
from datasets import load_dataset, load_metric
import nltk
from nltk.tokenize import sent_tokenize
import pandas as pd
import tqdm
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [13]:
SEED = 42

In [2]:
dataset = load_dataset('cnn_dailymail', version='3.0.0')
dataset

Downloading builder script:   0%|          | 0.00/3.23k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.62k [00:00<?, ?B/s]

Using custom data configuration default


Downloading and preparing dataset cnn_dailymail/default to /root/.cache/huggingface/datasets/cnn_dailymail/default/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de...


Downloading data files:   0%|          | 0/5 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/159M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/376M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/12.3M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/661k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/572k [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Dataset cnn_dailymail downloaded and prepared to /root/.cache/huggingface/datasets/cnn_dailymail/default/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 287113
    })
    validation: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 13368
    })
    test: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 11490
    })
})

In [3]:
dataset['train'][1]

{'article': 'Editor\'s note: In our Behind the Scenes series, CNN correspondents share their experiences in covering news and analyze the stories behind the events. Here, Soledad O\'Brien takes users inside a jail where many of the inmates are mentally ill. An inmate housed on the "forgotten floor," where many mentally ill inmates are housed in Miami before trial. MIAMI, Florida (CNN) -- The ninth floor of the Miami-Dade pretrial detention facility is dubbed the "forgotten floor." Here, inmates with the most severe mental illnesses are incarcerated until they\'re ready to appear in court. Most often, they face drug charges or charges of assaulting an officer --charges that Judge Steven Leifman says are usually "avoidable felonies." He says the arrests often result from confrontations with police. Mentally ill people often won\'t do what they\'re told when police arrive on the scene -- confrontation seems to exacerbate their illness and they become more paranoid, delusional, and less li

In [9]:
def three_sent_summary(text):
    return '\n'.join(sent_tokenize(text)[:3])

In [10]:
three_sent_summary(dataset['train'][1]['article'])

'Editor\'s note: In our Behind the Scenes series, CNN correspondents share their experiences in covering news and analyze the stories behind the events.\nHere, Soledad O\'Brien takes users inside a jail where many of the inmates are mentally ill. An inmate housed on the "forgotten floor," where many mentally ill inmates are housed in Miami before trial.\nMIAMI, Florida (CNN) -- The ninth floor of the Miami-Dade pretrial detention facility is dubbed the "forgotten floor."'

## Evaluating PEGASUS

In [11]:
def eval_summaries_baseline(dataset, metric, column_text='article', column_summary='highlights'):
    summaries = [three_sent_summary(txt) for txt in dataset[column_text]]
    metric.add_batch(predictions=summaries, references=dataset[column_summary])
    score = metric.compute()
    return score

In [18]:
rouge_metric = load_metric('rouge')
rouge_names = ['rouge1', 'rouge2', 'rougeL', 'rougeLsum']

### Baseline

In [21]:
test_sample = dataset['test'].shuffle(seed=SEED).select(range(1000))

score = eval_summaries_baseline(test_sample, rouge_metric)
rouge_dict = dict((rn, score[rn].mid.fmeasure) for rn in rouge_names)

pd.DataFrame.from_dict(rouge_dict, orient='index', columns=['baseline']).T



Unnamed: 0,rouge1,rouge2,rougeL,rougeLsum
baseline,0.389121,0.171397,0.245077,0.354013


In [23]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def chunks(elements, batch_size):
    for i in range(0, len(elements), batch_size):
        yield elements[i: i+batch_size]
        

def eval_summaries_pegasus(dataset, metric, model, tokenizer, batch_size=16, device=device, column_text='article', column_summary='highlights'):
    article_batches = list(chunks(dataset[column_text], batch_size))
    target_batches = list(chunks(dataset[column_summary], batch_size))
    
    for ab, tb in tqdm.tqdm(zip(article_batches, target_batches), total=len(article_batches)):
        inputs = tokenizer(ab, max_length=1024, truncation=True, padding='max_length', return_tensors='pt')
        summaries = model.generate(
            input_ids=inputs['input_ids'].to(device), attention_mask=inputs['attention_mask'].to(device),
            length_penalty=0.8, num_beams=8, max_length=128
        )
        decoded = [tokenizer.decode(s, skip_special_tokens=True, clean_up_tokenization_spaces=True) for s in summaries]
        decoded = [d.replace('<n>', ' ') for d in decoded]
        metric.add_batch(predictions=decoded, references=tb)
        
    score = metric.compute()
    return score

In [None]:
ckpt = 'google/pegasus-cnn_dailymail'
tokenizer = AutoTokenizer.from_pretrained(ckpt)
model = AutoModelForSeq2SeqLM.from_pretrained(ckpt).to(device)

score = eval_summaries_pegasus(test_sample, rouge_metric, model, tokenizer, batch_size=8)
rouge_dict = dict((rn, score[rn].mid.fmeasure) for rn in rouge_names)

pd.DataFrame.from_dict(rouge_dict, orient='index', columns=['pegasus']).T

Downloading:   0%|          | 0.00/88.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.09k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.82M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/65.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.12G [00:00<?, ?B/s]

  0%|          | 0/125 [00:05<?, ?it/s]
