Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the TensorFlow version of some JAX utilities #59

Merged
merged 28 commits into from
Aug 26, 2020

Conversation

DarrenZhang01
Copy link
Contributor

@DarrenZhang01 DarrenZhang01 commented Aug 24, 2020

As titled, this Pull Request contains some necessary TF-based helper APIs for Neural Tangents, and they will be served as the main support.

Copy link
Contributor

@romanngg romanngg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look great! Left a few minor comments.

tf_helpers/tf_dot_general.py Outdated Show resolved Hide resolved
tf_helpers/tf_dot_general.py Outdated Show resolved Hide resolved
tf_helpers/tf_dot_general.py Outdated Show resolved Hide resolved
Construct an equivalent general dot operation as that in JAX -
<https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dot_general.html>

Although there is an implementation in TF XLA, avoid directly using XLA when
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, great question! I think Ashish mentioned that we want the newly constructed APIs - general conv, etc to be able to inter-op with the rest of the TF ecosystem, but I guess TF XLA is too low-level and independent. @wangpengmit Would you mind verifying this?

tf_helpers/tf_dot_general.py Outdated Show resolved Hide resolved
tf_helpers/tf_dot_general.py Outdated Show resolved Hide resolved
tf_helpers/tf_dot_general.py Outdated Show resolved Hide resolved
@DarrenZhang01
Copy link
Contributor Author

Tests coming soon.

@DarrenZhang01 DarrenZhang01 changed the title Add the TensorFlow version of general dot operation to tf_hlpers Add the TensorFlow version of some JAX lax utilities Aug 25, 2020
tf_helpers/lax.py Outdated Show resolved Hide resolved
tf_helpers/lax.py Outdated Show resolved Hide resolved
tf_helpers/lax.py Outdated Show resolved Hide resolved
tf_helpers/lax.py Show resolved Hide resolved
tf_helpers/lax_tests.py Outdated Show resolved Hide resolved
tf_helpers/lax_tests.py Outdated Show resolved Hide resolved
.travis.yml Outdated Show resolved Hide resolved
.travis.yml Outdated Show resolved Hide resolved
tf_helpers/lax.py Outdated Show resolved Hide resolved
@DarrenZhang01
Copy link
Contributor Author

I will also add ostax into this pull request shortly.

@DarrenZhang01
Copy link
Contributor Author

Probably not going to add test cases for ostax since there are some dependencies that are still not there.

@DarrenZhang01 DarrenZhang01 changed the title Add the TensorFlow version of some JAX lax utilities Add the TensorFlow version of some JAX utilities Aug 26, 2020
tf_helpers/tf_jax_stax.py Outdated Show resolved Hide resolved
tf_helpers/tf_jax_stax.py Outdated Show resolved Hide resolved
tf_helpers/lax.py Show resolved Hide resolved
.travis.yml Outdated Show resolved Hide resolved
tests/lax_test.py Outdated Show resolved Hide resolved
tf_helpers/tf_jax_stax.py Outdated Show resolved Hide resolved
tf_helpers/tf_jax_stax.py Outdated Show resolved Hide resolved
tf_helpers/tf_jax_stax.py Outdated Show resolved Hide resolved
tf_helpers/lax.py Outdated Show resolved Hide resolved
tf_helpers/lax.py Outdated Show resolved Hide resolved
DarrenZhang01 added 2 commits August 26, 2020 16:16
the TF pool shape checker in order to make the TF `reduce_window` API
consistent with JAX `reduce_window`.
@DarrenZhang01
Copy link
Contributor Author

DarrenZhang01 commented Aug 26, 2020

tf_helpers/stax.py Outdated Show resolved Hide resolved
tf_helpers/stax.py Outdated Show resolved Hide resolved
tf_helpers/lax.py Outdated Show resolved Hide resolved
@romanngg romanngg merged commit a240b24 into google:neural-tangents-tf Aug 26, 2020
DarrenZhang01 pushed a commit to DarrenZhang01/neural-tangents that referenced this pull request Sep 2, 2020
DarrenZhang01 pushed a commit to DarrenZhang01/neural-tangents that referenced this pull request Sep 2, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants