-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
grad + vmap + odeint AssertionError #8783
Comments
I'm not sure, but I think @danieldjohnson told me about this error a long time ago (maybe back in May?), and shared this repro which hits the same thing: import functools
import jax
import jax.numpy as jnp
@functools.partial(jax.custom_vjp, nondiff_argnums=(0,))
def foo_custom_vjp(kernel_closure, aux_args):
return kernel_closure(*aux_args)
def foo_fwd(kernel_closure, aux_args):
out = foo_custom_vjp(kernel_closure, aux_args)
# Output must be saved for the error to occur.
return out, (out,)
def foo_bwd(kernel_closure, saved, table_bar):
return ([jnp.zeros([8, 10])],)
raise NotImplementedError("JAX problem occurs before backward pass executes")
foo_custom_vjp.defvjp(foo_fwd, foo_bwd)
def repro(differentiable_arg, batched_arg):
def kernel_closure(differentiable_arg):
# This function closes over batched_arg, which is a batched NDArray;
# it then gets nondiff-argnumed into foo_custom_vjp
return jnp.exp(differentiable_arg + batched_arg)
result = foo_custom_vjp(kernel_closure, [differentiable_arg])
return jnp.mean(result)
def batched_repro(soft_prediction, samples, batch_size):
def go(sample):
return repro(soft_prediction, sample)
return jnp.mean(jax.vmap(go)(samples))
@functools.partial(jax.jit)
def trigger(soft_prediction, samples):
grads = jax.grad(batched_repro)(soft_prediction, samples, batch_size=3)
return grads
trigger(jnp.zeros([8, 10]), jnp.zeros([7, 10], dtype=jnp.int32)) I vaguely recall that:
Then I failed to land that branch... and now I've forgotten everything I figured out before! |
Okay, I'm starting to get the issue again... I'm going to jot some notes here but they probably won't be comprehensible without detailed knowledge of JAX internals. These examples exercise The problem is that the number of outputs to this Backing up, the reason this is happening, and the reason The solution is to unwrap and re-wrap output Tracers, and call back into the Traces we forgot to process so they're aware that a call primitive was bound. That's exactly what the Coming back to this particular issue, starting in the case of We should plumb enough information into this function so that it knows what's going to happen and so that it can check and manipulate |
One last thing to note: in the |
Not sure if this is the place for it, but can you explain what
means? Specifically: what does it mean for a function to "close over Tracers". I've seen this statement appear in a bunch of places (while looking up this particular issue). Even without understanding the JAX internals much it seems like something I might be able to recognize in the future. |
Sure! Apologies for the jargon. Here's an example program, without any JAX for now: def f(x):
def g(y):
return x * y
return g(3.)
f(2.) Here I'd say the inner function This is interesting because in some sense the value of Let's add JAX to the example: def f(x):
@jax.jit
def g(y):
return x * y
return g(3.)
jax.grad(f)(2.) Here again the function But notice that when we call the Because the Tracers that a Python callable closes over can't easily be inspected ahead-of-time, you might only find out what Tracers are in a function's closure (and hence which transformations must be applied to the function) after running it. You might imagine that makes tracing a bit tricky. The general mechanism which sorts out these issues is the thing I alluded to in above comments. And it's buggy for WDYT? |
Ah yeah, makes sense! Thanks for the clarification. I saw in a previous issue you'd recommended swapping the order of grad and vmap to get around this, would you still recommend doing this here? |
I'm not sure. But this is a bad bug and I plan to fix this in the next few days... |
Okay, let's try to crush this bug! To start, here's a repro of the import jax
import jax.numpy as jnp
def h(z):
def f(x):
@jax.custom_jvp
def g(y):
return x * y
# NOTE: rule closes over vmap tracer
@g.defjvp
def g_jvp(primals, tangents):
(y,), (ydot,) = primals, tangents
return x * y, x * ydot
return g(z) # NOTE: no vmapped arg
return jax.vmap(f)(jnp.arange(3.))
jax.jvp(h, (1.,), (2.,)) |
And here's a corresponding import jax
import jax.numpy as jnp
def h(z):
def f(x):
@jax.custom_vjp
def g(y):
return x * y
def g_fwd(y):
return x * y, (x * y, y)
def g_rev(xys, w_bar):
xy, _ = xys
return (xy * w_bar,)
g.defvjp(g_fwd, g_rev)
return g(z)
return jax.vmap(f)(jnp.arange(3.)).sum()
jax.grad(h)(1.) I have a fix, in #8915, but I need to look it over and decide if there are things to clean up, and whether to break it into multiple PRs (because there were a couple cleanups I did which made the fix easier but could be made independent). |
Actually I just noticed that #8915 fixes my repro as well as the code in this above comment, but the code in the OP runs into a different error. I think it's a separate bug for me to fix... |
related to google#8783, doesn't completely fix it
related to google#8783, doesn't completely fix it
related to google#8783, doesn't completely fix it
related to google#8783, doesn't completely fix it
related to google#8783, doesn't completely fix it
related to google#8783, doesn't completely fix it
related to google#8783, doesn't completely fix it
related to google#8783, doesn't completely fix it
#8915 finally went in, but IIRC last I tried it the repro in the OP now runs into a distinct second issue. |
Ah! The distinct second issue is just that odeint has a check for floating point time values, and it raises an error if it's given complex values (as in this example). But if we change the repro to pass |
Discussed in #8782
Originally posted by DanPuzzuoli December 2, 2021
Hi,
I've seen a bunch of discussions and issues surrounding this so I apologize if i'm re-raising something that has already been addressed elsewhere. I don't understand the internals of JAX enough to understand if this is some version of an issue that's already raised, though a lot of what I'm seeing is in already closed issues so I assume are solved and hence this is different.
When I run the following code on the latest release versions of jax/jaxlib (a self-contained version of my actual code):
I get the error:
I've definitely reverse-mode differentiated this code in the past with success, though it was some time ago so would have been on a much older version of jax/jaxlib. Based on other similar issues I've seen it seems like this has something to do with the interaction of reverse-mode autodiff, vmap, and the control flow used in
odeint
, but again, the issues I've seen raising these kinds of errors seem to have been solved?Thanks!
The text was updated successfully, but these errors were encountered: