New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Padding doesn't work in jitted mode #3620
Comments
Thanks for the question! The issue is that the result array size depends on the value of the
I want to improve the error message you get here so that it's crystal clear what's going wrong, hat the constraints are, and what the options are for fixing it. Do you have any suggestions for how this could have been made more clear? (The recent #3603 is related.) |
For posterity, here's the current error message (stack trace removed):
|
One thing you can do now is have JAX re-compile new versions of the function for every value (rather than shape/dtype) of the print(jjit(pad_last_dim, static_argnums=(1,))(vals, 1)) But this isn't a great idea; there are more issues here. The code treats Let me know if you have thoughts on how to make the error message more clear! I have some I'm starting to play with. |
After #3627, here's the new error message even with
If you need to manipulate a from jax import jit as jjit
import jax.numpy as jnp
from jax import ops
import numpy as np
vals = np.random.randn(50, 100)
def pad_last_dim(array, pad_size):
ndim = jnp.ndim(array)
npad = np.zeros((ndim, 2), dtype=np.int32)
axis = ndim - 1
npad[axis, 1] = pad_size
npad = list(map(list, npad))
return jnp.pad(array, npad, 'constant', constant_values=0)
print(pad_last_dim(vals, 1)) # All good
print(jjit(pad_last_dim, static_argnums=(1,))(vals, 1)) However, it's probably better just to use lists/tuples here: from jax import jit as jjit
import jax.numpy as jnp
from jax import ops
import numpy as np
vals = np.random.randn(50, 100)
def pad_last_dim(array, pad_size):
ndim = jnp.ndim(array)
npad = [[0, pad_size]] * ndim
return jnp.pad(array, npad, 'constant', constant_values=0)
print(pad_last_dim(vals, 1)) # All good
print(jjit(pad_last_dim, static_argnums=(1,))(vals, 1)) |
Hrm, test failures on #3627 have made me think we should allow |
Thanks a lot for that Matt. Bit late for me but I'll come back to it tomorrow. I think you should forbid lists and ndarrays tbh. Would be consistent with tf where they consider lists to be a collection of some kind. In the meantime I like your solution and I'll use it (I don't need the ndarray, just was easier to set the padding) |
I think these are just failures due to JAX's test harness, which uses np.random to make random padding sizes. This is much less common in actual user code. |
Thanks for catching that, @shoyer! |
Your solution worked for me, thanks a lot! |
I think we should declare this issue fixed. We still allow arrays as
and I think that's good enough. |
Hi,
Thanks for the work on the library.
For a project I'm needing to implement a padding on the last dimension of a tensor. Everything works fine when the function is not jitted, but I get a conversion error when I jit it. It seems to me like it is coming from a bug in the lax_numpy internals (this line), but maybe I'm missing something...
Adrien
The code is as follows:
The text was updated successfully, but these errors were encountered: