-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A case where jnp.sinc
+ jit
make a function unstable and non-deterministic
#16129
Comments
I'm trying to run your reproducer, but Can you fix the example code? |
Sorry, not sure how I skipped that line, I have added it back now. Should be correct now, I have just tried it on colab and confirmed that I get the problem. |
Just adding on this, it seems to not matter if one uses |
I just tried to reproduce this again: it reproduces for me with jax/jaxlib 0.4.10, but not with 0.4.12. Can you try updating to 0.4.12? |
Yes I have just tested it on 0.4.11 and I don't see the issue anymore, not sure what changed |
I'm going to see if I can bisect it to a change, but fixed is fixed! |
Indeed! 😄 |
Not sure if it can be helpful, but this also got solved by 0.4.11: #14302 |
I bisected the fix to this XLA change openxla/xla@215705b I'm not sure why it fixes the problem, but it seems to do the trick! |
Description
Hello jax team!
While working on my jax-based wave simulator, I have encountered this edge case where things are quite weird. Following is the best MRE that I managed to prepare from stripping down the simulation code, sorry if it is quite long, but in a minute I will explain why I can't obviously see how to make it shorter
The last parts of the code are used to visually plot the results of running the function 5 times with
jit
(red) and withoutjit
(black):Please note the log-axis. Where the red traces disappear, the output is a NaN.
As you can see, the function without jit is stable, while the outputs with jit are unstable but, more importantly I believe, non deterministic.
Some extra observations that make things more weird:
k_space_op = jnp.sinc(jnp.zeros(N))
withk_space_op = jnp.ones(N)
, which in every way should be numerically equivalent sinceN
is a static argument, the results are almost correct and the function becomes deterministic. I say almost because the very first values are actually different than for the non-jitted version: this is probably just a matter of numerical precision.N, dx = (64, 64), (0.1e-3, 0.1e-3)
, the jitted function is stable and deterministic again. It also exactly matches the non-jitted code.What jax/jaxlib version are you using?
jax 0.4.10, jaxlib 0.4.10
Which accelerator(s) are you using?
GPU
Additional system info
Python 3.10.11, Ubuntu 22.04.2 LTS
NVIDIA GPU info
The text was updated successfully, but these errors were encountered: