Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

无法加载训练的模型,程序自动从HuggingFace下载模型,这是什么原因? #35

Closed
TGLTommyAI opened this issue Jun 9, 2022 · 0 comments

Comments

@TGLTommyAI
Copy link

TGLTommyAI commented Jun 9, 2022

你好,通过调用inference.py中的load_model_directly()方法,无法加载训练的模型,具体代码如下:

① 代码部分:

def load_model_directly():
ckpt_file = 'SoftMaskedBert/epoch=05-val_loss=0.03253.ckpt'
config_file = 'csc/train_SoftMaskedBert.yml'

from bbcm.config import cfg
cp = get_abs_path('checkpoints', ckpt_file)
cfg.merge_from_file(get_abs_path('configs', config_file))
tokenizer = BertTokenizer.from_pretrained(cfg.MODEL.BERT_CKPT)
print("###tokenizer加载完毕")
print("### tokenizer: ", tokenizer)
if cfg.MODEL.NAME in ['bert4csc', 'macbert4csc']:
    model = BertForCsc.load_from_checkpoint(cp,
                                            cfg=cfg,
                                            tokenizer=tokenizer)
else:
    print("###加载模型")
    print("###cp : ", cp)
    model = SoftMaskedBertModel.load_from_checkpoint(cp,
                                                     cfg=cfg,
                                                     tokenizer=tokenizer)
print("###model加载完毕")
model.eval()
model.to(cfg.MODEL.DEVICE)
return model

② 问题:
感觉这段代码没有起作用,ckpt文件无法加载,程序还是自动从huggingface下载了。
model = SoftMaskedBertModel.load_from_checkpoint(cp,
cfg=cfg,
tokenizer=tokenizer)
我查了一下load_from_checkpoint() 方法,对于参数cp, cfg的传递,没有看明白。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant