In [None]:
!pip install -q transformers datasets evaluate rouge_score

# Summarization

Summarization creates a shorter version of a document or an article that captures all the important information, and it is a sequence-to-sequence task.

Summarization can be:
* Extractive: extract the most relevant information from a document.
* Abstractive: generate new text that captures the most relevant information.

## Load BillSum dataset

In [None]:
from datasets import load_dataset

billsum = load_dataset('billsum', split='ca_test')

In [None]:
billsum = billsum.train_test_split(test_size=0.2)

In [None]:
billsum['train'][0]

{'text': 'The people of the State of California do enact as follows:\n\n\nSECTION 1.\nSection 7612 of the Family Code is amended to read:\n7612.\n(a) Except as provided in Chapter 1 (commencing with Section 7540) and Chapter 3 (commencing with Section 7570) of Part 2 or in Section 20102, a presumption under Section 7611 is a rebuttable presumption affecting the burden of proof and may be rebutted in an appropriate action only by clear and convincing evidence.\n(b) If two or more presumptions arise under Section 7610 or 7611 that conflict with each other, or if a presumption under Section 7611 conflicts with a claim pursuant to Section 7610, the presumption which on the facts is founded on the weightier considerations of policy and logic controls.\n(c) In an appropriate action, a court may find that more than two persons with a claim to parentage under this division are parents if the court finds that recognizing only two parents would be detrimental to the child. In determining detrime

There are two fields that we want to use:
* `text`: the text of the bill which will be the input to the model.
* `summary`: a condensed version of `text` which will be the model target.

## Preprocess

We need to load a T5 tokenizer to process `text` and `summary`:

In [None]:
from transformers import AutoTokenizer

checkpoint = 'google-t5/t5-small'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

We also want to create a preprocessing function to
* Prefix the input with a prompt so T5 knows this is a summarization task. Some models capable of multiple NLP tasks require prompting for specific tasks.
* Use the keyword `text_target` argument when tokenizing labels.
* Truncate sequences to be no longer than the maximum legnth set by the `max_length` parameter.

In [None]:
prefix = 'summarize: '


def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples['text']]
    model_inputs = tokenizer(
        inputs,
        max_length=1024,
        truncation=True,
    )

    labels = tokenizer(
        text_target=examples['summary'],
        max_length=128,
        truncation=True,
    )
    model_inputs['labels'] = labels['input_ids']

    return model_inputs

In [None]:
tokenized_billsum = billsum.map(preprocess_function, batched=True)

Map:   0%|          | 0/989 [00:00<?, ? examples/s]

Map:   0%|          | 0/248 [00:00<?, ? examples/s]

Create a batch of examples using `DataCollatorForSeq2Seq`.

In [None]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

## Evaluate

For this task, load the `ROUGE` metric.

In [None]:
import evaluate

rouge = evaluate.load('rouge')

Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

In [None]:
import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(
        predictions=decoded_preds,
        references=decoded_labels,
        use_stemmer=True,
    )

    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id)
        for pred in predictions
    ]
    result['gen_len'] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

## Train

In [None]:
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer

model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir='my_billsum_model',
    eval_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=4,
    predict_with_generate=True,
    fp16=True,
    push_to_hub=False,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_billsum['train'],
    eval_dataset=tokenized_billsum['test'],
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

## Inference

In [None]:
text = "summarize: The Inflation Reduction Act lowers prescription drug costs, health care costs, and energy costs. It's the most aggressive action on tackling the climate crisis in American history, which will lift up American workers and create good-paying, union jobs across the country. It'll lower the deficit and ask the ultra-wealthy and corporations to pay their fair share. And no one making under $400,000 per year will pay a penny more in taxes."

In [None]:
from transformers import pipeline

summarizer = pipeline('summarization', model='stevhliu/my_awesome_billsum_model')

In [None]:
summarizer(text)

Your max_length is set to 200, but your input_length is only 103. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=51)


[{'summary_text': "the Inflation Reduction Act lowers prescription drug costs, health care costs, and energy costs. It's the most aggressive action on tackling the climate crisis in American history, which will lift up American workers and create good-paying, union jobs across the country."}]

Manually replicate the results

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

checkpoint = 'stevhliu/my_awesome_billsum_model'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)



In [None]:
inputs = tokenizer(text, return_tensors='pt').input_ids
outputs = model.generate(inputs, max_new_tokens=100, do_sample=False)

tokenizer.decode(outputs[0], skip_special_tokens=True)

"The Inflation Reduction Act lowers prescription drug costs, health care costs, and energy costs. It's the most aggressive action on tackling the climate crisis in American history. It'll ask the ultra-wealthy and corporations to pay their fair share."