New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
why so slow when I get result from estimator.predict(input_fn)? #287
Comments
I have the same problem, using estimator.predict( ), it cost me a lot of time, can you tell me how to solve it, thanks |
btw, I'm using estimator.predict() to get result from bert |
because of the generator, it has to recover "context", all resources about this function. this func you can see in this repo, author uses |
@nlpwhp can you give me a link of where the author use pyzmq, really appreciate |
you can see |
@nlpwhp I'm really awkward that I still don't know how to solve my problem. can you give me an example to use this in result predict, I'm a fish. thanks bro. hahhahah |
em sorry, I just use this repo, it's convenient .... and I'm not sure what is your specific task. If you try to deploy BERT-based service by some frameworks you familiar with (flask, django, ...) , you can implement one server to start BERT and wait for data by pyzmq and one client for getting request and sending data to server. Or, if you just wanna run your service quickly, using this repo is the good way. btw, if you wanna change service kind, because this repo is based on encoding-task, you can change your model here: |
really really thanks for you patient reply, I just use google-research/bert for text_classification task, and I found it really slow when calling |
oh, I see, you have no requirements on service. In that case, this repo cannot help you a lot. I guess you run your BERT-classification on your personal PC with CPU-device. |
@nlp4whp 我想我还是用中文来说明下我的问题吧。其实我有机器:Ubuntu GPU 20G 内存, 我使用如下代码:
|
嗯懂了,有两种办法,一种办法是你修改bert-as-service的源码再用pip安装,需要读一下源码,修改你定义的模型以及输入输出的数据格式
另一种是你自己写一个server-client模式的代码,我先大概说一下思路,estimator.predict会不断从
predict_input_fn这个函数里取数据,所以只要predict_input_fn有数据吐就好
因此,你需要一个单独的client.py负责发送数据给
predict_input_fn,然后接收server端算好的数据,这个就是pyzmq做的,你需要稍微了解下pyzmq用法,流程大概下面这样
```
# server.py
# 这里,只要调用了recv(),predict_input_fn就不会结束,会一直等client端发送数据,所以下面的
estimator.predict也不会结束,你不用多次执行
def predict_input_fn():
request_data = pyzmq_obj_s.recv()
features = preprocess(request_data)
def send_to_client(result):
res = preprocess(request_data)
pyzmq_obj_s.send(res)
# 这里,BERT执行完再返回给client端
for result in estimator.predict(input_fn=predict_input_fn):
send_to_client(result)
```
```
# client.py
def send_to_server(data):
pyzmq_obj_c.send(data)
def wait_res(res):
res = pyzmq_obj_c.recv()
return preprocess(res)
```
Biaocsu <notifications@github.com> 于2019年7月1日周一 下午1:03写道:
… @nlp4whp <https://github.com/nlp4whp> 我想我还是用中文来说明下我的问题吧。其实我有机器:Ubuntu GPU
20G 内存, 我使用如下代码:
`def predicts(file_dic):
if FLAGS.do_predict:
predict_examples = []
for (index, file) in enumerate(file_dic):
with open(file, 'r', encoding='utf-8', errors='ignore') as f:
text = f.read()
# text = re.sub(r"[^\u4E00-\u9FA5]", '', text)
# text = re.sub("转载|首发|喜讯|推荐|编辑|人气|禁闻|大纪元|新唐人|文化大观|翻墙必看|健康医疗|一周大事解读|热点透视", '', text)
# if len(text) > 512:
# text = text[:128] + text[-382:]
guid = 'test-%d' % index
text_a = tokenization.convert_to_unicode(text)
predict_examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label='0'))
predict_features = convert_examples_to_features(predict_examples, label_list, FLAGS.max_seq_length, tokenizer)
predict_input_fn = input_fn_builder(
features=predict_features,
seq_length=FLAGS.max_seq_length,
is_training=False,
drop_remainder=predict_drop_remainder)
result = estimator.predict(input_fn=predict_input_fn)
# file_based_convert_examples_to_features(predict_examples, label_list, FLAGS.max_seq_length, tokenizer, predict_file)
# predict_input_fn = file_based_input_fn_builder(
# input_file=predict_file,
# seq_length=FLAGS.max_seq_length,
# is_training=False,
# drop_remainder=predict_drop_remainder)
# result = estimator.predict(input_fn=predict_input_fn)
res_dic = {}
tf.logging.info("***** Predict results *****")
j = 0
for prediction in result:
probabilities = prediction["probabilities"]
dicts = {}
for i in range(len(probabilities)):
dicts[label_list[i]] = probabilities[i]
dicts = sorted(dicts.items(), key=lambda x: x[1], reverse=True)
res_dic[j] = [dicts[0][0], dicts[0][1]]
j += 1
return res_dic
将file_dic目录下所有文件处理后,一次性通过estimator.predict
来预测(给出所有文件的结果),这样平均每个文件需要耗时0.4s。但是因为其他原因,我不能将所有数据一次性通过estimator.predict
预测给出,我需要多次调用estimator.predict`(即每个txt文件调用一次),这样就会非常耗时了(平均5s)。你知道使用bert-as-service
可以解决这个问题吗?因为我也相当于刚入门NLP,所以并不知道这个”上游pro“应用到哪里可以起作用。再次感谢您,谢谢
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#287>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AJ5DRTPKD2NLLNEYNMFZ2NDP5GFZRANCNFSM4G73YOMA>
.
|
谢谢兄弟,多谢 |
Hi,
Thanks in advance for all the help.
I try to implement a simple version. just create
estimator
in flask and dopredict
what confused me is, after
results=estimator.predict()
, this loop segment always cost more than 10 seconds, like this:I've no idea why...does anyone meet issues like this?
btw. I notice that zmq is used
but I don't think that's the reason for me....
The text was updated successfully, but these errors were encountered: