# Text summarization with T5 on XSum

We are going to fine-tune the [T5 model, implemented by HuggingFace](https://huggingface.co/t5-small), for text summarization on the [Extreme Summarization (XSum)](https://huggingface.co/datasets/xsum) dataset.
The data is composed by news articles and the corresponding summaries.

We will be using the following model sizes available from HuggingFace

| Variant                                     |   Parameters    |
|:-------------------------------------------:|----------------:|
| [T5-small](https://huggingface.co/t5-small) |    60,506,624   | 
| [T5-large](https://huggingface.co/t5-large) |   737,668,096   | 
| [T5-3b](https://huggingface.co/t5-3b)       | 2,851,598,336   | 


More info:
* This notebooks is based on the script [run_summarization_no_trainer.py](https://github.com/huggingface/transformers/blob/v4.12.5/examples/pytorch/summarization/run_summarization_no_trainer.py) from HuggingFace
* [T5 on HuggingFace docs](https://huggingface.co/transformers/model_doc/t5.html)

In [2]:
import os
import datasets
import nltk
import numpy as np
import torch
from datasets import load_dataset, load_metric
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq
from torch.utils.data import DataLoader

In [3]:
from datasets.utils import disable_progress_bar
from datasets import disable_caching


disable_progress_bar()
disable_caching()

## The data

In [5]:
hf_dataset = load_dataset('xsum')

Using custom data configuration default
Reusing dataset xsum (/users/sarafael/.cache/huggingface/datasets/xsum/default/1.2.0/32c23220eadddb1149b16ed2e9430a05293768cfffbdfd151058697d4c11f934)


In [6]:
hf_dataset

DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 204045
    })
    validation: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11332
    })
    test: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11334
    })
})

In [9]:
# sample = 33609  # twin peaks
# sample = 192550 # twin peaks
sample = 188948   # blues

In [10]:
hf_dataset['train']['id'][sample]

'15575668'

In [11]:
hf_dataset['train']['summary'][sample]

'BB King was hailed as one of the greatest blues musicians of all time.'

In [12]:
hf_dataset['train']['document'][sample]

'His vibrato style of playing influenced a generation of rock and blues guitarists, including Eric Clapton, Mike Bloomfield and Stevie Ray Vaughan.\nRolling Stone magazine once ranked BB King in third place in its list of the 100 greatest guitarists of all time, just below Jimi Hendrix and Duane Allman.\nHis output crossed musical barriers, from jazz and blues to mainstream pop.\nHe was born Riley B King in Indianola, Mississippi, on 16 September 1925. His parents were sharecroppers and, as a young boy, he helped them work in the fields.\nThe family struggled. "When you live in a house that you can always peek out of and see what kind of day it is," King later said, "you\'re not doing so well."\nThe sound of his co-workers hollering the blues was his first introduction to the style of music that he was to help take from a purely black American audience into the mainstream.\nHe bought his first guitar when he was barely a teenager so he could play at church services. In 1947 he moved to

## The tokenizer

In [14]:
hf_model = 't5-small'
t5_cache = os.path.join(os.getcwd(), 'cache')

tokenizer = AutoTokenizer.from_pretrained(
    hf_model,
    use_fast=True,
    cache_dir=os.path.join(t5_cache, f'{hf_model}_tokenizer')
)

In [15]:
encoded_text = tokenizer("What's up tokenizer!", max_length=1024,
                         padding=False, truncation=True)

In [16]:
encoded_text

{'input_ids': [363, 31, 7, 95, 14145, 8585, 55, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}

In [17]:
tokenizer.batch_decode(encoded_text['input_ids'])

['What', "'", 's', 'up', 'token', 'izer', '!', '</s>']

In [18]:
with tokenizer.as_target_tokenizer():
    encoded_text = tokenizer("What's up tokenizer!", max_length=1024,
                             padding=False, truncation=True)

In [19]:
encoded_text

{'input_ids': [363, 31, 7, 95, 14145, 8585, 55, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}

## Tokenizing the data

In [20]:
def preprocess_function(examples):    
    inputs = examples['document']
    targets = examples['summary']
    inputs = [f'summarize: {inp}' for inp in inputs]
    model_inputs = tokenizer(inputs, max_length=1024,
                             padding=False, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=128,
                           padding=False, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [21]:
%%time
processed_datasets = hf_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=hf_dataset["train"].column_names,
    desc="Running tokenizer on dataset",
    num_proc=12
)

CPU times: user 1.03 s, sys: 394 ms, total: 1.42 s
Wall time: 35.5 s


In [22]:
processed_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 204045
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 11332
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 11334
    })
})

In [23]:
# For training Sequence to Sequence models, we need a special kind of data collator,
# which will not only pad the inputs to the maximum length in the batch,
# but also the labels.
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    label_pad_token_id=tokenizer.pad_token_id
)

per_device_train_batch_size = 4

train_dataset = processed_datasets["train"]

train_dataloader = DataLoader(
    train_dataset,
    shuffle=False,
    collate_fn=data_collator,
    batch_size=per_device_train_batch_size
)

In [24]:
for step, batch in enumerate(train_dataloader):
    if step > 5:
        break

In [25]:
type(batch)

transformers.tokenization_utils_base.BatchEncoding

In [26]:
batch.keys()

dict_keys(['input_ids', 'attention_mask', 'labels'])

In [27]:
batch['input_ids'].shape

torch.Size([4, 1024])

In [28]:
batch['input_ids'][0]

tensor([21603,    10,  1363,  ...,     0,     0,     0])

In [29]:
batch['attention_mask']

tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0]])

In [30]:
batch['attention_mask'][0]

tensor([1, 1, 1,  ..., 0, 0, 0])

In [31]:
tokenizer.decode(batch['input_ids'][0][batch['attention_mask'][0]==1])

'summarize: Mr Fox, 54, from London, denies eight counts of indecent assault and two counts of sexual assault between 1988 and 2014. He said there was often "horseplay" with colleagues, involving "piggybacks, tickling and squeezing". But he told Westminster Magistrates\' Court such behaviour was consensual. Mr Fox, who uses the nicknames Dr Fox and Foxy, became well known for presenting the chart show on Capital Radio, and was a judge on the ITV show Pop Idol between 2001 and 2003 alongside Simon Cowell. He joined Magic 105.4 in 2005, where he presents the breakfast show, Foxy in the Morning. He is currently not hosting the show. Giving evidence on Wednesday, Mr Fox said he had worked with "hundreds" of female colleagues during his career, but had never been accused of sexually inappropriate behaviour until last year. Under questioning from his defence counsel, Jonathan Caplan QC, he told the court his teams had kept their energy up during live broadcasts by playing loud music, dancing

In [32]:
batch['labels'][0]

tensor([ 6920, 17906,  7547,    65,  1219,     3,     9,  1614,     3,    88,
         5908,    16,    96,  7348,    75,    63,   121,  7916,    44,   161,
            6,    68, 11244,   263,   581,   376,    16,  1412,   130,     8,
          166,    16,   112,  2838,    18,  1201,  1415,     5,     1,     0,
            0,     0,     0,     0])

In [36]:
tokenizer.batch_decode(batch['labels'])

['DJ Neil Fox has told a court he engaged in "saucy" behaviour at work, but complaints made against him in 2014 were the first in his 29-year career.</s><pad><pad><pad><pad><pad>',
 "Tottenham midfielder Dele Alli has been ruled out of England's World Cup qualifier with Scotland and friendly against Spain after suffering a knee injury in training.</s><pad><pad><pad><pad><pad><pad><pad>",
 "From when Karen Morgan was 12, until she was well into her teens, she was sexually abused by her uncle - a ministerial servant with the Jehovah's Witnesses.</s>",
 'Former Greek Finance Minister Yanis Varoufakis has told the BBC that economic reforms imposed on his country by creditors are "going to fail", ahead of talks on a huge bailout.</s><pad>']