Skip to content

Latest commit

 

History

History
171 lines (117 loc) · 6.5 KB

pegasus.md

File metadata and controls

171 lines (117 loc) · 6.5 KB

Pegasus

Overview

The Pegasus model was proposed in PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu on Dec 18, 2019.

According to the abstract,

  • Pegasus' pretraining task is intentionally similar to summarization: important sentences are removed/masked from an input document and are generated together as one output sequence from the remaining sentences, similar to an extractive summary.
  • Pegasus achieves SOTA summarization performance on all 12 downstream tasks, as measured by ROUGE and human eval.

This model was contributed by sshleifer. The Authors' code can be found here.

Usage tips

  • Sequence-to-sequence model with the same encoder-decoder model architecture as BART. Pegasus is pre-trained jointly on two self-supervised objective functions: Masked Language Modeling (MLM) and a novel summarization specific pretraining objective, called Gap Sentence Generation (GSG).

    • MLM: encoder input tokens are randomly replaced by a mask tokens and have to be predicted by the encoder (like in BERT)
    • GSG: whole encoder input sentences are replaced by a second mask token and fed to the decoder, but which has a causal mask to hide the future words like a regular auto-regressive transformer decoder.
  • FP16 is not supported (help/ideas on this appreciated!).

  • The adafactor optimizer is recommended for pegasus fine-tuning.

Checkpoints

All the checkpoints are fine-tuned for summarization, besides pegasus-large, whence the other checkpoints are fine-tuned:

  • Each checkpoint is 2.2 GB on disk and 568M parameters.
  • FP16 is not supported (help/ideas on this appreciated!).
  • Summarizing xsum in fp32 takes about 400ms/sample, with default parameters on a v100 GPU.
  • Full replication results and correctly pre-processed data can be found in this Issue.
  • Distilled checkpoints are described in this paper.

Implementation Notes

  • All models are transformer encoder-decoders with 16 layers in each component.
  • The implementation is completely inherited from [BartForConditionalGeneration]
  • Some key configuration differences:
    • static, sinusoidal position embeddings
    • the model starts generating with pad_token_id (which has 0 token_embedding) as the prefix.
    • more beams are used (num_beams=8)
  • All pretrained pegasus checkpoints are the same besides three attributes: tokenizer.model_max_length (maximum input size), max_length (the maximum number of tokens to generate) and length_penalty.
  • The code to convert checkpoints trained in the author's repo can be found in convert_pegasus_tf_to_pytorch.py.

Usage Example

>>> from transformers import PegasusForConditionalGeneration, PegasusTokenizer
>>> import torch

>>> src_text = [
...     """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
... ]

... model_name = "google/pegasus-xsum"
... device = "cuda" if torch.cuda.is_available() else "cpu"
... tokenizer = PegasusTokenizer.from_pretrained(model_name)
... model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)
... batch = tokenizer(src_text, truncation=True, padding="longest", return_tensors="pt").to(device)
... translated = model.generate(**batch)
... tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
... assert (
...     tgt_text[0]
...     == "California's largest electricity provider has turned off power to hundreds of thousands of customers."
... )

Resources

PegasusConfig

[[autodoc]] PegasusConfig

PegasusTokenizer

warning: add_tokens does not work at the moment.

[[autodoc]] PegasusTokenizer

PegasusTokenizerFast

[[autodoc]] PegasusTokenizerFast

PegasusModel

[[autodoc]] PegasusModel - forward

PegasusForConditionalGeneration

[[autodoc]] PegasusForConditionalGeneration - forward

PegasusForCausalLM

[[autodoc]] PegasusForCausalLM - forward

TFPegasusModel

[[autodoc]] TFPegasusModel - call

TFPegasusForConditionalGeneration

[[autodoc]] TFPegasusForConditionalGeneration - call

FlaxPegasusModel

[[autodoc]] FlaxPegasusModel - call - encode - decode

FlaxPegasusForConditionalGeneration

[[autodoc]] FlaxPegasusForConditionalGeneration - call - encode - decode