-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
np.result_type
fails with bfloat16
#11014
Comments
I am also curious why a bfloat16 array gets promoted to float32 when broadcast against a scalar. E.g. This is not the way it works for np.float16: It seems the general rule is that broadcasting an array against a scalar of higher dtype should not promote the whole array to the higher dtype. But bfloat16 doesn't seem to respect this. |
Hi, thanks for the question! It looks like you are finding differences between numpy's promotion behavior and JAX's promotion behavior. These differences are intentional, and are documented in JAX Type Promotion Semantics. If you want numpy-style promotion, you can use Regarding the addition of values to arrays, if you add a Python scalar value like This can be seen in the following example: import numpy as np
import jax.numpy as jnp
np_value = np.array(1.0, dtype=jnp.bfloat16)
jax_value = jnp.array(1.0, dtype=jnp.bfloat16)
print(jnp.result_type(jax_value, np.nan)) # JAX-style promotion
# bfloat16
print((jax_value + np.nan).dtype) # jax array addition -> jax promotion rules
# bfloat16
print(np.result_type(np_value, np.nan)) # numpy-style promotion
# float64
print((np_value + np.nan).dtype) # numpy array addition -> numpy promotion rules
# float64 Does that answer your question? |
Thanks for your reply, it answers the question about how things are meant to be, but it still does not answer explain some of the behavior. For example, if I take your code, and replace:
with:
And leave the rest the same, then I get:
I would not expect Also, the behavior described by Matthew above does seem to make a lot of sense:
Now, it is true that these seem to be problems when using numpy functionality with bfloat16 and not jax functionality per se, except that I don't know if |
Thanks - all of that has to do with numpy (not JAX) dtype promotion. Since Here's the JAX version of what you did above, and it gives reasonable results: import jax.numpy as jnp
jnp_value = jnp.ones((2, 2), dtype=jnp.bfloat16)
print(jnp.result_type(jnp_value, jnp.nan))
# bfloat16
print((jnp_value + jnp.nan).dtype)
# bfloat16
print((jnp.array([1,2,3], jnp.float32) * 1.).dtype) # float32 (same type)
print((jnp.array([1,2,3], jnp.float16) * 1.).dtype) # float16 (same type)
print((jnp.array([1,2,3], jnp.bfloat16) * 1.).dtype) # bfloat16 (same type) If you'd like to report the numpy behavior as a bug, you could do so at http://github.com/numpy/numpy/issues, but I suspect maintainers would close it as "works as intended". I don't know of any way that NumPy could be changed to correctly handle arbitrary third-party dtypes, which is what Alternatively, if you'd like to see NumPy handle bfloat16 natively, you could peruse this issue: numpy/numpy#19808. It has some discussion of why NumPy may or may not include bfloat16 natively in the future. |
I'm going to close this issue, since I believe JAX is handling bfloat16 correctly in all cases, and the remaining poor behavior is due to numpy's handling of custom dtypes, something that JAX has no control over. Thanks for the report! |
Sorry, I did not have time to reply before this was closed.
But using JAX ops tends to want to put things on accelerators, right? So for CPU only loads, e.g. prefetching we tend to use either TensorFlow or plain Numpy.
So whose responsibility is to maintain a custom type for numpy, is it numpy or is it JAX, or is it XLA? As far as I can tell, it is JAX which is doing It may seem that I am being nit-picky, but I suspect in practice anyone doing |
The custom
No, in an operation like
You can safely do computations with bfloat16 numpy arrays, you just need to be explicit in cases where numpy's default type promotion rules don't give you the result you want. For example, here's the solution to your problematic case above: import numpy as np
import jax.numpy as jnp
np.ones((2, 2), dtype=jnp.bfloat16) * np.array(1, dtype=jnp.bfloat16)
# array([[1, 1],
# [1, 1]], dtype=bfloat16) But I suspect there's very little value in doing numpy operations in bfloat16, because CPUs aren't optimized for this data type. So my guess is you'd be better off doing host-side computations in float32 or float64, and then converting to bfloat16 before pushing the results to device. |
If making numpy natively handle bfloat16 is important to you, I'd suggest weighing in on numpy/numpy#19808. |
Thanks for the extra explanation.
It is great to hear this. I guess the reason I was getting nervous was because my thinking was something like:
what other computations may be silently outputting the wrong thing when using I just don't know the details of what it involves to write a custom type for numpy, my guess was that since numpy does not natively support But probably I got the whole picture about how this works wrong, and if you have a recommendation for documentation explaining how all of this magic works, I would really appreciate a pointer :) |
I don't know of any good documentation for creating new custom dtypes, but you can get an idea of what is available by looking at the bfloat16 source: https://github.com/tensorflow/tensorflow/blob/1f0782337808704e564883f77d89e51b50d60808/tensorflow/python/lib/core/bfloat16.cc In particular, I think this is the line that results in the problematic promotion that you've observed: https://github.com/tensorflow/tensorflow/blob/1f0782337808704e564883f77d89e51b50d60808/tensorflow/python/lib/core/bfloat16.cc#L422-L427 |
Thanks so much! This sounds like I should file the bug in the TF repo then, thanks! |
You certainly could, but given how numpy type promotion works, I don’t think there’s any change that TF could make to that code that would improve the situation. It’s fundamentally a numpy dtype promotion bug. |
Although, as a final nit. If I try this with TF I get an error:
So it seems that TF are trying to discourage people from actually using this type of behavior. (Although I can reproduce the same if I do |
I think you're looking for |
Fair enough, I will just do that. Thanks again for all of your time! |
Disclaimer: I am not sure if this is a JAX, bug a Numpy bug, or not a bug at all, so any help triaging would be welcome!
Reproducer:
When I run the code above with
bfloat16
, I get inconsistent results between the output dtype of an op, and the dtype returned byresult_type
. (either the wrong type if usingjnp.result_type
, or a Type error if usingnp.result_type
.All issues go away if I explicit make an array for
nan
with the same dtype.Not sure how much support for
bfloat16
I should expect from numpy methods (or whether the TypeError is expected), but even if no support is expected I am still surprised thatjnp.result_type
returns the wrong answer.The text was updated successfully, but these errors were encountered: