# Text summarization with T5 on XSum

We are going to fine-tune the [T5 model, implemented by HuggingFace](https://huggingface.co/t5-small), for text summarization on the [Extreme Summarization (XSum)](https://huggingface.co/datasets/xsum) dataset.
The data is composed by news articles and the corresponding summaries.

We will be using the following model sizes available from HuggingFace

| Variant                                     |   Parameters    |
|:-------------------------------------------:|----------------:|
| [T5-small](https://huggingface.co/t5-small) |    60,506,624   | 
| [T5-large](https://huggingface.co/t5-large) |   737,668,096   | 
| [T5-3b](https://huggingface.co/t5-3b)       | 2,851,598,336   | 


More info:
* This notebooks is based on the script [run_summarization_no_trainer.py](https://github.com/huggingface/transformers/blob/v4.12.5/examples/pytorch/summarization/run_summarization_no_trainer.py) from HuggingFace
* [T5 on HuggingFace docs](https://huggingface.co/transformers/model_doc/t5.html)

In [None]:
import os
import datasets
from datasets import load_dataset, load_metric
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq
from torch.utils.data import DataLoader

In [None]:
from datasets.utils import disable_progress_bar
from datasets import disable_caching


disable_progress_bar()
disable_caching()

## The data

In [None]:
hf_dataset = load_dataset('xsum')

In [None]:
hf_dataset

In [None]:
sample = 188948

In [None]:
hf_dataset['train']['id'][sample]

In [None]:
hf_dataset['train']['summary'][sample]

In [None]:
hf_dataset['train']['document'][sample]

## The tokenizer

In [None]:
hf_model = 't5-small'
t5_cache = os.path.join(os.getcwd(), 'cache')

tokenizer = AutoTokenizer.from_pretrained(
    hf_model,
    use_fast=True,
    cache_dir=os.path.join(t5_cache, f'{hf_model}_tokenizer')
)

In [None]:
encoded_text = tokenizer("What's up tokenizer!",
                         max_length=1024,
                         padding=False,
                         truncation=True)

In [None]:
encoded_text

 * `attention_mask` indicates what's text and what's padding

In [None]:
tokenizer.batch_decode(encoded_text['input_ids'])

In [None]:
with tokenizer.as_target_tokenizer():
    encoded_text = tokenizer("What's up tokenizer!", max_length=1024,
                             padding=False, truncation=True)

In [None]:
encoded_text

## Tokenizing the data

In [None]:
def preprocess_function(examples):    
    inputs = examples['document']
    targets = examples['summary']
    inputs = [f'summarize: {inp}' for inp in inputs]

    model_inputs = tokenizer(inputs, max_length=1024,
                             padding=False, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=128,
                           padding=False, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
%%time
processed_datasets = hf_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=hf_dataset["train"].column_names,
    desc="Running tokenizer on dataset",
    num_proc=12
)

In [None]:
processed_datasets

In [None]:
# For training Sequence to Sequence models, we need a special kind of data collator,
# which will not only pad the inputs to the maximum length in the batch,
# but also the labels.
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    label_pad_token_id=tokenizer.pad_token_id
)

per_device_train_batch_size = 128

train_dataset = processed_datasets["train"]

train_dataloader = DataLoader(
    train_dataset,
    shuffle=False,
    collate_fn=data_collator,
    batch_size=per_device_train_batch_size
)

In [None]:
for step, batch in enumerate(train_dataloader):
    if step > 15:
        break

In [None]:
type(batch)

In [None]:
batch.keys()

In [None]:
batch['input_ids'].shape

In [None]:
batch['input_ids'][0]

In [None]:
batch['attention_mask']  # indicates what's text and what's padding

In [None]:
batch['attention_mask'][0]

In [None]:
tokenizer.decode(batch['input_ids'][0][batch['attention_mask'][0]==1])

In [None]:
batch['labels'][0]

In [None]:
tokenizer.batch_decode(batch['labels'])[0]