Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using fine tuned Model #25

Closed
erichen510 opened this issue Nov 5, 2019 · 23 comments
Closed

Using fine tuned Model #25

erichen510 opened this issue Nov 5, 2019 · 23 comments

Comments

@erichen510
Copy link

I post this to find if I'm doing the write thing.

I just add dense and softmax to fine tune the model

albert_model = build_bert_model(config_path, checkpoint_path, albert=True)
out = Lambda(lambda x: x[: 0])(albert_model.output)
output = Dense(units=class_num, activation = 'softmax')(out)

after I trained the model, I try to load the model by

model = load_model (model.dir)

and I get the error like 'I miss the custom layer 'TokenEmbedding'
after that, I try

 custom_objects = {'MaskedGlobalPool1D: MaskedGlobalPool1D}
 custom_objects.update(get_bert_custom_objects())

get_bert_custom_objects() is come from keras_bert, basically just define some custom layer
while MaskedGlobalPool1D from keras_bert aiming to get rid of the mask of the output of the model.

I don't know if I'm doing right, since the prediction is not good enough.
Can someone explain what is the TokenEmbedding layer, the dese layer I defined?

@bojone
Copy link
Owner

bojone commented Nov 5, 2019

are you sure you are using bert4keras? keras_bert is here: https://github.com/CyberZHG/keras-bert

if you are using bert4keras, please from bert4keras.layers import * before load_model.

@erichen510
Copy link
Author

yes, I'm using the bert4keras.
from bert4keras.bert import build_bert_model
I use the function from keras_bert just because I get the error when I loaded the fine tune model

Unkonwn layer: TokenEmbedding

@erichen510
Copy link
Author

就是我直接读取用bert4keras fine tune出的模型 会报 Unkonwn layer: TokenEmbedding
将 TokenEmbedding 加到custom_object之后 又报 Unkonwn layer: MaxkedGlobalMaxPool1D
我按

custom_objects = {'MaskedGlobalPool1D: MaskedGlobalPool1D}
custom_objects.update(get_bert_custom_objects())
model = load_model (model.dir, custom_objects = custom_objects)

这么些才不报错

@bojone
Copy link
Owner

bojone commented Nov 5, 2019

那就是解决啦?用了第三方的层,我也不能拍板解决啦,只能见招拆招呀~

@erichen510
Copy link
Author

那应该是解决了吧。。很迷为啥能输出mask

@erichen510
Copy link
Author

问题是我定义模型的时候 没有使用 MaskedGlobalPool1D 之类的第三方层 仅仅使用了一层dense

@bojone
Copy link
Owner

bojone commented Nov 5, 2019

MaskedGlobalPool1D在哪里引入的?

@erichen510
Copy link
Author

MaskedGlobalPool1D是定义在keras_bert中的

@erichen510
Copy link
Author

我好像知道错在哪里了。。。

@erichen510
Copy link
Author

这次是返回 Unkonwn layer: FactorizedEmbedding

@erichen510
Copy link
Author

抱歉抱歉实在抱歉我之前读错模型文件了。。。。请问这个FactorizedEmbedding是bert4keras中定义的吗

@bojone
Copy link
Owner

bojone commented Nov 5, 2019

FactorizedEmbedding是bert4keras的,不过最新版已经删掉了。

@erichen510
Copy link
Author

那我只要引入FactorizedEmbedding 就行了吗 还是需要再训练一遍
我记得我这个版本是基于前天下的git上的本地编译的, 苏神更新好快

@bojone
Copy link
Owner

bojone commented Nov 5, 2019

刚删除的,你可以用0.2.3版加载
https://github.com/bojone/bert4keras/releases/tag/v0.2.3

@erichen510
Copy link
Author

老版训练的模型也能用新版加载吗

@bojone
Copy link
Owner

bojone commented Nov 5, 2019

0.2.4及后续版本都没有FactorizedEmbedding了,你要FactorizedEmbedding只能0.2.3或者更早版本

@erichen510
Copy link
Author

好的 我再试试 谢谢苏神

@erichen510
Copy link
Author

我现在 用的0.22版本 然后将所有的custom_object加入了之后
报错为: ‘tuple' object has no attribute 'layer'

@erichen510
Copy link
Author

网上说类似的错误因为 keras版本 我的keras 版本为 2.3.1 感觉旧版还是没法读取模型

1 similar comment
@erichen510
Copy link
Author

网上说类似的错误因为 keras版本 我的keras 版本为 2.3.1 感觉旧版还是没法读取模型

@bojone
Copy link
Owner

bojone commented Nov 5, 2019

错误信息详细点?另外建议试试0.2.3。

@bojone
Copy link
Owner

bojone commented Nov 5, 2019

要是不麻烦的话,重新训练吧。。。

@erichen510
Copy link
Author

好的 我明天重新训练下 谢谢苏神

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants