Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

grad + vmap + odeint AssertionError #8783

Open
mattjj opened this issue Dec 3, 2021 Discussed in #8782 · 12 comments
Open

grad + vmap + odeint AssertionError #8783

mattjj opened this issue Dec 3, 2021 Discussed in #8782 · 12 comments
Assignees
Labels
bug Something isn't working

Comments

@mattjj
Copy link
Member

mattjj commented Dec 3, 2021

Discussed in #8782

Originally posted by DanPuzzuoli December 2, 2021
Hi,

I've seen a bunch of discussions and issues surrounding this so I apologize if i'm re-raising something that has already been addressed elsewhere. I don't understand the internals of JAX enough to understand if this is some version of an issue that's already raised, though a lot of what I'm seeing is in already closed issues so I assume are solved and hence this is different.

When I run the following code on the latest release versions of jax/jaxlib (a self-contained version of my actual code):

from jax.experimental.ode import odeint
from jax import jit, value_and_grad, vmap
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)
config.update('jax_platform_name', 'cpu')


T = 1.

X = -1j * jnp.array([[0., 1.], [1., 0.]], dtype=complex)
Y = -1j * jnp.array([[0., -1j], [1j, 0.]], dtype=complex)

def err_obj(a, b_vals):
    
    def err(b):
        def rhs(y, t):
            return (b * X + a * (t**2) * Y) @ y
        
        res = odeint(rhs, y0=jnp.eye(2, dtype=complex), t=jnp.array([0, T], dtype=complex), rtol=1e-6, atol=1e-6)
        
        return jnp.abs((X * res[-1]).sum())**2 / 4


    all_err = vmap(err)(b_vals)
    return all_err.sum()

b_vals = jnp.array([1., 2., 3., 4., 5.])
jit(value_and_grad(lambda a: err_obj(a, b_vals)))(1.)

I get the error:

AssertionError: length mismatch: [1, 4]

I've definitely reverse-mode differentiated this code in the past with success, though it was some time ago so would have been on a much older version of jax/jaxlib. Based on other similar issues I've seen it seems like this has something to do with the interaction of reverse-mode autodiff, vmap, and the control flow used in odeint, but again, the issues I've seen raising these kinds of errors seem to have been solved?

Thanks!

@mattjj mattjj added the bug Something isn't working label Dec 3, 2021
@mattjj mattjj self-assigned this Dec 3, 2021
@mattjj
Copy link
Member Author

mattjj commented Dec 3, 2021

I'm not sure, but I think @danieldjohnson told me about this error a long time ago (maybe back in May?), and shared this repro which hits the same thing:

import functools

import jax
import jax.numpy as jnp

@functools.partial(jax.custom_vjp, nondiff_argnums=(0,))
def foo_custom_vjp(kernel_closure, aux_args):
  return kernel_closure(*aux_args)

def foo_fwd(kernel_closure, aux_args):
  out = foo_custom_vjp(kernel_closure, aux_args)
  # Output must be saved for the error to occur.
  return out, (out,)

def foo_bwd(kernel_closure, saved, table_bar):
  return ([jnp.zeros([8, 10])],)
  raise NotImplementedError("JAX problem occurs before backward pass executes")


foo_custom_vjp.defvjp(foo_fwd, foo_bwd)


def repro(differentiable_arg, batched_arg):

  def kernel_closure(differentiable_arg):
    # This function closes over batched_arg, which is a batched NDArray;
    # it then gets nondiff-argnumed into foo_custom_vjp
    return jnp.exp(differentiable_arg + batched_arg)

  result = foo_custom_vjp(kernel_closure, [differentiable_arg])
  return jnp.mean(result)


def batched_repro(soft_prediction, samples, batch_size):
  def go(sample):
    return repro(soft_prediction, sample)
  return jnp.mean(jax.vmap(go)(samples))

@functools.partial(jax.jit)
def trigger(soft_prediction, samples):
  grads = jax.grad(batched_repro)(soft_prediction, samples, batch_size=3)
  return grads

trigger(jnp.zeros([8, 10]), jnp.zeros([7, 10], dtype=jnp.int32))

I vaguely recall that:

  1. I tried to fix it, and at least figured out the problem, but then
  2. decided that I would instead land no-more-post-process soon and that would define the issue away.

Then I failed to land that branch... and now I've forgotten everything I figured out before!

@mattjj
Copy link
Member Author

mattjj commented Dec 3, 2021

Okay, I'm starting to get the issue again... I'm going to jot some notes here but they probably won't be comprehensible without detailed knowledge of JAX internals.

These examples exercise BatchTrace.post_process_custom_vjp_call, which is currently identical to BatchTrace.post_process_custom_jvp_call. But actually that implementation is broken!

The problem is that the number of outputs to this custom_vjp_call change between when we run BatchTrace.post_process_custom_jvp_call and when we run the todo inside it. That is, vals in the outer lexical scope (that of the body of BatchTrace.post_process_custom_jvp_call) might be length 2 while vals inside the body of todo might be length 1. But we're using the same dims! In this particular case, the number of outputs is changing because we're packaging a primal-and-tangent pair up into a single JVPTracer.


Backing up, the reason this is happening, and the reason post_process_call methods exist at all, is that the JAX tracing machinery speculatively assumes that the only transformations that apply to a primitive are those which have Tracers boxing some arguments. We assume that even for call primitives (like the one underlying custom_vjp). But for call primitives that assumption can be broken when a the function-valued argument closes over Tracers of Traces not represented on the arguments. In that case, we can end up with outputs that are boxed in more Tracer levels than we expected given the arguments!

The solution is to unwrap and re-wrap output Tracers, and call back into the Traces we forgot to process so they're aware that a call primitive was bound. That's exactly what the post_process_call methods do for call primitives: when the core system calls them (when we're about to return from the function called by the higher-order primitive), they dutifully unwrap their Tracers and return a todo callback so that the core can put them back on the right order (after the higher-order primitive has returned).


Coming back to this particular issue, starting in the case of custom_jvp (even though the repros above are for custom_vjp), the values returned by the called function are either just primals or are a flattened list of primal and tangent pairs (guaranteed to be twice the length of the just-primals version). In the latter case, dims will be of length 2N (and of structure [*dims_, *dims_]) when post_process_custom_jvp_call is first called, but then of length N (and of structure dims_) when the todo is called. (In the former case there's no change in dims between the two stages.)

We should plumb enough information into this function so that it knows what's going to happen and so that it can check and manipulate dims as needed. That information is available, e.g. in this suppressed _ value in the caller of process_env_traces, though I'll have to think about how best to plumb it...

@mattjj
Copy link
Member Author

mattjj commented Dec 3, 2021

One last thing to note: in the custom_jvp example there's this 2N-or-N relationship, but for custom_vjp the outputs in question are the primals and the residuals, so there's no relationship in general. I think we have to plumb in how many residuals there are.

@DanPuzzuoli
Copy link
Contributor

Not sure if this is the place for it, but can you explain what

But for call primitives that assumption can be broken when a the function-valued argument closes over Tracers of Traces not represented on the arguments.

means? Specifically: what does it mean for a function to "close over Tracers". I've seen this statement appear in a bunch of places (while looking up this particular issue). Even without understanding the JAX internals much it seems like something I might be able to recognize in the future.

@mattjj
Copy link
Member Author

mattjj commented Dec 3, 2021

Sure! Apologies for the jargon.

Here's an example program, without any JAX for now:

def f(x):
  def g(y):
    return x * y
  return g(3.)

f(2.)

Here I'd say the inner function g closes over the variable x: that variable occurs in the body of g while being bound not as a parameter/argument of g but instead in a "lexically enclosing scope" (meaning here the scope introduced by f, which "lexically encloses" g in the sense that the text defining g is literally inside the body of f).

This is interesting because in some sense the value of x is an input to an application of g, even though it's not in g's parameters/arguments. So one could say there are actually two kinds of inputs to a Python function: arguments and closed-over values.

Let's add JAX to the example:

def f(x):
  @jax.jit
  def g(y):
    return x * y
  return g(3.)

jax.grad(f)(2.)

Here again the function g closes over x. The difference is that now JAX's tracing mechanism is at work. To evaluate jax.grad(f)(2.), we box up the value of 2. in a Tracer and then use it to trace (i.e. monitor) the operations that are applied to it to produce the output of f. So in the body of f, x will refer to a Tracer instance.

But notice that when we call the jit-decorated function in evaluating g(3.), that Tracer doesn't appear in the arguments to g. Yet it'll affect the output of g! In this case I'd say that the function g closes over a Tracer when evaluating jax.grad(f)(2.).

Because the Tracers that a Python callable closes over can't easily be inspected ahead-of-time, you might only find out what Tracers are in a function's closure (and hence which transformations must be applied to the function) after running it. You might imagine that makes tracing a bit tricky.

The general mechanism which sorts out these issues is the thing I alluded to in above comments. And it's buggy for custom_vjp when closed-over vmap tracers are involved.

WDYT?

@DanPuzzuoli
Copy link
Contributor

Ah yeah, makes sense! Thanks for the clarification. I saw in a previous issue you'd recommended swapping the order of grad and vmap to get around this, would you still recommend doing this here?

@mattjj
Copy link
Member Author

mattjj commented Dec 3, 2021

I'm not sure. But this is a bad bug and I plan to fix this in the next few days...

@mattjj
Copy link
Member Author

mattjj commented Dec 11, 2021

Okay, let's try to crush this bug!

To start, here's a repro of the jax.custom_jvp version:

import jax
import jax.numpy as jnp

def h(z):
  def f(x):
    @jax.custom_jvp
    def g(y):
      return x * y

    # NOTE: rule closes over vmap tracer
    @g.defjvp
    def g_jvp(primals, tangents):
      (y,), (ydot,) = primals, tangents
      return x * y, x * ydot

    return g(z)  # NOTE: no vmapped arg

  return jax.vmap(f)(jnp.arange(3.))

jax.jvp(h, (1.,), (2.,))

@mattjj
Copy link
Member Author

mattjj commented Dec 12, 2021

And here's a corresponding jax.custom_vjp minimal repro:

import jax
import jax.numpy as jnp

def h(z):
  def f(x):
    @jax.custom_vjp
    def g(y):
      return x * y

    def g_fwd(y):
      return x * y, (x * y, y)
    def g_rev(xys, w_bar):
      xy, _ = xys
      return (xy * w_bar,)
    g.defvjp(g_fwd, g_rev)

    return g(z)

  return jax.vmap(f)(jnp.arange(3.)).sum()

jax.grad(h)(1.)

I have a fix, in #8915, but I need to look it over and decide if there are things to clean up, and whether to break it into multiple PRs (because there were a couple cleanups I did which made the fix easier but could be made independent).

mattjj added a commit to mattjj/jax that referenced this issue Dec 12, 2021
@mattjj
Copy link
Member Author

mattjj commented Dec 12, 2021

Actually I just noticed that #8915 fixes my repro as well as the code in this above comment, but the code in the OP runs into a different error. I think it's a separate bug for me to fix...

mattjj added a commit to mattjj/jax that referenced this issue Dec 28, 2021
mattjj added a commit to mattjj/jax that referenced this issue Jan 6, 2022
mattjj added a commit to mattjj/jax that referenced this issue Jan 12, 2022
mattjj added a commit to mattjj/jax that referenced this issue Jan 12, 2022
mattjj added a commit to mattjj/jax that referenced this issue Jan 12, 2022
mattjj added a commit to mattjj/jax that referenced this issue Jan 12, 2022
mattjj added a commit to mattjj/jax that referenced this issue Jan 12, 2022
mattjj added a commit to mattjj/jax that referenced this issue Jan 12, 2022
mattjj added a commit to mattjj/jax that referenced this issue Jan 12, 2022
mattjj added a commit to mattjj/jax that referenced this issue Jan 12, 2022
@mattjj
Copy link
Member Author

mattjj commented Jan 12, 2022

#8915 finally went in, but IIRC last I tried it the repro in the OP now runs into a distinct second issue.

@mattjj
Copy link
Member Author

mattjj commented Jan 12, 2022

Ah! The distinct second issue is just that odeint has a check for floating point time values, and it raises an error if it's given complex values (as in this example).

But if we change the repro to pass t=jnp.array([0, T], dtype=complex), we get yet another issue, to do with a leaked vmap tracer...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants