You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
#10453 breaks the ability to use negative-values (a la negative indexing) for the argnums arg in jax.grad, resulting in a TypeError about positional arguments.
Here is a minimal repro (yay!) that mirrors the way I use this in practice.
Thanks for the report – I'm not sure negative indices for argnums were ever intentionally supported; at least, we never mention this in the docs or cover it in tests (let me know if I'm mistaken on that)
My inclination here is to "fix" this by adding assertions that argnums must be positive – what do you think?
@jakevdp I'm working on argument 'annotation' validation for jax.jit now (with the intention of expanding to other places where argument annotation is used later, including grad). @hawkinsp has kindly provided feedback over in #10603
I started a discussion issue #10614 that outlines how I think the annotation feature could be improved and made more consistent across different functions.
If there is appetite for negative argnum support, I'd be happy to add that into the work I hope to do as part of #10614
I assigned Matt, who authored #10453, so that we have an assignee. Additionally it seems that @JeppeKlitgaard may also be interested in contributing here as well!
#10453 breaks the ability to use negative-values (a la negative indexing) for the argnums arg in
jax.grad
, resulting in a TypeError about positional arguments.Here is a minimal repro (yay!) that mirrors the way I use this in practice.
yields this stack trace:
The text was updated successfully, but these errors were encountered: