You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi all! I am seeing some strange numerical behavior in jitted vs. non-jitted code. I start by defining the two functions below. They should ideally give the same result since the second term in f2 gets multiplied by zero!
I'll also note that in this case, I could do away with the vmap and just pass in xs to the two functions. However, I am seeing differences between the jitted and non-jitted versions there as well.
Is this kind of behavior expected? This feels like a numerical precision problem, but I have enabled 64 bit mode in the config - so I'm not sure why that would be the case.
Any assistance would be greatly appreciated!
System info (python version, jaxlib version, accelerator, etc.)
Hi - thanks for the question! This kind of difference is expected: in general, floating point operations are only accurate to within a particular precision that depends on the width of the float representation. When you compute the "same" result in two ways, the results will not in general be bitwise-equivalent. JIT-compilation replaces your original sequence of operations with a more efficient compiled kernel, and so in general you should not expect bitwise-equivalent outputs.
You can see the approximate expected precision using finfo:
In [1]: importnumpyasnpIn [2]: np.finfo(np.float64).epsOut[2]: np.float64(2.220446049250313e-16)
The differences you're seeing are consistent with those expectations.
Also, a side note: you might wonder why the compiler doesn't just simplify x + 0 * y to x: the reason for this is that for floating point math, these two expressions may return different results! For example, if you plug-in y = np.inf or y = np.nan, the first and second expression are not equivalent.
Description
Hi all! I am seeing some strange numerical behavior in jitted vs. non-jitted code. I start by defining the two functions below. They should ideally give the same result since the second term in
f2
gets multiplied by zero!If I then calculate this over a given range
xs = jnp.linspace(-2, 2, 101)
, I get different results depending on whether jit is used or not:I'll also note that in this case, I could do away with the
vmap
and just pass inxs
to the two functions. However, I am seeing differences between the jitted and non-jitted versions there as well.Is this kind of behavior expected? This feels like a numerical precision problem, but I have enabled 64 bit mode in the config - so I'm not sure why that would be the case.
Any assistance would be greatly appreciated!
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: