-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add a shape mismatch check and error to custom_vjp #19009
Merged
copybara-service
merged 1 commit into
google:main
from
mattjj:custom-vjp-shape-mismatch-error
Mar 14, 2024
Merged
add a shape mismatch check and error to custom_vjp #19009
copybara-service
merged 1 commit into
google:main
from
mattjj:custom-vjp-shape-mismatch-error
Mar 14, 2024
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
froystig
approved these changes
Dec 16, 2023
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice find and nice fix!
google-ml-butler
bot
added
kokoro:force-run
pull ready
Ready for copybara import and testing
labels
Dec 16, 2023
mattjj
force-pushed
the
custom-vjp-shape-mismatch-error
branch
from
December 16, 2023 01:44
c9834ea
to
9292bfc
Compare
mattjj
force-pushed
the
custom-vjp-shape-mismatch-error
branch
4 times, most recently
from
December 16, 2023 02:35
06b4c0b
to
e8dd4c3
Compare
mattjj
added a commit
to mattjj/jax
that referenced
this pull request
Dec 20, 2023
* remove the dead code KeyTangentTy * replace TyRules.make_tangent with TyRules.zero * removed ad.instantiate_zeros_aval, which was redundant with ad.instantiate_zeros ever since (1) we removed units and (2) we made Zero carry an aval on it * fix a bug in backward_pass where we instantiated a Zero at the primal type rather than the corresponding tangent type * fix _f_bwd in test_keyarray_custom_vjp, which had the wrong type (need to return cotangents for all inputs, we were returning a (float_tangent, key_tangent) pair instead of a (float_tangent, (float_tangent, key_tangent)) nested tuple, see google#19009 for a check which catches this and hence includes the same test change We probably also need a TyRules.add for any extended dtypes that can occur as tangent dtypes, but we currently don't have any tests that exercise that (because all extended dtype tangent types are currently float0). I have some follow-up work to add such a case though!
mattjj
added a commit
to mattjj/jax
that referenced
this pull request
Dec 20, 2023
* remove the dead code KeyTangentTy * replace TyRules.make_tangent with TyRules.zero * removed ad.instantiate_zeros_aval, which was redundant with ad.instantiate_zeros ever since (1) we removed units and (2) we made Zero carry an aval on it * fix a bug in backward_pass where we instantiated a Zero at the primal type rather than the corresponding tangent type * fix _f_bwd in test_keyarray_custom_vjp, which had the wrong type (need to return cotangents for all inputs, we were returning a (float_tangent, key_tangent) pair instead of a (float_tangent, (float_tangent, key_tangent)) nested tuple, see google#19009 for a check which catches this and hence includes the same test change We probably also need a TyRules.add for any extended dtypes that can occur as tangent dtypes, but we currently don't have any tests that exercise that (because all extended dtype tangent types are currently float0). I have some follow-up work to add such a case though!
mattjj
added a commit
to mattjj/jax
that referenced
this pull request
Dec 20, 2023
* remove the dead code KeyTangentTy * replace TyRules.make_tangent with TyRules.zero * removed ad.instantiate_zeros_aval, which was redundant with ad.instantiate_zeros ever since (1) we removed units and (2) we made Zero carry an aval on it * fix a bug in backward_pass where we instantiated a Zero at the primal type rather than the corresponding tangent type * fix _f_bwd in test_keyarray_custom_vjp, which had the wrong type (need to return cotangents for all inputs, we were returning a (float_tangent, key_tangent) pair instead of a (float_tangent, (float_tangent, key_tangent)) nested tuple, see google#19009 for a check which catches this and hence includes the same test change We probably also need a TyRules.add for any extended dtypes that can occur as tangent dtypes, but we currently don't have any tests that exercise that (because all extended dtype tangent types are currently float0). I have some follow-up work to add such a case though!
mattjj
added a commit
to mattjj/jax
that referenced
this pull request
Dec 20, 2023
* remove the dead code KeyTangentTy * replace TyRules.make_tangent with TyRules.zero * removed ad.instantiate_zeros_aval, which was redundant with ad.instantiate_zeros ever since (1) we removed units and (2) we made Zero carry an aval on it * fix a bug in backward_pass where we instantiated a Zero at the primal type rather than the corresponding tangent type * fix _f_bwd in test_keyarray_custom_vjp, which had the wrong type (need to return cotangents for all inputs, we were returning a (float_tangent, key_tangent) pair instead of a (float_tangent, (float_tangent, key_tangent)) nested tuple, see google#19009 for a check which catches this and hence includes the same test change We probably also need a TyRules.add for any extended dtypes that can occur as tangent dtypes, but we currently don't have any tests that exercise that (because all extended dtype tangent types are currently float0). I have some follow-up work to add such a case though!
mattjj
added a commit
to mattjj/jax
that referenced
this pull request
Dec 20, 2023
* remove the dead code KeyTangentTy * replace TyRules.make_tangent with TyRules.zero * removed ad.instantiate_zeros_aval, which was redundant with ad.instantiate_zeros ever since (1) we removed units and (2) we made Zero carry an aval on it * fix a bug in backward_pass where we instantiated a Zero at the primal type rather than the corresponding tangent type * fix _f_bwd in test_keyarray_custom_vjp, which had the wrong type (need to return cotangents for all inputs, we were returning a (float_tangent, key_tangent) pair instead of a (float_tangent, (float_tangent, key_tangent)) nested tuple, see google#19009 for a check which catches this and hence includes the same test change We probably also need a TyRules.add for any extended dtypes that can occur as tangent dtypes, but we currently don't have any tests that exercise that (because all extended dtype tangent types are currently float0). I have some follow-up work to add such a case though!
mattjj
added a commit
to mattjj/jax
that referenced
this pull request
Dec 20, 2023
* remove the dead code KeyTangentTy * replace TyRules.make_tangent with TyRules.zero * removed ad.instantiate_zeros_aval, which was redundant with ad.instantiate_zeros ever since (1) we removed units and (2) we made Zero carry an aval on it * fix a bug in backward_pass where we instantiated a Zero at the primal type rather than the corresponding tangent type * fix _f_bwd in test_keyarray_custom_vjp, which had the wrong type (need to return cotangents for all inputs, we were returning a (float_tangent, key_tangent) pair instead of a (float_tangent, (float_tangent, key_tangent)) nested tuple, see google#19009 for a check which catches this and hence includes the same test change We probably also need a TyRules.add for any extended dtypes that can occur as tangent dtypes, but we currently don't have any tests that exercise that (because all extended dtype tangent types are currently float0). I have some follow-up work to add such a case though!
mattjj
force-pushed
the
custom-vjp-shape-mismatch-error
branch
from
March 14, 2024 02:42
e8dd4c3
to
92bad24
Compare
no idea how we lasted so long without this...
mattjj
force-pushed
the
custom-vjp-shape-mismatch-error
branch
from
March 14, 2024 02:57
92bad24
to
1326c74
Compare
copybara-service bot
pushed a commit
that referenced
this pull request
Mar 14, 2024
PiperOrigin-RevId: 615872948
copybara-service bot
pushed a commit
that referenced
this pull request
Mar 14, 2024
PiperOrigin-RevId: 615872948
copybara-service bot
pushed a commit
that referenced
this pull request
Mar 14, 2024
PiperOrigin-RevId: 615883208
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
no idea how we lasted so long without this...
Before this PR, we could get really crazy UnshapedArray errors, though I didn't make a repro of that. We could get downstream shape mismatch errors deep in the backward pass interpreter though, like in the test case.
After this PR:
We could make this error even better by referencing the corresponding input argument name and path, but I'm going to leave that for another day.