Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Keras LSTM support #1752

Merged
merged 1 commit into from Oct 25, 2021
Merged

Conversation

q-ycong-p
Copy link
Contributor

This commit adds Keras LSTM support in response to this issue.
The changes include:

  1. Modified existing LSTM pattern-matching and rewriters to handle Keras LSTM;
  2. Added unit tests to ensure no loops in converted ONNX models from Keras LSTM. (with reference to previous commit that added Keras GRU).

@lgtm-com
Copy link

lgtm-com bot commented Oct 22, 2021

This pull request introduces 1 alert when merging ca4d655 into 42e800d - view on LGTM.com

new alerts:

  • 1 for Unused local variable

@lgtm-com
Copy link

lgtm-com bot commented Oct 23, 2021

This pull request introduces 1 alert when merging 2583cd3 into 42e800d - view on LGTM.com

new alerts:

  • 1 for Unused local variable

@lgtm-com
Copy link

lgtm-com bot commented Oct 23, 2021

This pull request introduces 1 alert when merging c227c7a into 42e800d - view on LGTM.com

new alerts:

  • 1 for Unused local variable

@q-ycong-p q-ycong-p force-pushed the keras_lstm_dev branch 2 times, most recently from 78460f4 to e4a26b8 Compare October 23, 2021 08:36
@q-ycong-p
Copy link
Contributor Author

q-ycong-p commented Oct 23, 2021

Hi @TomWildenhain-Microsoft and other contributors. Any suggestion on below issue?

I've encountered failure against tf2.6: in tests/test_backend.py/test_rfft_ops on line-5497. Full error log found in pipeline history. In short, it complains RuntimeError: Failed to run tfjs model: Error: Argument 'x' passed to 'cos' must be float32 tensor, but got complex64 tensor when executing tf.cos(tf.signal.rfft(x), name=_TFOUTPUT). I looked into it but cannot associate this error with my changes which modified pattern-matching and rewriter for keras LSTM only. I traced to def _rfft(...) but has no further clue...

I've tested in my local conda env (linux system): tensorflow==2.6.0; onnxruntime==1.9.0; python==3.9.0; onnx==1.10.0; numpy==1.21.3, which is same as the said pipeline test case. I couldn't reproduce this error locally. All tests passed for me locally. Any suggestion is appreciated here!

@TomWildenhain-Microsoft
Copy link
Contributor

@q-ycong-p The issue is a failing tfjs test. I don't think it is related to your change. Add an @skip_tfjs("Fails to run tfjs model") decorator to the test.

@q-ycong-p
Copy link
Contributor Author

@TomWildenhain-Microsoft I added the decorator as suggested, see here (though we might want to also figure out why it failed?). The pipeline is now green. The PR is ready for review/merge. Thanks!

@TomWildenhain-Microsoft
Copy link
Contributor

Awesome, Thanks!

we might want to also figure out why it failed?

Yeah, we have to skip tfjs tests decently regularly. It might be that the script for running the tfjs model isn't converting from float to complex, or that tensorflow's tfjs converter isn't properly converting the model (not uncommon). In any case, the test is failing before it even gets to the converter. It is running as a tf model, but fails to run after using tensorflow's tfjs model converter. So it doesn't tell us anything about whether our conversion is working.

onnx_model = convert_keras(model, name=model.name)
if gru_class.__module__.split('.')[-1] == "recurrent_v2":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a huge fan of this check. Can you just do gru_class == recurrent_v2.GRU?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or make GRU_CLASSES a tuple like you did with LSTM_CLASSES.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gru_class == recurrent_v2.GRU alone doesn't seen to do, because the class name is evaluated to equal tensorflow.python.keras.layers.recurrent_v2.GRU in my current env. However I was hesitant to match exactly this. I assume it would vary depending on TF versions which import keras relies on, see line 40 to line 50 here.

I propose to change GRU_CLASSES to the same structure as LSTM_CLASSES, i.e., a list of tuples where each tuple contains a rnn_version to be used like here.

Would it be a good practice? Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, the tuple option is good then.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Pipeline is green.

Copy link
Contributor

@TomWildenhain-Microsoft TomWildenhain-Microsoft left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall great, really appreciate it. Very thorough tests. LSTM patterns can be tricky.

Signed-off-by: congyc <congyc@amazon.com>
@TomWildenhain-Microsoft TomWildenhain-Microsoft merged commit f19a036 into onnx:master Oct 25, 2021
@q-ycong-p
Copy link
Contributor Author

Thank you @TomWildenhain-Microsoft for reviewing and helping fix this issue!

@q-ycong-p
Copy link
Contributor Author

@TomWildenhain-Microsoft i saw the errors in onnxruntime-nightly-unittest-matrix for latest commit. This issue seems to be directly related. The suggested solution is either, change python and/or numpy vresions, or change tensorflow source code as suggested. Either doesn't seem doable here.

I'm not sure if we should add @skip_tf_versions("2.1", "Bug in TF 2.1") decorator to problematic test LSTMTests.test_keras_lstm? Or there's other reason causing it, or better way fixing it?

@guschmue
Copy link
Collaborator

Thanks for debugging it!! Yes, little we can do on it since it is a tf-2.1 issue.
@skip_tf_versions("2.1", "Bug in TF 2.1") seems to be the best option.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

tf.keras.layers.LSTM not converted to ONNX LSTM layer
3 participants