Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR 10453 breaks negative-valued argnums in `jax.grad #10630

Closed
ajbrock opened this issue May 9, 2022 · 4 comments
Closed

PR 10453 breaks negative-valued argnums in `jax.grad #10630

ajbrock opened this issue May 9, 2022 · 4 comments
Assignees
Labels
better_errors Improve the error reporting bug Something isn't working

Comments

@ajbrock
Copy link

ajbrock commented May 9, 2022

#10453 breaks the ability to use negative-values (a la negative indexing) for the argnums arg in jax.grad, resulting in a TypeError about positional arguments.

Here is a minimal repro (yay!) that mirrors the way I use this in practice.

import jax

def f(x, y):
  return x.sum() * y.sum()

g = jax.grad(f, argnums=-1)
x = jax.random.normal(jax.random.PRNGKey(0), (16, 16))
y = jax.random.normal(jax.random.PRNGKey(1), (16, 16))
g(x, y)

yields this stack trace:

---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
[<ipython-input-4-64035b120794>]() in <module>()
      8 y = jax.random.normal(jax.random.PRNGKey(1), (16, 16))
----> 9 g(x, y)

10 frames
[jax/_src/traceback_util.py]() in reraise_with_filtered_traceback(*args, **kwargs)
    161     try:
--> 162       return fun(*args, **kwargs)
    163     except Exception as e:

[jax/_src/api.py]() in grad_f(*args, **kwargs)
    904   def grad_f(*args, **kwargs):
--> 905     _, g = value_and_grad_f(*args, **kwargs)
    906     return g

[jax/_src/traceback_util.py]() in reraise_with_filtered_traceback(*args, **kwargs)
    161     try:
--> 162       return fun(*args, **kwargs)
    163     except Exception as e:

[jax/_src/api.py]() in value_and_grad_f(*args, **kwargs)
    980     if not has_aux:
--> 981       ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
    982     else:

[jax/_src/api.py]() in _vjp(fun, has_aux, reduce_axes, *primals)
   2442     flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 2443     out_primal, out_vjp = ad.vjp(
   2444         flat_fun, primals_flat, reduce_axes=reduce_axes)

[jax/interpreters/ad.py]() in vjp(traceable, primals, has_aux, reduce_axes)
    129   if not has_aux:
--> 130     out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
    131   else:

[jax/interpreters/ad.py]() in linearize(traceable, *primals, **kwargs)
    118   jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
--> 119   jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
    120   out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)

[/jax/_src/profiler.py]() in wrapper(*args, **kwargs)
    205     with TraceAnnotation(name, **decorator_kwargs):
--> 206       return func(*args, **kwargs)
    207     return wrapper

[jax/interpreters/partial_eval.py]() in trace_to_jaxpr_nounits(fun, pvals, instantiate)
    607     fun = trace_to_subjaxpr_nounits(fun, main, instantiate)
--> 608     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    609     assert not env

[jax/linear_util.py]() in call_wrapped(self, *args, **kwargs)
    167     try:
--> 168       ans = self.f(*args, **dict(self.params, **kwargs))
    169     except:

UnfilteredStackTrace: TypeError: f() takes 2 positional arguments but 3 were given

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
[<ipython-input-4-64035b120794>]() in <module>()
      7 x = jax.random.normal(jax.random.PRNGKey(0), (16, 16))
      8 y = jax.random.normal(jax.random.PRNGKey(1), (16, 16))
----> 9 g(x, y)
**TypeError: f() takes 2 positional arguments but 3 were given**
@ajbrock ajbrock added the bug Something isn't working label May 9, 2022
@jakevdp
Copy link
Collaborator

jakevdp commented May 9, 2022

Thanks for the report – I'm not sure negative indices for argnums were ever intentionally supported; at least, we never mention this in the docs or cover it in tests (let me know if I'm mistaken on that)

My inclination here is to "fix" this by adding assertions that argnums must be positive – what do you think?

Paging @mattjj for input here.

@JeppeKlitgaard
Copy link
Contributor

JeppeKlitgaard commented May 9, 2022

@jakevdp I'm working on argument 'annotation' validation for jax.jit now (with the intention of expanding to other places where argument annotation is used later, including grad). @hawkinsp has kindly provided feedback over in #10603

I started a discussion issue #10614 that outlines how I think the annotation feature could be improved and made more consistent across different functions.

If there is appetite for negative argnum support, I'd be happy to add that into the work I hope to do as part of #10614

@froystig froystig added the better_errors Improve the error reporting label May 10, 2022
@froystig
Copy link
Member

I assigned Matt, who authored #10453, so that we have an assignee. Additionally it seems that @JeppeKlitgaard may also be interested in contributing here as well!

@JeppeKlitgaard
Copy link
Contributor

This was fixed by @mattjj in #10669

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
better_errors Improve the error reporting bug Something isn't working
Projects
None yet
Development

No branches or pull requests

6 participants