New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Keras LSTM support #1752
Add Keras LSTM support #1752
Conversation
48b00d5
to
ca4d655
Compare
This pull request introduces 1 alert when merging ca4d655 into 42e800d - view on LGTM.com new alerts:
|
ca4d655
to
2583cd3
Compare
This pull request introduces 1 alert when merging 2583cd3 into 42e800d - view on LGTM.com new alerts:
|
2583cd3
to
c227c7a
Compare
This pull request introduces 1 alert when merging c227c7a into 42e800d - view on LGTM.com new alerts:
|
78460f4
to
e4a26b8
Compare
Hi @TomWildenhain-Microsoft and other contributors. Any suggestion on below issue? I've encountered failure against tf2.6: in I've tested in my local conda env (linux system): tensorflow==2.6.0; onnxruntime==1.9.0; python==3.9.0; onnx==1.10.0; numpy==1.21.3, which is same as the said pipeline test case. I couldn't reproduce this error locally. All tests passed for me locally. Any suggestion is appreciated here! |
e4a26b8
to
a816aa1
Compare
@q-ycong-p The issue is a failing tfjs test. I don't think it is related to your change. Add an |
a816aa1
to
57a4206
Compare
@TomWildenhain-Microsoft I added the decorator as suggested, see here (though we might want to also figure out why it failed?). The pipeline is now green. The PR is ready for review/merge. Thanks! |
Awesome, Thanks!
Yeah, we have to skip tfjs tests decently regularly. It might be that the script for running the tfjs model isn't converting from float to complex, or that tensorflow's tfjs converter isn't properly converting the model (not uncommon). In any case, the test is failing before it even gets to the converter. It is running as a tf model, but fails to run after using tensorflow's tfjs model converter. So it doesn't tell us anything about whether our conversion is working. |
onnx_model = convert_keras(model, name=model.name) | ||
if gru_class.__module__.split('.')[-1] == "recurrent_v2": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not a huge fan of this check. Can you just do gru_class == recurrent_v2.GRU
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or make GRU_CLASSES a tuple like you did with LSTM_CLASSES.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gru_class == recurrent_v2.GRU
alone doesn't seen to do, because the class name is evaluated to equal tensorflow.python.keras.layers.recurrent_v2.GRU
in my current env. However I was hesitant to match exactly this. I assume it would vary depending on TF versions which import keras
relies on, see line 40 to line 50 here.
I propose to change GRU_CLASSES to the same structure as LSTM_CLASSES, i.e., a list of tuples where each tuple contains a rnn_version
to be used like here.
Would it be a good practice? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, the tuple option is good then.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done! Pipeline is green.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall great, really appreciate it. Very thorough tests. LSTM patterns can be tricky.
Signed-off-by: congyc <congyc@amazon.com>
57a4206
to
8a14f66
Compare
Thank you @TomWildenhain-Microsoft for reviewing and helping fix this issue! |
@TomWildenhain-Microsoft i saw the errors in onnxruntime-nightly-unittest-matrix for latest commit. This issue seems to be directly related. The suggested solution is either, change python and/or numpy vresions, or change tensorflow source code as suggested. Either doesn't seem doable here. I'm not sure if we should add @skip_tf_versions("2.1", "Bug in TF 2.1") decorator to problematic test LSTMTests.test_keras_lstm? Or there's other reason causing it, or better way fixing it? |
Thanks for debugging it!! Yes, little we can do on it since it is a tf-2.1 issue. |
This commit adds Keras LSTM support in response to this issue.
The changes include: