-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Additional input validation for transformations #10603
Additional input validation for transformations #10603
Conversation
19c4c6f
to
6a99228
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks promising, thanks.
6a99228
to
01a7399
Compare
dac9154
to
460ecf2
Compare
da0cb1a
to
728beac
Compare
728beac
to
16c95ea
Compare
an idea: |
If I understand correctly, that would fall outside the scope of this PR. Further, it only addresses cases where bind is used (static) and not other argument annotations (donate). This PR strictly deals with argument annotation validation at transformation time. I believe (but can't recall, and am traveling atm) that the binding process happens when the transformed function is called, not when the transformation is performed. I am working on another PR that builds on top of this one which goes a little further, but this PR aims to be compatible and uncontroversial. |
I don't mean that the bind for call, I mean that if we use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking good to me, but I have a final few nits.
jax/_src/test_util.py
Outdated
_CPP_JITTER._name = "cpp" | ||
_PYTHON_JITTER = functools.partial(api._jit, False) | ||
_PYTHON_JITTER._name = "python" | ||
_NOOP_JITTER = lambda x: x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I'd leave this out of the list and add it locally in the test cases that explicitly want to test "without jit"? After all, it isn't a jit implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All the parameterised tests use the NOOP
jitter, so for now I think it should remain. Adding it in locally for every parameterised test would just lead to a lot of code duplication without any benefit, I think. Can always refactor this later if there are any cases where the no-op test case is not used.
16c95ea
to
25c9c4a
Compare
Implemented the requested changes. Outstanding nit: #10603 (review) Let me know what you think @hawkinsp |
7756c90
to
ce26ab3
Compare
Squashed to one commit and fixed a test. Ready for review/merge |
ce26ab3
to
e186b9e
Compare
e186b9e
to
90fe16f
Compare
90fe16f
to
838a053
Compare
Adds additional input validation for
*_argnames
and*_argnums
arguments tojax.jit
.Previously
static_argnums
andstatic_argnames
could easily lead to silently dynamic arguments. In fact, two such cases were found in thejax
source code using the new test coverage (c477ce3).Additional test coverage has been added.
Includes a small refactoring in
api.py
which introduces_jit
to reduce code duplication (this leads to a number of small changes in various tests). The refactoring has the additional benefit of making experimentation with #10476 easier.For discussion see issue: #10601
Other places that could use additional validation (just add plumbing to validation logic introduced by this PR)
jax.experiment.pjit
jax.pmap
jax.value_and_grad
jax.custom_vjp
jax.custom_jvp
jax.hessian
jax.jacrev
jax.jacfwd
jax.grad
I will add validation for the functions mentioned above in a separate PR once/if this one has been merged.
Handling of
*args
and**kwargs
In cases where variable positional and/or keyword arguments are used, we (rightly) assume that the function will know how to deal with these and thus it is safe to assume that for
*args
anyargnum
is valid. Similarly for**kwargs
anyargname
is valid.Importantly we still do some validation where possible. As an example, consider:
Additional improvements
argnums
could be patched to contain positional-only arguments given asstatic_argnames
usinginspect
.Currently:
This is due to the way
_infer_argnums_and_argnames
works, which might be worthwhile to change in a separate PR. See further discussion in #10614Fix: #10601
Fix: #10046