Summarisation creates a shorter version of a document or an article that captures all the important information. Along with translation, it is another example of a task that can be formulated as a sequence-to-sequence task. 

Summarisation can be:
Extractive - extract the most relevant information from a document, or
Abstractive - generate new text that captures the most relevant information.

This guide shows how to:
1. Finetune T5 on the California state bill subset of the BillSum dataset for abstractive summarisation.
2. Use the finetuned model for inference.


# Libraries

In [None]:
pip install transformers datasets evaluate rouge_score

In [None]:
import torch
import evaluate
import numpy as np
from datasets import load_dataset
from transformers import pipeline, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import AutoTokenizer, DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM
mps_device = torch.device("mps")

# Data Load

In [None]:
# Load the smaller California state bill subset of the BillSum dataset
billsum = load_dataset("billsum", split="ca_test")

In [None]:
# Use train_test_split to split the dataset
billsum = billsum.train_test_split(test_size=0.2)

# The two fields to use for modeling:
# text: the text of the bill which’ll be the input to the model.
# summary: a condensed version of text which will be the model target.
billsum["train"][0]

# Preprocess

In [None]:
# load a T5 tokenizer to process text and summary
checkpoint = "google-t5/t5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [None]:
# Prefix the input with a prompt so T5 knows this is a summarization task
prefix = "summarize: "

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["text"]]
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True)

    # Use the keyword text_target argument when tokenizing labels
    # Truncate sequences to be no longer than the maximum length set by the max_length parameter.
    labels = tokenizer(text_target=examples["summary"], max_length=128, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
tokenized_billsum = billsum.map(preprocess_function, batched=True)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

# Evaluation

In [None]:
# To evaluate this summarisation task, load the ROUGE metric
rouge = evaluate.load("rouge")

In [None]:
# Function that passes predictions and labels to compute() to calculate the ROUGE metric
# Called during training
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

# Training

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
model.to(mps_device)

In [None]:
# Define training hyperparameters in Seq2SeqTrainingArguments()
# The only required parameter is output_dir
# At the end of each epoch, the Trainer will evaluate the ROUGE metric and save the training checkpoint
training_args = Seq2SeqTrainingArguments(
    output_dir="summarisation_model",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=4,
    predict_with_generate=True,
    fp16=True
)

# Pass the training arguments to Seq2SeqTrainer...
# along with the model, dataset, tokenizer, data collator, and compute_metrics() function
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_billsum["train"],
    eval_dataset=tokenized_billsum["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# Call train() to finetune the model.
trainer.train()

# Inference

In [None]:
# Text you’d like to summarise 
# For T5, you need to prefix your input depending on the task you’re working on...
# e.g. for summarisation, prefix your input as shown below:
text = "summarize: The Inflation Reduction Act is a proposed piece of legislation\
that is supposed to lower prescription drug costs, health care costs, and energy costs.\
It's the most aggressive action on tackling the climate crisis (and maybe inflation?) in American history, \
which will lift up American workers and create good-paying, union jobs across the country. \
It promises to lower the deficit and ask the ultra-wealthy and corporations to pay their fair share. \
And it nearly guarantees that no one making under $400,000 per year will pay a penny more in taxes."

In [None]:
# --> inference within a pipeline()
summarizer = pipeline("summarization", model="summarisation_model")
summarizer(text)

In [None]:
# --> inference using PyTorch objects
tokenizer = AutoTokenizer.from_pretrained("summarisation_model")
inputs = tokenizer(text, return_tensors="pt").input_ids

# # Use the generate() method to generate the summarised text
model = AutoModelForSeq2SeqLM.from_pretrained("stevhliu/my_awesome_billsum_model")
outputs = model.generate(inputs, max_new_tokens=100, do_sample=False)
# Decode the generated token ids back into text
tokenizer.decode(outputs[0], skip_special_tokens=True)