-
Notifications
You must be signed in to change notification settings - Fork 227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add the TensorFlow version of some JAX utilities #59
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Look great! Left a few minor comments.
tf_helpers/tf_dot_general.py
Outdated
Construct an equivalent general dot operation as that in JAX - | ||
<https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dot_general.html> | ||
|
||
Although there is an implementation in TF XLA, avoid directly using XLA when |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah, great question! I think Ashish mentioned that we want the newly constructed APIs - general conv, etc to be able to inter-op with the rest of the TF ecosystem, but I guess TF XLA is too low-level and independent. @wangpengmit Would you mind verifying this?
Tests coming soon. |
tf_hlpers
TensorFlow-related ecosystem and the tests folder contain both the `lax` and its test cases.
I will also add |
Probably not going to add test cases for |
the TF pool shape checker in order to make the TF `reduce_window` API consistent with JAX `reduce_window`.
stax back to TF stax.
Mark UPDATE (2nd September, 2020): Finished in #63. |
As titled, this Pull Request contains some necessary TF-based helper APIs for Neural Tangents, and they will be served as the main support.