Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Proposal] Consistent argnums and argnames parameters for transformations #10614

Open
15 tasks
JeppeKlitgaard opened this issue May 6, 2022 · 20 comments
Open
15 tasks
Labels
enhancement New feature or request

Comments

@JeppeKlitgaard
Copy link
Contributor

JeppeKlitgaard commented May 6, 2022

Hey JAX team,

I have been trying to wrap my head around 'argument annotation` in JAX for a bit in the hopes of finding a more intuitive/consistent implementation, which has lead me to the big block of text below. I would be super keen to hear your thoughts as I try to dive deeper into the inner workings of JAX.

Lately there have been a number of issues requesting improvements to *_argnums and *_argnames parameters used in transformations in addition to other ergonomics improvements related to declaring which function arguments should be annotated with a given property. I figured it might be helpful to make an over-arching issue with the end goal of having a consistent, ergonomic way of specifying these parameters. Managing argument 'annotations' in transformations has definitely been one of the more frustrating experiences of learning JAX (which is otherwise entirely amazing, of course)

Related issues:

jax.jit correctly implements static_argnames even for cases with keyword-only arguments, which would suggest that it should be possible to add argnames equivalents to any function that currently only implements argnums.

An easier but less robust fix could be to map argnames to argnums using inspect (see discussion: #1159). This would likely not work for keyword-only arguments (though it might for things like donate_arg...?)

Current shortcomings

Currently even the most robust implementation of the 'argument annotation' mechanism behaves in a somewhat counter-intuitive way (although this is suggested in the fine print of the docstring, if one reads it with sufficient care):

def f(a, /, b, *, c):
    print(a, b, c)

jf = jit(f, static_argnames=("a", "b", "c"))
jf(1, 2, c=3)
> Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> 2 3
# Expected: 1 2 3

jf2 = jit(f, static_argnames=("b", "c"), static_argnums=(0,))
jf2(1, 2, c=3)
> 1 Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> 3
# Expected: 1 2 3

jf2(1, b=2, c=3)
> 1 2 3
# As expected

jf3 = jit(f, static_argnums=(0, 1, 2))
jf3(1, 2, c=3)
> 1 2 Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
# Expected: 1 2 3

jf3(1, b=2, c=3)
> 1 2 Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
# Expected: 1 2 3

The fact that we have one instance where we are able to get the expected result gives hope that a solution should be possible by inspecting the function and arguments and modifying static_argnums and static_argnames accordingly – or perhaps a better solution exists? Ideally we would want to avoid inspecting the arguments at call-time.

I have started toying with validation of static_argnums and static_argnames in #10603

Goals

My suggestion would be that a solution that fixes the inconsistencies above (or in the worst case documents them thoroughly) is found for jax.jit.

Once that is done, it would be great to see *_argnames and keyword-arg support added to other functions:

  • jax.experiment.pjit
  • jax.pmap
  • jax.value_and_grad
  • jax.custom_vjp
  • jax.custom_jvp
  • jax.hessian
  • jax.jacrev
  • jax.jacfwd
  • jax.grad

Additionally #10476 can be explored (could live in jax.experimental.annotations, if there is any interest for this feature at all)

Progress

  • Get feedback and decide on: (this issue)
    • Interface (potential changes in function signatures for argument annotations)
    • Behaviour
  • Document interface and behaviour (initial PR: [WIP] Document argument annotations #10677)
  • Make tests and ensure consistency for functions
@patrick-kidger
Copy link
Collaborator

patrick-kidger commented May 10, 2022

This is totally doable using inspect.signature. For example:

sig = inspect.signature(f)
sig = ...  # replace all defaults with False; elided for space
static_args = tuple(True if i in static_argnums else False for i in range(static_argnums))
static_kwargs = {k: True for k in static_argnames}
bound = sig.bind_partial(*static_args, **static_kwargs)
bound.apply_defaults()
static_args = bound.static_args
static_kwargs = bound.static_kwargs

which canonicalises args and kwargs based on the signature of the function.

Then if necessary do the same thing to the actual args and kwargs passed at runtime, and match them up.

(This is exactly how Equinox handles filter_{jit,grad,vmap,pmap,value_and_grad}.)

@JeppeKlitgaard
Copy link
Contributor Author

@patrick-kidger I think a native JAX solution might even be able to just lean on the existing argnames_partial/argnums_partial functions. This way the only changes needed would be to generate the extra entries in argnums and argnames at the beginning of methods like jit, grad, ... The lookup logic would then also work for nodiff and donate type arguments.

@JeppeKlitgaard
Copy link
Contributor Author

From #10669 it appears that negative argnums should also be made valid for static_argnums, in which case I think that convention should be carried across to nodiff, argnums, etc.

This will require some addition work on #10603.

Could @mattjj weigh in?

@mattjj
Copy link
Member

mattjj commented May 11, 2022

Supporting negative indices everywhere SGTM. (Sorry, haven't had time to reply to the broader proposal in more detail...)

@YouJiacheng
Copy link
Contributor

YouJiacheng commented May 13, 2022

@patrick-kidger It seems that your example has some typos, and bind_partial need static_args and static_kwargs doesn't overlap, so static_args = tuple(True if i in static_argnums else False for i in range(static_argnums)) seems wrong(and range(static_argnums) should be a typo I think).
I'll have a look at Equinox.
Okay it seems that Equinox have a different API interface, so it doesn't have this problem.

@patrick-kidger
Copy link
Collaborator

patrick-kidger commented May 13, 2022

If static_args and static_kwargs overlap then that's a user error, analogous to def f(a) ... f(1, a=2).

Indeed it should be `range(len(static_argnums)).

This was typed out without testing. If in doubt use the Equinox version, that definitely works ;) Equinox provides a superset of the interface being considered here - it also handles mapping over PyTrees, filter functions, auxiliary outputs, etc.

@YouJiacheng
Copy link
Contributor

YouJiacheng commented May 13, 2022

@patrick-kidger It is not a user error since static_args = tuple(True if i in static_argnums else False for i in range(static_argnums)) generate a full set args. And I doubt that it should be range(max(static_argnums)).

And static_args, static_kwargs = bound.args, bound.kwargs doesn't meet the need of JAX, since bound.kwargs only contains keyworld only arguments and var kwargs.

@patrick-kidger
Copy link
Collaborator

Ah! I see what you're saying. Sorry, yes, being a bit slow today. JAX uses an index-based way of selecting arguments and I was thinking of a mask-based way.

The basic principle holds, but you're right that the parsing would be a bit more involved.

@YouJiacheng
Copy link
Contributor

I have tried to implement it leveraging bind_partial, and finally I found that it might be better to manually parse it using inspect, similar to what JAX currently do.

@YouJiacheng
Copy link
Contributor

YouJiacheng commented May 13, 2022

My trial:

neg_argnums = tuple(argnum for argnum in argnums if argnum < 0)
argnums_set = set(argnum for argnum in argnums if argnum >= 0)

sentinel = object()
args = tuple(None if i in argnums_set else sentinel for i in range(max(argnums_set)))
kwargs = {k: None for k in argnames}
sig = inspect.signature(fun)
ba = inspect.BoundArguments(sig, sig.bind_partial(*args).arguments | sig.bind_partial(**kwargs))
args = ba.args
kwargs = ba.kwargs
# JAX need POSITIONAL_OR_KEYWORD, KEYWORD_ONLY and VAR_KEYWORD
# but ba.kwargs only contains KEYWORD_ONLY and VAR_KEYWORD

@JeppeKlitgaard
Copy link
Contributor Author

JeppeKlitgaard commented May 19, 2022

Idea - Interface discussion:

Don't use *_argnums and *_argnames at all, just a *_args parameter of type Sequence[str | int] where integers are taken as positions and strings are taken as argument names.

This is not only more succinct, but also allows us to maintain full backwards compatibility: argnums and argnames would continue to work as they currently do, but would give rise to a deprecation warning for a few versions before being removed (potentially until JAX 1.0, but preferably sooner).

Using the container class approach as proposed (proposal not finished) in #10746 would enable relatively painless support of argnames+argnums and args for the period of deprecation.

@JeppeKlitgaard
Copy link
Contributor Author

@mattjj I would love to hear some more thoughts from the JAX team. Could I get you to ping people that might be interested in this? @hawkinsp was really helpful in getting #10603 merged.

@danijar
Copy link

danijar commented Jun 2, 2022

@mattjj @hawkinsp I just wanted to flag this issue again, it would be great to make progress on this to improve JAX usability.

The approach of just having e.g. static_args that takes Sequence[str | int] seems nice, regardless of whether it's implemented via #10746 or something easier.

Background: The current behavior in JAX is somewhat broken, where static_argnums cannot be passed as kwargs and static_argnames cannot be passed positionally. Moreover, counting argnums has a good amount of mental overhead, especially when using function transformations that changes the function signature, whereas argnames isn't supported everywhere, e.g. in pmap.

@YouJiacheng
Copy link
Contributor

@danijar Actually if you only pass static_argnums or only pass static_argnames, static arguments can be passed as kwargs and can be passed positionally. However, if you pass static_argnums and static_argnames, the behavior is what you describe.
I try to unify this behavior in #10724

@JeppeKlitgaard
Copy link
Contributor Author

@danijar Thank you for highlighting this. Even after having spent a good bit of time with this particular part of the JAX source, the behaviour still manages to confuse me from time to time. #10746 is intended as more of a rough sketch, but I think having an immutable dataclass object and passing that around might be a good option.

Having had a very cursory look at the code, I think in most places I would be able to figure out how to expand the code to accept named arguments as well (many places have argnums but no equivalent argnames parameter).

@hawkinsp
Copy link
Member

hawkinsp commented Jun 3, 2022

(I just wanted to note that looking at this is still on my radar but I haven't had time to do so between travel and other higher priorities. Sorry for the delay.)

@danijar
Copy link

danijar commented Jun 4, 2022

Actually if you only pass static_argnums or only pass static_argnames, static arguments can be passed as kwargs and can be passed positionally.

Is this true for all functions? It's not what I've been seeing with jit or pmap.

Another failure case right now is d default arguments. If I mark an input with default value as static, it raises an error if the value isn't pissed in.

This should all be pretty easy to do with the inspect module by canonicalizing the function signature.

@carlosgmartin
Copy link
Contributor

I second the proposal of #10614 (comment) and #10614 (comment) to use *_args : Sequence[int | str]. Treat each element as an argument number if it's an int and argument name if it's a str.

@Conchylicultor
Copy link
Member

Any update on this ? Not being able to use kwargs in jax.grad makes the API very brittle (very easy to get arguments order wrong, especially when updating the function signature) and force bad programming practices.

@patrick-kidger
Copy link
Collaborator

jax.grad does support passing through keyword arguments -- just not as a differentiated argument.

Generally speaking I'd recommend against using jax.grad(argnums=...) -- instead just pack all of your differentiated quantities into a tuple that is passed through the first argument. This helps to avoid the issue you're describing.

In any case, I don't believe there are any plans to change the API for jax.grad.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

8 participants