Skip to content

Commit

Permalink
Pegasus finetune script: add --adafactor (#6811)
Browse files Browse the repository at this point in the history
  • Loading branch information
sshleifer committed Aug 29, 2020
1 parent ac47458 commit 0f58903
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 70 deletions.
2 changes: 1 addition & 1 deletion examples/seq2seq/finetune_pegasus_xsum.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ python finetune.py \
--n_val 1000 \
--val_check_interval 0.25 \
--max_source_length 512 --max_target_length 56 \
--freeze_embeds --max_target_length 56 --label_smoothing 0.1 \
--freeze_embeds --label_smoothing 0.1 --adafactor --task summarization_xsum \
"$@"
55 changes: 16 additions & 39 deletions src/transformers/configuration_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,46 +47,23 @@
activation_function="relu",
)
# Config values that vary between checkpoints: for testing and conversion
max_gen_length = {
# See appendix C of paper
"xsum": 64,
"cnn_dailymail": 128,
"newsroom": 128,
"wikihow": 256,
"multi_news": 256,
"reddit_tifu": 128,
"big_patent": 256,
"arxiv": 256,
"pubmed": 256,
"gigaword": 32,
"aeslc": 32,
"billsum": 256,
"large": 256, # @sshleifer chose arbitrarily
task_specific_params = {
# These are task specific params for pegasus-large and normal params for finetuned checkpoints
"summarization_xsum": {"length_penalty": 0.8, "max_length": 64, "max_position_embeddings": 512},
"summarization_cnn_dailymail": {"length_penalty": 0.8, "max_length": 128, "max_position_embeddings": 1024},
"summarization_newsroom": {"length_penalty": 0.8, "max_length": 128, "max_position_embeddings": 512},
"summarization_wikihow": {"length_penalty": 0.6, "max_length": 256, "max_position_embeddings": 512},
"summarization_multi_news": {"length_penalty": 0.8, "max_length": 256, "max_position_embeddings": 1024},
"summarization_reddit_tifu": {"length_penalty": 0.6, "max_length": 128, "max_position_embeddings": 512},
"summarization_big_patent": {"length_penalty": 0.7, "max_length": 256, "max_position_embeddings": 1024},
"summarization_arxiv": {"length_penalty": 0.8, "max_length": 256, "max_position_embeddings": 1024},
"summarization_pubmed": {"length_penalty": 0.8, "max_length": 256, "max_position_embeddings": 1024},
"summarization_gigaword": {"length_penalty": 0.6, "max_length": 32, "max_position_embeddings": 128},
"summarization_aeslc": {"length_penalty": 0.6, "max_length": 32, "max_position_embeddings": 512},
"summarization_billsum": {"length_penalty": 0.6, "max_length": 256, "max_position_embeddings": 1024},
# this last entry is useless -- just for consistency
"summarization_large": {"length_penalty": 0.8, "max_length": 256, "max_position_embeddings": 1024},
}
max_model_length = {
"xsum": 512,
"cnn_dailymail": 1024,
"newsroom": 512,
"wikihow": 512,
"multi_news": 1024,
"reddit_tifu": 512,
"big_patent": 1024,
"arxiv": 1024,
"pubmed": 1024,
"gigaword": 128,
"aeslc": 512,
"billsum": 1024,
"large": 1024,
}
expected_alpha = {
"multinews": 0.9,
"wikihow": 0.6,
"reddit_tifu": 0.6,
"big_patent": 0.7,
"gigaword": 0.6,
"aeslc": 0.6,
"billsum": 0.6,
} # otherwise 0.8


@add_start_docstrings_to_callable(BART_CONFIG_ARGS_DOC)
Expand Down
22 changes: 13 additions & 9 deletions src/transformers/convert_pegasus_tf_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import argparse
import os
from pathlib import Path
from typing import Dict

Expand All @@ -22,7 +23,7 @@
from tqdm import tqdm

from transformers import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer
from transformers.configuration_pegasus import DEFAULTS, expected_alpha, max_gen_length, max_model_length
from transformers.configuration_pegasus import DEFAULTS, task_specific_params


PATTERNS = [
Expand Down Expand Up @@ -101,23 +102,25 @@ def get_tf_weights_as_numpy(path="./ckpt/aeslc/model.ckpt-32000") -> Dict:
return tf_weights


def convert_pegasus_ckpt_to_pytorch(ckpt_path, save_dir):
def convert_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str):
# save tokenizer first
dataset = Path(ckpt_path).parent.name
desired_max_model_length = max_model_length[dataset]
desired_max_model_length = task_specific_params[f"summarization_{dataset}"]["max_position_embeddings"]
tok = PegasusTokenizer.from_pretrained("sshleifer/pegasus", model_max_length=desired_max_model_length)
assert tok.model_max_length == desired_max_model_length
tok.save_pretrained(save_dir)

# convert model
tf_weights = get_tf_weights_as_numpy(ckpt_path)
cfg_updates = dict(
max_length=max_gen_length[dataset],
length_penalty=expected_alpha.get(dataset, 0.8),
max_position_embeddings=desired_max_model_length,
)
cfg_updates = task_specific_params[f"summarization_{dataset}"]
if dataset == "large":
cfg_updates["task_specific_params"] = task_specific_params
torch_model = convert_pegasus_to_bart(tf_weights, cfg_updates)
torch_model.save_pretrained(save_dir)
sd = torch_model.state_dict()
sd.pop("model.decoder.embed_positions.weight")
sd.pop("model.encoder.embed_positions.weight")
torch.save(sd, Path(save_dir) / "pytorch_model.bin")


if __name__ == "__main__":
Expand All @@ -127,5 +130,6 @@ def convert_pegasus_ckpt_to_pytorch(ckpt_path, save_dir):
parser.add_argument("save_dir", default=None, type=str, help="Path to the output PyTorch model.")
args = parser.parse_args()
if args.save_dir is None:
args.save_dir = f"pegasus/{Path(args.tf_ckpt_path).parent.name}"
dataset = Path(args.tf_ckpt_path).parent.name
args.save_dir = os.path.join("pegasus", dataset)
convert_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir)
39 changes: 18 additions & 21 deletions tests/test_modeling_pegasus.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import unittest

from transformers import AutoConfig, AutoTokenizer, is_torch_available
from transformers.configuration_pegasus import max_gen_length, max_model_length
from transformers.configuration_pegasus import task_specific_params
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.utils.logging import ERROR, set_verbosity

from .test_modeling_bart import PGE_ARTICLE
from .test_modeling_mbart import AbstractSeq2SeqIntegrationTest
Expand All @@ -14,6 +15,8 @@

XSUM_ENTRY_LONGER = """ The London trio are up for best UK act and best album, as well as getting two nominations in the best song category."We got told like this morning 'Oh I think you're nominated'", said Dappy."And I was like 'Oh yeah, which one?' And now we've got nominated for four awards. I mean, wow!"Bandmate Fazer added: "We thought it's best of us to come down and mingle with everyone and say hello to the cameras. And now we find we've got four nominations."The band have two shots at the best song prize, getting the nod for their Tynchy Stryder collaboration Number One, and single Strong Again.Their album Uncle B will also go up against records by the likes of Beyonce and Kanye West.N-Dubz picked up the best newcomer Mobo in 2007, but female member Tulisa said they wouldn't be too disappointed if they didn't win this time around."At the end of the day we're grateful to be where we are in our careers."If it don't happen then it don't happen - live to fight another day and keep on making albums and hits for the fans."Dappy also revealed they could be performing live several times on the night.The group will be doing Number One and also a possible rendition of the War Child single, I Got Soul.The charity song is a re-working of The Killers' All These Things That I've Done and is set to feature artists like Chipmunk, Ironik and Pixie Lott.This year's Mobos will be held outside of London for the first time, in Glasgow on 30 September.N-Dubz said they were looking forward to performing for their Scottish fans and boasted about their recent shows north of the border."We just done Edinburgh the other day," said Dappy."We smashed up an N-Dubz show over there. We done Aberdeen about three or four months ago - we smashed up that show over there! Everywhere we go we smash it up!" """

set_verbosity(ERROR)


@require_torch
class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
Expand Down Expand Up @@ -50,31 +53,25 @@ def test_pegasus_xsum_summary(self):


class PegasusConfigTests(unittest.TestCase):
def test_all_config_max_lengths(self):
@slow
def test_task_specific_params(self):
"""Test that task_specific params['summarization_xsum'] == config['pegasus_xsum'] """
failures = []
pegasus_prefix = "google/pegasus"
for dataset, max_len in max_gen_length.items():
n_prefix_chars = len("summarization_")
for task, desired_settings in task_specific_params.items():
dataset = task[n_prefix_chars:]
mname = f"{pegasus_prefix}-{dataset}"
cfg = AutoConfig.from_pretrained(mname)

if cfg.max_length != max_len:
failures.append(f"config for {mname} had max_length: {cfg.max_length}, expected {max_len}")

if cfg.max_position_embeddings < max_model_length[dataset]:
# otherwise you get IndexError for e.g. position 513
# see https://github.com/huggingface/transformers/issues/6599
failures.append(
f"config for {mname} had max_position_embeddings: {cfg.max_position_embeddings}, expected {max_model_length[dataset]}"
)

for k, v in desired_settings.items():
actual_value = getattr(cfg, k)
if actual_value != v:
failures.append(f"config for {mname} had {k}: {actual_value}, expected {v}")
tokenizer = AutoTokenizer.from_pretrained(mname)
if max_model_length[dataset] != tokenizer.model_max_length:
failures.append(
f"tokenizer.model_max_length {tokenizer.model_max_length} expected {max_model_length[dataset]}"
)
n_pos_embeds = desired_settings["max_position_embeddings"]
if n_pos_embeds != tokenizer.model_max_length:
failures.append(f"tokenizer.model_max_length {tokenizer.model_max_length} expected {n_pos_embeds}")

if failures == []:
return
# error
all_fails = "\n".join(failures)
raise AssertionError(f"The following configs have unexpected settings: {all_fails}")
assert not failures, f"The following configs have unexpected settings: {all_fails}"

0 comments on commit 0f58903

Please sign in to comment.