Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

生成的速度太慢了,能否加一个生成的batch_size大于1的功能 #45

Closed
huosu opened this issue Aug 29, 2019 · 6 comments
Closed

Comments

@huosu
Copy link

huosu commented Aug 29, 2019

如题!谢谢!

@fengzuo97
Copy link

generate函数的确有优化的空间,没有用到past,利用到past,速度还快很多

@Morizeyao
Copy link
Owner

merge了fengzuo97的PR,试试看呢?

@DevelMayCry-MrChen
Copy link

我想问问你们的数据集从哪里来的呢 能上传一个数据集的样本吗

@Morizeyao
Copy link
Owner

中文数据集在项目README文档中有

@HamQ
Copy link

HamQ commented Dec 3, 2019

要是有教程把 GPT2-ML 的1.5 中文模型 转成本项目的Pytorch格式就好了

@ScottishFold007
Copy link

ScottishFold007 commented Feb 6, 2020

要是有教程把 GPT2-ML 的1.5 中文模型 转成本项目的Pytorch格式就好了

我试过,里面缺个东西,目前转不了,不信你可以试试下面的代码:
"""Convert OpenAI GPT checkpoint."""

import argparse
import logging

import torch

from transformers import CONFIG_NAME, WEIGHTS_NAME, GPT2Config, GPT2Model, load_tf_weights_in_gpt2

logging.basicConfig(level=logging.INFO)

def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path):
# Construct model
if gpt2_config_file == "":
config = GPT2Config()
else:
config = GPT2Config.from_json_file(gpt2_config_file)
model = GPT2Model(config)

# Load weights from numpy
load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path)

# Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
torch.save(model.state_dict(), pytorch_weights_dump_path)
print("Save configuration file to {}".format(pytorch_config_dump_path))
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
    f.write(config.to_json_string())

if name == "main":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--gpt2_checkpoint_path", default=r'C:\Users\gaochangkuan\Desktop\2020.02.01 Chinese_news_generation_gpt2-ml-Chinse\chinese_model', type=str, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--pytorch_dump_folder_path", default=r'C:\Users\gaochangkuan\Desktop\2020.02.01 Chinese_news_generation_gpt2-ml-Chinse\chinese_model', type=str, help="Path to the output PyTorch model."
)
parser.add_argument(
"--gpt2_config_file",
default="",
type=str,
help="An optional config json file corresponding to the pre-trained OpenAI model. \n"
"This specifies the model architecture.",
)
args = parser.parse_args([])
convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, args.gpt2_config_file, args.pytorch_dump_folder_path)
`

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants