Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

np.result_type fails with bfloat16 #11014

Open
alvarosg opened this issue Jun 7, 2022 · 16 comments
Open

np.result_type fails with bfloat16 #11014

alvarosg opened this issue Jun 7, 2022 · 16 comments
Assignees
Labels
bug Something isn't working

Comments

@alvarosg
Copy link
Contributor

alvarosg commented Jun 7, 2022

Disclaimer: I am not sure if this is a JAX, bug a Numpy bug, or not a bug at all, so any help triaging would be welcome!

Reproducer:

import jax.numpy as jnp
import numpy as np

dtype = jnp.bfloat16
array = np.ones((2, 2), dtype=dtype)

nan = np.nan

print((array + nan).dtype)  # float32
print(jnp.result_type(array, nan))  # bfloat16, Different from the above!
print(np.result_type(array, nan))  # raises TypeError: The DTypes <class 'numpy.dtype[float16]'> and <class 'numpy.dtype[bfloat16]'> do not have a common DType. For example they cannot be stored in a single array unless the dtype is `object`.

When I run the code above with bfloat16, I get inconsistent results between the output dtype of an op, and the dtype returned by result_type. (either the wrong type if using jnp.result_type, or a Type error if using np.result_type.

All issues go away if I explicit make an array for nan with the same dtype.

casted_nan = np.array(nan, dtype=dtype)
print((array + casted_nan).dtype)  # bfloat16
print(np.result_type(array, casted_nan))  # bfloat16, as expected from the above
print(jnp.result_type(array, casted_nan))  # bfloat16, as expected from the above

Not sure how much support for bfloat16 I should expect from numpy methods (or whether the TypeError is expected), but even if no support is expected I am still surprised that jnp.result_type returns the wrong answer.

@alvarosg alvarosg added the bug Something isn't working label Jun 7, 2022
@mjwillson
Copy link
Contributor

mjwillson commented Jun 7, 2022

I am also curious why a bfloat16 array gets promoted to float32 when broadcast against a scalar. E.g. np.array([1,2,3], jnp.bfloat16) * 1. is float32.

This is not the way it works for np.float16: np.array([1,2,3], np.float16) * 1. is np.float16.

It seems the general rule is that broadcasting an array against a scalar of higher dtype should not promote the whole array to the higher dtype. But bfloat16 doesn't seem to respect this.

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 8, 2022

Hi, thanks for the question! It looks like you are finding differences between numpy's promotion behavior and JAX's promotion behavior. These differences are intentional, and are documented in JAX Type Promotion Semantics.

If you want numpy-style promotion, you can use np.result_type. If you want JAX-style promotion, you can use jnp.result_type.

Regarding the addition of values to arrays, if you add a Python scalar value like np.nan to a numpy array, you will get numpy-style promotion. If you add a Python scalar value like np.nan to a JAX array, you will get JAX-style promotion.

This can be seen in the following example:

import numpy as np
import jax.numpy as jnp

np_value = np.array(1.0, dtype=jnp.bfloat16)
jax_value = jnp.array(1.0, dtype=jnp.bfloat16)

print(jnp.result_type(jax_value, np.nan))  # JAX-style promotion
# bfloat16
print((jax_value + np.nan).dtype)  # jax array addition -> jax promotion rules
# bfloat16

print(np.result_type(np_value, np.nan))  # numpy-style promotion
# float64
print((np_value + np.nan).dtype)  # numpy array addition -> numpy promotion rules
# float64

Does that answer your question?

@alvarosg
Copy link
Contributor Author

alvarosg commented Jun 8, 2022

Thanks for your reply, it answers the question about how things are meant to be, but it still does not answer explain some of the behavior.

For example, if I take your code, and replace:

np_value = np.array(1.0, dtype=jnp.bfloat16)

with:

np_value = np.ones((2, 2), dtype=jnp.bfloat16)

And leave the rest the same, then I get:

print(np.result_type(np_value, np.nan))
# TypeError: The DTypes <class 'numpy.dtype[float16]'> and <class 'numpy.dtype[bfloat16]'> do not have a common DType.
print((np_value + np.nan).dtype) 
# float32

I would not expect np.ones, to behave so differently from np.array().

Also, the behavior described by Matthew above does seem to make a lot of sense:

print((np.array([1,2,3], jnp.float32) * 1.).dtype)   # float32 (same type)
print((np.array([1,2,3], jnp.float16) * 1.).dtype)   # float16 (same type)
print((np.array([1,2,3], jnp.bfloat16) * 1.).dtype)  # float32 (new type!)

Now, it is true that these seem to be problems when using numpy functionality with bfloat16 and not jax functionality per se, except that I don't know if bfloat16 is officially considered to be JAX functionality (since numpy does not support it natively), and I also don't know whether JAX makes any promises of supporting bfloat16 with numpy ops.

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 8, 2022

Thanks - all of that has to do with numpy (not JAX) dtype promotion. Since bfloat16 is a custom dtype that numpy has no knowledge of, I think it's somewhat expected that numpy type promotion does strange things with it. I would suggest that when working with bfloat16, you use JAX operations, because JAX is aware of the dtype and handles it correctly.

Here's the JAX version of what you did above, and it gives reasonable results:

import jax.numpy as jnp

jnp_value = jnp.ones((2, 2), dtype=jnp.bfloat16)

print(jnp.result_type(jnp_value, jnp.nan))
# bfloat16
print((jnp_value + jnp.nan).dtype) 
# bfloat16

print((jnp.array([1,2,3], jnp.float32) * 1.).dtype)   # float32 (same type)
print((jnp.array([1,2,3], jnp.float16) * 1.).dtype)   # float16 (same type)
print((jnp.array([1,2,3], jnp.bfloat16) * 1.).dtype)  # bfloat16 (same type)

If you'd like to report the numpy behavior as a bug, you could do so at http://github.com/numpy/numpy/issues, but I suspect maintainers would close it as "works as intended". I don't know of any way that NumPy could be changed to correctly handle arbitrary third-party dtypes, which is what bfloat16 is.

Alternatively, if you'd like to see NumPy handle bfloat16 natively, you could peruse this issue: numpy/numpy#19808. It has some discussion of why NumPy may or may not include bfloat16 natively in the future.

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 8, 2022

I'm going to close this issue, since I believe JAX is handling bfloat16 correctly in all cases, and the remaining poor behavior is due to numpy's handling of custom dtypes, something that JAX has no control over. Thanks for the report!

@jakevdp jakevdp closed this as completed Jun 8, 2022
@alvarosg
Copy link
Contributor Author

alvarosg commented Jun 8, 2022

Sorry, I did not have time to reply before this was closed.

I would suggest that when working with bfloat16, you use JAX operations, because JAX is aware of the dtype and handles it correctly.

But using JAX ops tends to want to put things on accelerators, right? So for CPU only loads, e.g. prefetching we tend to use either TensorFlow or plain Numpy.

Since bfloat16 is a custom dtype that numpy has no knowledge of, I think it's somewhat expected that numpy type promotion does strange things with it.
I don't know of any way that NumPy could be changed to correctly handle arbitrary third-party dtypes, which is what bfloat16 is.

So whose responsibility is to maintain a custom type for numpy, is it numpy or is it JAX, or is it XLA?

As far as I can tell, it is JAX which is doing np.dtype(jax._src.dtypes.bfloat16) in here and here. Does this not mean it is JAX responsibility to somehow either support the custom type on numpy, or at least throw an error if a numpy op is attempted on this array, rather than allowing undefined behavior to happen?

It may seem that I am being nit-picky, but I suspect in practice anyone doing bfloat16 manipulation outside of jitted function may end up falling into this trap.

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 8, 2022

So whose responsibility is to maintain a custom type for numpy, is it numpy or is it JAX, or is it XLA?

The custom bfloat16 dtype is currently defined in tensorflow; JAX extracts this definition and bundles it as part of jaxlib. There has been recent discussion about whether bfloat16 should instead be defined elsewhere as a stand-alone numpy dtype extension, but I'm not sure of the status of that.

Does this not mean it is JAX responsibility to somehow either support the custom type on numpy, or at least throw an error if a numpy op is attempted on this array, rather than allowing undefined behavior to happen?

No, in an operation like np.ones((2, 2), dtype=jnp.bfloat16) + 1, JAX code never comes into play in type promotion decisions, so if we wanted to fix this issue, there would be no possible way to do so by changing code in JAX. This type promotion behavior is defined by numpy, and the only way to change it would be to modify the numpy package.

But using JAX ops tends to want to put things on accelerators, right? So for CPU only loads, e.g. prefetching we tend to use either TensorFlow or plain Numpy.

You can safely do computations with bfloat16 numpy arrays, you just need to be explicit in cases where numpy's default type promotion rules don't give you the result you want. For example, here's the solution to your problematic case above:

import numpy as np
import jax.numpy as jnp

np.ones((2, 2), dtype=jnp.bfloat16) * np.array(1, dtype=jnp.bfloat16)
# array([[1, 1],
#        [1, 1]], dtype=bfloat16)

But I suspect there's very little value in doing numpy operations in bfloat16, because CPUs aren't optimized for this data type. So my guess is you'd be better off doing host-side computations in float32 or float64, and then converting to bfloat16 before pushing the results to device.

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 8, 2022

If making numpy natively handle bfloat16 is important to you, I'd suggest weighing in on numpy/numpy#19808.

@alvarosg
Copy link
Contributor Author

alvarosg commented Jun 8, 2022

Thanks for the extra explanation.

You can safely do computations with bfloat16 numpy arrays, you just need to be explicit in cases where numpy's default type promotion rules don't give you the result you want.

It is great to hear this. I guess the reason I was getting nervous was because my thinking was something like:
"If numpy is not self consistent with its own np.result_type method, e.g.:

np_value = np.ones((2, 2), dtype=jnp.bfloat16)
print(np.result_type(np_value, np.nan))  # Raises value error
print((np_value + np.nan).dtype)   # Does not raise value error and returns float32.

what other computations may be silently outputting the wrong thing when using bfloat16?"

I just don't know the details of what it involves to write a custom type for numpy, my guess was that since numpy does not natively support bfloat16, then when I do something like np.ones((2, 2), dtype=jnp.bfloat16), some other external-to-numpy code must have told numpy how to generate ones for bfloat16 specifically. And perhaps, that same other code that tells numpy how to generate ones for bfloat16, or how to do sums for bfloat16, should also be responsible to configure appropriate type promotion rules for numpy for the new type. I just don't know if configuring this for custom types on numpy is at all possible, or how implementations of methods such as ones are actually injected for custom types. My guess was that, considering that numpy does not natively support bfloat16, but I can still successfully call np.result_type(jnp.bfloat16, jnp.bfloat16), as well as np.result_type(jnp.bfloat16, np.float32), then the logic for defining those rules must have been injected externally too.

But probably I got the whole picture about how this works wrong, and if you have a recommendation for documentation explaining how all of this magic works, I would really appreciate a pointer :)

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 8, 2022

I don't know of any good documentation for creating new custom dtypes, but you can get an idea of what is available by looking at the bfloat16 source: https://github.com/tensorflow/tensorflow/blob/1f0782337808704e564883f77d89e51b50d60808/tensorflow/python/lib/core/bfloat16.cc

In particular, I think this is the line that results in the problematic promotion that you've observed: https://github.com/tensorflow/tensorflow/blob/1f0782337808704e564883f77d89e51b50d60808/tensorflow/python/lib/core/bfloat16.cc#L422-L427

@alvarosg
Copy link
Contributor Author

alvarosg commented Jun 9, 2022

Thanks so much! This sounds like I should file the bug in the TF repo then, thanks!

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 9, 2022

You certainly could, but given how numpy type promotion works, I don’t think there’s any change that TF could make to that code that would improve the situation. It’s fundamentally a numpy dtype promotion bug.

@alvarosg
Copy link
Contributor Author

alvarosg commented Jun 9, 2022

Although, as a final nit. If I try this with TF I get an error:

import numpy as np
import tensorflow as tf
np.array(2., dtype=tf.bfloat16)  # Cannot interpret 'tf.bfloat16' as a data type

So it seems that TF are trying to discourage people from actually using this type of behavior. (Although I can reproduce the same if I do tf.constant(2., dtype=tf.bfloat16).numpy())

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 9, 2022

I think you're looking for tf.bfloat16.as_numpy_dtype, but again, this is not a TF bug, it is a numpy bug, and the appropriate place to report this issue is in the numpy repo, however numpy is not currently interested in fixing this. If you're interested in a path to a fix, you should weigh in on numpy/numpy#19808.

@alvarosg
Copy link
Contributor Author

alvarosg commented Jun 9, 2022

Fair enough, I will just do that. Thanks again for all of your time!

@alvarosg
Copy link
Contributor Author

alvarosg commented Jun 9, 2022

@hawkinsp, based on your reply here, should we reopen this then?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants