Skip to content

Commit

Permalink
mBART Conversion script (#6230)
Browse files Browse the repository at this point in the history
  • Loading branch information
sshleifer committed Aug 4, 2020
1 parent 268bf34 commit d5b0a0e
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,19 +78,6 @@ def load_xsum_checkpoint(checkpoint_path):
return hub_interface


def convert_checkpoint_from_disk(checkpoint_path, **config_kwargs):
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
remove_ignore_keys_(state_dict)
vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0]
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
mbart_config = BartConfig(vocab_size=vocab_size, **config_kwargs)
model = BartForConditionalGeneration(mbart_config)
model.model.load_state_dict(state_dict)
if hasattr(model, "lm_head"):
model.lm_head = _make_linear_from_emb(model.model.shared)
return model


@torch.no_grad()
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None):
"""
Expand Down
36 changes: 36 additions & 0 deletions src/transformers/convert_mbart_original_checkpoint_to_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import argparse

import torch

from transformers import BartForConditionalGeneration, MBartConfig

from .convert_bart_original_pytorch_checkpoint_to_pytorch import remove_ignore_keys_


def convert_fairseq_mbart_checkpoint_from_disk(checkpoint_path, hf_config_path="facebook/mbart-large-en-ro"):
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
remove_ignore_keys_(state_dict)
vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0]
mbart_config = MBartConfig.from_pretrained(hf_config_path, vocab_size=vocab_size)
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
model = BartForConditionalGeneration(mbart_config)
model.model.load_state_dict(state_dict)
return model


if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"fairseq_path", type=str, help="bart.large, bart.large.cnn or a path to a model.pt on local filesystem."
)
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument(
"--hf_config",
default="facebook/mbart-large-cc25",
type=str,
help="Which huggingface architecture to use: bart-large-xsum",
)
args = parser.parse_args()
model = convert_fairseq_mbart_checkpoint_from_disk(args.fairseq_path, hf_config_path=args.hf_config)
model.save_pretrained(args.pytorch_dump_folder_path)

0 comments on commit d5b0a0e

Please sign in to comment.