Summarisation creates a shorter version of a document or an article that captures all the important information. Along with translation, it is another example of a task that can be formulated as a sequence-to-sequence task. 

Summarisation can be:
Extractive - extract the most relevant information from a document, or
Abstractive - generate new text that captures the most relevant information.

This guide shows how to:
1. Finetune T5 on the California state bill subset of the BillSum dataset for abstractive summarisation.
2. Use the finetuned model for inference.


# Libraries

In [None]:
pip install transformers datasets evaluate rouge_score

In [None]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorForSeq2Seq

mps_device = torch.device("mps")

# Data Load

In [None]:
# Load the smaller California state bill subset of the BillSum dataset
billsum = load_dataset("billsum", split="ca_test")

In [None]:
# Use train_test_split to split the dataset
billsum = billsum.train_test_split(test_size=0.2)

# The two fields to use for modeling:
# text: the text of the bill which’ll be the input to the model.
# summary: a condensed version of text which will be the model target.
billsum["train"][0]

# Preprocess

In [None]:
# load a T5 tokenizer to process text and summary
checkpoint = "google-t5/t5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [None]:
# Prefix the input with a prompt so T5 knows this is a summarization task
prefix = "summarize: "

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["text"]]
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True)

    # Use the keyword text_target argument when tokenizing labels
    # Truncate sequences to be no longer than the maximum length set by the max_length parameter.
    labels = tokenizer(text_target=examples["summary"], max_length=128, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
tokenized_billsum = billsum.map(preprocess_function, batched=True)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)