Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional input validation for transformations #10603

Merged

Conversation

JeppeKlitgaard
Copy link
Contributor

@JeppeKlitgaard JeppeKlitgaard commented May 6, 2022

Adds additional input validation for *_argnames and *_argnums arguments to jax.jit.

Previously static_argnums and static_argnames could easily lead to silently dynamic arguments. In fact, two such cases were found in the jax source code using the new test coverage (c477ce3).

Additional test coverage has been added.

Includes a small refactoring in api.py which introduces _jit to reduce code duplication (this leads to a number of small changes in various tests). The refactoring has the additional benefit of making experimentation with #10476 easier.

For discussion see issue: #10601

Other places that could use additional validation (just add plumbing to validation logic introduced by this PR)

  • jax.experiment.pjit
  • jax.pmap
  • jax.value_and_grad
  • jax.custom_vjp
  • jax.custom_jvp
  • jax.hessian
  • jax.jacrev
  • jax.jacfwd
  • jax.grad

I will add validation for the functions mentioned above in a separate PR once/if this one has been merged.

Handling of *args and **kwargs

In cases where variable positional and/or keyword arguments are used, we (rightly) assume that the function will know how to deal with these and thus it is safe to assume that for *args any argnum is valid. Similarly for **kwargs any argname is valid.

Importantly we still do some validation where possible. As an example, consider:

def f(a, /, b, *, c): ...

jit(f, static_argnames=("a",))   # This will fail since we know that `a` must be positional

jit(f, static_argnums=(2,))  # This will fail since `c` must be keyword

Additional improvements

argnums could be patched to contain positional-only arguments given as static_argnames using inspect.

Currently:

def f(a, /, b, *, c):
    print(a, b, c)

ff = jit(f, static_argnames=("a", "b", "c"))  # with validation disabled
ff(1, 2, c=3)
> Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> 2 3
# Expected: 1 2 3

ff2 = jit(f, static_argnames=("b", "c"), static_argnums=(0,))
ff2(1, 2, c=3)
> 1 Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> 3
# Expected 1 2 3

This is due to the way _infer_argnums_and_argnames works, which might be worthwhile to change in a separate PR. See further discussion in #10614

Fix: #10601
Fix: #10046

jax/_src/api.py Show resolved Hide resolved
jax/_src/api.py Show resolved Hide resolved
jax/_src/lax/qdwh.py Outdated Show resolved Hide resolved
jax/_src/scipy/linalg.py Show resolved Hide resolved
Copy link
Member

@hawkinsp hawkinsp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks promising, thanks.

jax/_src/api.py Outdated Show resolved Hide resolved
jax/_src/api.py Show resolved Hide resolved
jax/_src/api.py Outdated Show resolved Hide resolved
jax/_src/api.py Show resolved Hide resolved
jax/_src/api_util.py Outdated Show resolved Hide resolved
jax/_src/api_util.py Outdated Show resolved Hide resolved
jax/_src/api_util.py Outdated Show resolved Hide resolved
jax/_src/api_util.py Outdated Show resolved Hide resolved
jax/_src/api.py Outdated Show resolved Hide resolved
@hawkinsp hawkinsp added the pull ready Ready for copybara import and testing label May 9, 2022
@JeppeKlitgaard JeppeKlitgaard force-pushed the transformation-input-validation branch from 6a99228 to 01a7399 Compare May 9, 2022 16:52
@JeppeKlitgaard JeppeKlitgaard force-pushed the transformation-input-validation branch 3 times, most recently from dac9154 to 460ecf2 Compare May 12, 2022 12:37
@JeppeKlitgaard
Copy link
Contributor Author

@hawkinsp changes since last time are primarily support for negative argnums as per #10669

jax/_src/api.py Outdated Show resolved Hide resolved
jax/_src/api.py Outdated Show resolved Hide resolved
jax/_src/api.py Outdated Show resolved Hide resolved
jax/_src/lax/qdwh.py Outdated Show resolved Hide resolved
@JeppeKlitgaard JeppeKlitgaard force-pushed the transformation-input-validation branch 4 times, most recently from da0cb1a to 728beac Compare May 12, 2022 18:50
@JeppeKlitgaard JeppeKlitgaard marked this pull request as draft May 12, 2022 18:52
@JeppeKlitgaard JeppeKlitgaard force-pushed the transformation-input-validation branch from 728beac to 16c95ea Compare May 12, 2022 21:59
@JeppeKlitgaard JeppeKlitgaard marked this pull request as ready for review May 12, 2022 22:00
jax/_src/api_util.py Outdated Show resolved Hide resolved
@YouJiacheng
Copy link
Contributor

YouJiacheng commented May 13, 2022

an idea:
If we do arg and kwargs inference using inspect tools following #10614 (comment) (there are some bugs in that comment), we might only need to catch error from bind? But this might need a breaking change...

@JeppeKlitgaard
Copy link
Contributor Author

only

If I understand correctly, that would fall outside the scope of this PR. Further, it only addresses cases where bind is used (static) and not other argument annotations (donate).

This PR strictly deals with argument annotation validation at transformation time. I believe (but can't recall, and am traveling atm) that the binding process happens when the transformed function is called, not when the transformation is performed.

I am working on another PR that builds on top of this one which goes a little further, but this PR aims to be compatible and uncontroversial.

@YouJiacheng
Copy link
Contributor

YouJiacheng commented May 13, 2022

I don't mean that the bind for call, I mean that if we use Signature.bind_partial to inference argnums and argnames.
Okay I find it is hard to use Signature.bind_partial in current JAX API interface...

Copy link
Member

@hawkinsp hawkinsp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking good to me, but I have a final few nits.

_CPP_JITTER._name = "cpp"
_PYTHON_JITTER = functools.partial(api._jit, False)
_PYTHON_JITTER._name = "python"
_NOOP_JITTER = lambda x: x
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'd leave this out of the list and add it locally in the test cases that explicitly want to test "without jit"? After all, it isn't a jit implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the parameterised tests use the NOOP jitter, so for now I think it should remain. Adding it in locally for every parameterised test would just lead to a lot of code duplication without any benefit, I think. Can always refactor this later if there are any cases where the no-op test case is not used.

jax/_src/test_util.py Outdated Show resolved Hide resolved
jax/_src/api_util.py Outdated Show resolved Hide resolved
@JeppeKlitgaard
Copy link
Contributor Author

Implemented the requested changes.

Outstanding nit: #10603 (review)

Let me know what you think @hawkinsp

@JeppeKlitgaard JeppeKlitgaard force-pushed the transformation-input-validation branch 2 times, most recently from 7756c90 to ce26ab3 Compare May 17, 2022 22:05
@JeppeKlitgaard
Copy link
Contributor Author

JeppeKlitgaard commented May 17, 2022

Squashed to one commit and fixed a test. Ready for review/merge

tests/x64_context_test.py Show resolved Hide resolved
tests/jax_jit_test.py Outdated Show resolved Hide resolved
@JeppeKlitgaard JeppeKlitgaard force-pushed the transformation-input-validation branch from 90fe16f to 838a053 Compare May 18, 2022 20:54
@copybara-service copybara-service bot merged commit 478a95a into google:main May 19, 2022
@JeppeKlitgaard JeppeKlitgaard deleted the transformation-input-validation branch May 19, 2022 21:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
4 participants