Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: numpy bfloat16 support #19808

Open
jakpiase opened this issue Sep 1, 2021 · 20 comments
Open

ENH: numpy bfloat16 support #19808

jakpiase opened this issue Sep 1, 2021 · 20 comments
Labels
01 - Enhancement 33 - Question Question about NumPy usage or development Tracking / planning

Comments

@jakpiase
Copy link

jakpiase commented Sep 1, 2021

Feature

Hi,
I am working at PaddlePaddle(chinese DL framework). We and other DL frameworks would extremely benefit from integrated bfloat16 numpy datatype. I have seen that TF added its own implementation and lately a standalone pip package bfloat16 was released.
I have also seen NEP 41, 42, 43 which from what I understand will allow adding new, user-defined datatypes. Are you guys planning to integrate bfloat16 into core numpy? If you don't have the bandwidth to do that, is there anyone that can guide me how to implement that so I can make a PR? And how are these NEPs going?

@melissawm melissawm added the 33 - Question Question about NumPy usage or development label Sep 2, 2021
@seberg
Copy link
Member

seberg commented Sep 7, 2021

NEPs 41-43 are mostly implemented (there is one major step remaining, that is mostly finished though; there will be few smaller steps until we are there though, and I expect things will come up as soon as someone tries to write a more complex dtype – which bfloat16 is probably not though).

However, it is not yet public API, I hope to followup with an experimental public API very soon, though. And there is a repo (which has to change a bit soon) here.

I am personally not sure whether or not we should include bfloat16 in NumPy proper, there are some advantages, and some disadvantages. (NumPy dtypes are a bit strange sometimes, defining bfloat16 at least somewhat outside of NumPy would sidestep this issue, by being very clear about it being not one of the "main" NumPy dtypes.)

But, opinions will probably differ, I would somewhat prefer a numpy_bfloat16 package right now, the main problem is that not including it NumPy means it needs the new DType API to be out of an experimental state probably (or maybe strict NumPy version binding. I.e. force the user to install numpy_bfloat16 with a version of NumPy that it knows it will work with correctly.)

Or, we include it in NumPy (for ABI reasons) but make it accessible through its own import?

@leofang
Copy link
Contributor

leofang commented Sep 7, 2021

Or, we include it in NumPy (for ABI reasons) but make it accessible through its own import?

This would be my preference for bfloat16 (the current issue) and complex32 (#14753). Would be very useful for GPU libraries like CuPy (and especially CuPy, which does not have its own type system and simply uses the dtypes from NumPy).

@seberg
Copy link
Member

seberg commented Sep 7, 2021

No matter the outcome of the discussion, I think it may make sense to start with a new bfloat16 pypi package as soon as the experimental API for it exists (which I can expedite a bit more if someone wants to work on creating such a prototype package – its great for motivation ☺; and it already exists enough that one could start IMO).

The reason is that I think there are a few questions that need answering, e.g. whether we even want a bfloat16 scalar.

@jakirkham
Copy link
Contributor

Agreed. Just to add to that, there may be issues in a bfloat16 implementation that may need faster iteration on to address. Having it live outside NumPy for a while until things mature is likely good both for the bfloat16 implementation and NumPy.

Sebastian, do you have thoughts on whether having a new repo live here (under the NumPy org) would make sense? Or is there some incubator or contrib NumPy org that would make sense? How should experimental things like this develop?

@seberg
Copy link
Member

seberg commented Sep 8, 2021

I think we can start such a thing in the NumPy repo organization if it seems like the easiest place. I am happy to help, come to meetings, etc. but at least for now I would much prefer not to champion such an effort/take the brunt of implementation.
It is not really all that high priority for me and it would be helpful if more people get acquainted (and thus are able to provide feedback) with the new DType API.

@alvarosg
Copy link

alvarosg commented Jun 9, 2022

+1 to adding this support.

From the thread in this JAX bug, it seems that there are several inconsistencies in type promotion behavior for the current custom bfloat16 type used by both TensorFlow and JAX, which cannot be fixed without actually adding better native support in numpy (as per my understanding from what I got in that bug).

@seberg
Copy link
Member

seberg commented Jun 9, 2022

which cannot be fixed without actually adding better native support in numpy

This is not true anymore, they can be addressed in a fully compatible way with the new (although not yet public) API. Yes, there will be road-bumps, but the point is that there is a path of starting this outside of NumPy and then considering whether that is good or whether inclusion is desired.

EDIT: The API is experimentally public, see also https://github.com/seberg/unitdtype

@hawkinsp
Copy link
Contributor

hawkinsp commented Jun 9, 2022

I wrote that bfloat16 extension (the TF/JAX one) targeting NumPy 1.16 when working on early TPUs. NumPy's type extension APIs have certainly advanced since then. So it's probably time someone had a go at adapting it to the new APIs.

In passing: one other observation I will make while I am here is that at least two varieties of 8-bit floating point type are gaining popularity in the machine learning world as well (e.g., non-standard E4M3 and E5M2 types from hardware vendors), and we were looking at generalizing the extension to support those as well.

@alvarosg
Copy link

alvarosg commented Jun 9, 2022

(@jakevdp for visibility)

Thanks! In that case, @hawkinsp would it be useful if I file a bug similar to google/jax#11014 , but in the tensorflow repo?

@hawkinsp
Copy link
Contributor

hawkinsp commented Jun 9, 2022

The JAX repo bug is fine: in practice we are the ones working on this and that's the project we work in most.

(Really this extension should be its own pip package, and we may do that. I note that someone did fork our code and do this already (bfloat16), so I think there are at least 4 packages through which you can get what is in essence the same extension now, and that's a mess that needs to be cleaned up.)

@jakpiase
Copy link
Author

Hi @hawkinsp, I am an author of paddle_bfloat package which is an upgraded version of "original" bfloat16 package which contains some minor fixes(fixing compilator warnings, some missing functions on windows etc). If I could help you in any way, please let me know.

@BlueskyFR
Copy link

Is some bfloat16 support being developped in Numpy or is it still an idea as of today?

@mattip
Copy link
Member

mattip commented Jun 30, 2022

@BlueskyFR We will adopt this approach

... there is a path of starting this outside of NumPy and then considering whether that is good or whether inclusion is desired.

so please help one of the implementations mentioned here to reach maturity.

@BlueskyFR
Copy link

@mattip I don't really understand what is going on, why is it developed as external repos and not as a Numpy branch?

@mattip
Copy link
Member

mattip commented Jul 18, 2022

NumPy is very conservative when it comes to enhancements since any new NumPy enhancement becomes part of the entire scientific python stack. There are too many unknowns with bfloat16 support by compilers and hardware for it to be immediately part of NumPy, the team would prefer the code mature outside of NumPy first.

The Data API has also not adopted bfloat16, see point 4 on what is out of scope for the standard.

The approach to first have code as a stand-alone library was successfully used when revamping the numpy.random module: it was first developed as part of randomgen, then NEP 19 was written about the migration path, and only then the code was merged to NumPy.

@BlueskyFR
Copy link

Okay, thanks for the explanation. Still, developing a feature as a branch of the repo doesn't mean that it is released with it, it would just be a common place to develop such a feature. The problem with waiting for external libraries to do that is that community effort is extremely limited whereas the Numpy repo can bring people together.

I am affraid that waiting for some external support requires a couple more years before reaching something stable, so IMO the Numpy repo has the capacity to help this effort with endangering the public stable releases.

@seberg
Copy link
Member

seberg commented Jul 18, 2022

@BlueskyFR doing a branch in NumPy means you get slowed down by PRs to the NumPy repo. Slower review, etc. If there was a champion who is already a NumPy dev there may be a point to it, but we do not have that.

We can still consider putting it into the NumPy organization, but I am not sure that it would help much. NumPy does not have resources to push this, having a feature branch that doesn't move quickly will not be helpful IMO.

You can use this issue, the NumPy mailing list, and https://discuss.scientific-python.org to coordinate. I can try to keep an eye on things also to have a NumPy dev in the loop, but we are not the ones that can drive this for now.

@BlueskyFR
Copy link

BlueskyFR commented Jul 18, 2022

Okay, many thanks for the replies!

@ogrisel
Copy link
Contributor

ogrisel commented Sep 23, 2022

The problem with waiting for external libraries to do that is that community effort is extremely limited whereas the Numpy repo can bring people together.

@BlueskyFR if you start such an external implementation, please put the a link to it in a comment here (or even in the description of this issue) so that people googling for "numpy bfloat16" and end up here as I just did will find there way to your project and start consolidating efforts there.

@jakevdp
Copy link
Contributor

jakevdp commented Mar 21, 2023

Hi all – we just released ml_dtypes: https://github.com/jax-ml/ml_dtypes. It's a lightweight stand-alone implementation of bfloat16 and a few other dtypes used in machine learning contexts, extracted from the source code of JAX/XLA. We plan to maintain this actively going forward as the source for dtype definitions in JAX and other libraries. The code is permissively licensed, with the goal that if NumPy were ever to add a native bfloat16 dtype, it could be adapted from the implementation in ml_dtypes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
01 - Enhancement 33 - Question Question about NumPy usage or development Tracking / planning
Projects
None yet
Development

No branches or pull requests