Skip to content

Commit

Permalink
Save tokenizer in conversion script (#128)
Browse files Browse the repository at this point in the history
* feature: save tokenizer based on script args

* chore: use none instead of empty str for consistency

* fix: rm duplicate args, save `tokenizer_class` key

* Update tools/convert_checkpoint/deepspeed_to_transformers.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

Co-authored-by: Jake Tae <>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
  • Loading branch information
jaketae and stas00 committed Oct 7, 2021
1 parent 323bf5c commit 23dded0
Showing 1 changed file with 30 additions and 8 deletions.
38 changes: 30 additions & 8 deletions tools/convert_checkpoint/deepspeed_to_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,23 @@
# the import was tested to work with this version
# https://github.com/huggingface/transformers/commit/0af901e83 if it diverges we may consider
# copying that version here instead
from transformers.models.megatron_gpt2.convert_megatron_gpt2_checkpoint import convert_megatron_checkpoint
from transformers import GPT2Config
from transformers.models.megatron_gpt2.convert_megatron_gpt2_checkpoint import (
convert_megatron_checkpoint,
)
from transformers import GPT2Config, AutoTokenizer

def main():

def main():
# this first part comes mainly from deepspeed_to_megatron.main
args = parse_arguments()
print(f'Converting DeepSpeed checkpoint in {args.input_folder} to HF Transformers checkpoint in {args.output_folder}')
print(
f"Converting DeepSpeed checkpoint in {args.input_folder} to HF Transformers checkpoint in {args.output_folder}"
)

ds_checkpoint = DeepSpeedCheckpoint(args.input_folder, args.target_tp, args.target_pp)
iteration = ds_checkpoint.get_iteration()
ds_checkpoint = DeepSpeedCheckpoint(
args.input_folder, args.target_tp, args.target_pp
)
ds_args = ds_checkpoint.get_args()
input_state_dict = _create_rank_checkpoint(ds_checkpoint, 0, 0, args.for_release)

# the 2nd part comes from transformers.models.megatron_gpt2.convert_megatron_gpt2_checkpoint.main
Expand Down Expand Up @@ -59,14 +65,28 @@ def main():
os.makedirs(basename, exist_ok=True)

# Print the structure of converted state dict.
#if args.print_checkpoint_structure:
# if args.print_checkpoint_structure:
# recursive_print(None, output_state_dict)

# Store the config to file.
output_config_file = os.path.join(basename, "config.json")
output_config = config.to_dict()
output_config["architectures"] = ["GPT2LMHeadModel"]
output_config["model_type"] = "gpt2"

# Add tokenizer class info to config.json
# see https://github.com/huggingface/transformers/issues/13906)
tokenizer_type = ds_args.tokenizer_type
if tokenizer_type == "GPT2BPETokenizer":
tokenizer_model_name = "gpt2"
elif tokenizer_type == "PretrainedFromHF":
tokenizer_model_name = ds_args.tokenizer_name_or_path
else:
raise ValueError(f"Unrecognized tokenizer_type {tokenizer_type}")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_model_name)
tokenizer_class = type(tokenizer).__name__
output_config["tokenizer_class"] = tokenizer_class

print(f'Saving config to "{output_config_file}"')
with open(output_config_file, "w") as f:
json.dump(output_config, f)
Expand All @@ -76,7 +96,9 @@ def main():
print(f'Saving checkpoint to "{output_checkpoint_file}"')
torch.save(output_state_dict, output_checkpoint_file)

print("Now add tokenizer files and upload to the hub")
# Save tokenizer based on args
print(f"Adding {tokenizer_class} tokenizer files")
tokenizer.save_pretrained(basename)


if __name__ == "__main__":
Expand Down

0 comments on commit 23dded0

Please sign in to comment.