Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numerical precision differences in jitted vs. non-jitted code! #22881

Closed
shoumikdc opened this issue Aug 5, 2024 · 3 comments
Closed

Numerical precision differences in jitted vs. non-jitted code! #22881

shoumikdc opened this issue Aug 5, 2024 · 3 comments
Assignees
Labels
question Questions for the JAX team

Comments

@shoumikdc
Copy link

shoumikdc commented Aug 5, 2024

Description

Hi all! I am seeing some strange numerical behavior in jitted vs. non-jitted code. I start by defining the two functions below. They should ideally give the same result since the second term in f2 gets multiplied by zero!

from jax import jit, vmap, config
import jax.numpy as jnp
config.update("jax_enable_x64", True)

def f1(x):
    return jnp.cos(x/2)

def f2(x):
    return jnp.cos(x/2) + 0 * jnp.sin(x/2)

If I then calculate this over a given range xs = jnp.linspace(-2, 2, 101) , I get different results depending on whether jit is used or not:

xs = jnp.linspace(-2, 2, 101) 
print(vmap(f1)(xs) - vmap(f2)(xs))

Output: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0.]

print(jit(vmap(f1))(xs) - jit(vmap(f2))(xs))

Output: [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00 -1.11022302e-16  0.00000000e+00  0.00000000e+00
 -1.11022302e-16 -1.11022302e-16 -1.11022302e-16  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  1.11022302e-16  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00 -1.11022302e-16  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00 -1.11022302e-16  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
 -1.11022302e-16 -1.11022302e-16  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00 -1.11022302e-16 -1.11022302e-16
 -1.11022302e-16  0.00000000e+00 -1.11022302e-16  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00]

I'll also note that in this case, I could do away with the vmap and just pass in xs to the two functions. However, I am seeing differences between the jitted and non-jitted versions there as well.

Is this kind of behavior expected? This feels like a numerical precision problem, but I have enabled 64 bit mode in the config - so I'm not sure why that would be the case.

Any assistance would be greatly appreciated!

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.30
jaxlib: 0.4.30
numpy:  1.24.3
python: 3.9.16 | packaged by conda-forge | (main, Feb  1 2023, 21:38:11)  [Clang 14.0.6 ]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='dhcp-10-29-164-99.dyn.MIT.EDU', release='21.1.0', version='Darwin Kernel Version 21.1.0: Wed Oct 13 17:33:01 PDT 2021; root:xnu-8019.41.5~1/RELEASE_ARM64_T6000', machine='arm64')
@shoumikdc shoumikdc added the bug Something isn't working label Aug 5, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Aug 5, 2024

Hi - thanks for the question! This kind of difference is expected: in general, floating point operations are only accurate to within a particular precision that depends on the width of the float representation. When you compute the "same" result in two ways, the results will not in general be bitwise-equivalent. JIT-compilation replaces your original sequence of operations with a more efficient compiled kernel, and so in general you should not expect bitwise-equivalent outputs.

You can see the approximate expected precision using finfo:

In [1]: import numpy as np

In [2]: np.finfo(np.float64).eps
Out[2]: np.float64(2.220446049250313e-16)

The differences you're seeing are consistent with those expectations.

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 5, 2024

Also, a side note: you might wonder why the compiler doesn't just simplify x + 0 * y to x: the reason for this is that for floating point math, these two expressions may return different results! For example, if you plug-in y = np.inf or y = np.nan, the first and second expression are not equivalent.

@mattjj
Copy link
Collaborator

mattjj commented Aug 6, 2024

I think we should probably close this as working-as-intended, and Jake's answer covers the reasoning well.

@mattjj mattjj closed this as completed Aug 6, 2024
@mattjj mattjj added question Questions for the JAX team and removed bug Something isn't working labels Aug 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

3 participants