We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
你好,通过调用inference.py中的load_model_directly()方法,无法加载训练的模型,具体代码如下:
① 代码部分:
def load_model_directly(): ckpt_file = 'SoftMaskedBert/epoch=05-val_loss=0.03253.ckpt' config_file = 'csc/train_SoftMaskedBert.yml'
from bbcm.config import cfg cp = get_abs_path('checkpoints', ckpt_file) cfg.merge_from_file(get_abs_path('configs', config_file)) tokenizer = BertTokenizer.from_pretrained(cfg.MODEL.BERT_CKPT) print("###tokenizer加载完毕") print("### tokenizer: ", tokenizer) if cfg.MODEL.NAME in ['bert4csc', 'macbert4csc']: model = BertForCsc.load_from_checkpoint(cp, cfg=cfg, tokenizer=tokenizer) else: print("###加载模型") print("###cp : ", cp) model = SoftMaskedBertModel.load_from_checkpoint(cp, cfg=cfg, tokenizer=tokenizer) print("###model加载完毕") model.eval() model.to(cfg.MODEL.DEVICE) return model
② 问题: 感觉这段代码没有起作用,ckpt文件无法加载,程序还是自动从huggingface下载了。 model = SoftMaskedBertModel.load_from_checkpoint(cp, cfg=cfg, tokenizer=tokenizer) 我查了一下load_from_checkpoint() 方法,对于参数cp, cfg的传递,没有看明白。
The text was updated successfully, but these errors were encountered:
No branches or pull requests
你好,通过调用inference.py中的load_model_directly()方法,无法加载训练的模型,具体代码如下:
① 代码部分:
def load_model_directly():
ckpt_file = 'SoftMaskedBert/epoch=05-val_loss=0.03253.ckpt'
config_file = 'csc/train_SoftMaskedBert.yml'
② 问题:
感觉这段代码没有起作用,ckpt文件无法加载,程序还是自动从huggingface下载了。
model = SoftMaskedBertModel.load_from_checkpoint(cp,
cfg=cfg,
tokenizer=tokenizer)
我查了一下load_from_checkpoint() 方法,对于参数cp, cfg的传递,没有看明白。
The text was updated successfully, but these errors were encountered: