Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Padding doesn't work in jitted mode #3620

Closed
AdrienCorenflos opened this issue Jul 1, 2020 · 10 comments
Closed

Padding doesn't work in jitted mode #3620

AdrienCorenflos opened this issue Jul 1, 2020 · 10 comments
Assignees
Labels
better_errors Improve the error reporting question Questions for the JAX team

Comments

@AdrienCorenflos
Copy link
Contributor

Hi,

Thanks for the work on the library.
For a project I'm needing to implement a padding on the last dimension of a tensor. Everything works fine when the function is not jitted, but I get a conversion error when I jit it. It seems to me like it is coming from a bug in the lax_numpy internals (this line), but maybe I'm missing something...

Adrien

The code is as follows:

from jax import jit as jjit
import jax.numpy as jnp
import numpy as np

vals = np.random.randn(50, 100)

def pad_last_dim(array, pad_size):
    ndim = jnp.ndim(array)
    npad = jnp.zeros((ndim, 2), dtype=jnp.int32)
    axis = ndim - 1
    npad = ops.index_update(npad, ops.index[axis, 1], pad_size)
    return jnp.pad(array, npad, 'constant', constant_values=0)

print(pad_last_dim(vals, 1))  # All good
print(jjit(pad_last_dim)(vals, 1))  # raises
@mattjj mattjj added question Questions for the JAX team better_errors Improve the error reporting labels Jul 1, 2020
@mattjj mattjj self-assigned this Jul 1, 2020
@mattjj
Copy link
Member

mattjj commented Jul 1, 2020

Thanks for the question!

The issue is that the result array size depends on the value of the pad_size argument to the jitted function. This kind of value-dependent-shape isn't allowed in a jitted function. The reason is that we try to compile a single XLA HLO program for all possible values of the inputs of a given shape and dtype, but since XLA HLO programs have array shapes in their static type system (for performance reasons), that means we can't compile as a single program a function in which the shapes depend on the numerical values of inputs.

One thing you can do is have JAX re-compile new versions of the function for every value (rather than shape/dtype) of the pad_size argument. Here's how you would do it: (EDIT: another change is needed too, see below.)

I want to improve the error message you get here so that it's crystal clear what's going wrong, hat the constraints are, and what the options are for fixing it. Do you have any suggestions for how this could have been made more clear?

(The recent #3603 is related.)

@mattjj
Copy link
Member

mattjj commented Jul 1, 2020

For posterity, here's the current error message (stack trace removed):

Exception: The numpy.ndarray conversion method __array__() was called on the
JAX Tracer object Traced<ShapedArray(int32[2,2]):JaxprTrace(level=-1/1)>.

This error can occur when a JAX Tracer object is passed to a raw numpy
function, or a method on a numpy.ndarray object. You might want to check that
you are using `jnp` together with `import jax.numpy as jnp` rather than using
`np` via `import numpy as np`. If this error arises on a line that involves
array indexing, like `x[idx]`, it may be that the array being indexed `x` is a
raw numpy.ndarray while the indices `idx` are a JAX Tracer instance; in that
case, you can instead write `jax.device_put(x)[idx]`

@mattjj
Copy link
Member

mattjj commented Jul 1, 2020

One thing you can do now is have JAX re-compile new versions of the function for every value (rather than shape/dtype) of the pad_size argument. You can do that using static_argnums:

print(jjit(pad_last_dim, static_argnums=(1,))(vals, 1))

But this isn't a great idea; there are more issues here. The code treats npad as an array, creating it with jnp.zeros and then applying jax.ops.index_update to it, means that value will be computed on, say, your GPU. That's likely not where you want to do scalar integer shape arithmetic! It's better not to use jnp or jax.ops to do shape computations, and in the future this is likely to raise an error.

Let me know if you have thoughts on how to make the error message more clear! I have some I'm starting to play with.

@mattjj
Copy link
Member

mattjj commented Jul 1, 2020

After #3627, here's the new error message even with static_argnums=(1,):

TypeError: jax.numpy.pad got an unexpected type for 'pad_width': got [[0 0]
 [0 1]] of type <class 'jax.interpreters.xla.DeviceArray'>.

Unlike numpy, jax.numpy requires the 'pad_width' argument to jax.numpy.pad to
be an int, single-element tuple/list with an int element, or tuple/list of int
pairs (each pair a tuple or list); in particular, 'pad_width' cannot be an
array.

If you need to manipulate a pad argument with NumPy, as in the original code,I recommend doing it with np rather than jnp, perhaps like this:

from jax import jit as jjit
import jax.numpy as jnp
from jax import ops
import numpy as np

vals = np.random.randn(50, 100)

def pad_last_dim(array, pad_size):
    ndim = jnp.ndim(array)
    npad = np.zeros((ndim, 2), dtype=np.int32)
    axis = ndim - 1
    npad[axis, 1] = pad_size
    npad = list(map(list, npad))
    return jnp.pad(array, npad, 'constant', constant_values=0)

print(pad_last_dim(vals, 1))  # All good
print(jjit(pad_last_dim, static_argnums=(1,))(vals, 1))

However, it's probably better just to use lists/tuples here:

from jax import jit as jjit
import jax.numpy as jnp
from jax import ops
import numpy as np

vals = np.random.randn(50, 100)

def pad_last_dim(array, pad_size):
    ndim = jnp.ndim(array)
    npad = [[0, pad_size]] * ndim
    return jnp.pad(array, npad, 'constant', constant_values=0)

print(pad_last_dim(vals, 1))  # All good
print(jjit(pad_last_dim, static_argnums=(1,))(vals, 1))

@mattjj
Copy link
Member

mattjj commented Jul 1, 2020

Hrm, test failures on #3627 have made me think we should allow numpy.ndarrays, but not JAX arrays.

@AdrienCorenflos
Copy link
Contributor Author

Thanks a lot for that Matt. Bit late for me but I'll come back to it tomorrow. I think you should forbid lists and ndarrays tbh. Would be consistent with tf where they consider lists to be a collection of some kind.

In the meantime I like your solution and I'll use it (I don't need the ndarray, just was easier to set the padding)

@shoyer
Copy link
Member

shoyer commented Jul 1, 2020

Hrm, test failures on #3627 have made me think we should allow numpy.ndarrays, but not JAX arrays.

I think these are just failures due to JAX's test harness, which uses np.random to make random padding sizes. This is much less common in actual user code.

@mattjj
Copy link
Member

mattjj commented Jul 1, 2020

Thanks for catching that, @shoyer!

@AdrienCorenflos
Copy link
Contributor Author

AdrienCorenflos commented Jul 2, 2020

from jax import jit as jjit
import jax.numpy as jnp
from jax import ops
import numpy as np

vals = np.random.randn(50, 100)

def pad_last_dim(array, pad_size):
    ndim = jnp.ndim(array)
    npad = [[0, pad_size]] * ndim
    return jnp.pad(array, npad, 'constant', constant_values=0)

print(pad_last_dim(vals, 1))  # All good
print(jjit(pad_last_dim, static_argnums=(1,))(vals, 1))

Your solution worked for me, thanks a lot!

@hawkinsp
Copy link
Member

hawkinsp commented Oct 4, 2023

I think we should declare this issue fixed. We still allow arrays as pad_width arguments to jnp.pad, but I think that's likely a good thing. The error message is much better these days:

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[2,2].
pad_width argument of jnp.pad
The error occurred while tracing the function pad_last_dim at <ipython-input-1-ec65a8c978ae>:7 for jit. This concrete value was not available in Python because it depends on the value of the argument pad_size.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

and I think that's good enough.

@hawkinsp hawkinsp closed this as completed Oct 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
better_errors Improve the error reporting question Questions for the JAX team
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants