-
-
Notifications
You must be signed in to change notification settings - Fork 854
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for bfloat16 #7527
Comments
Curious where/how would you use bf16 if CuPy were to support it? Any pointer or reference? Thanks! 🙂 |
It would be good if numpy data type extensions à la https://github.com/jax-ml/ml_dtypes/tree/main were supported, which includes bfloat16, fp8 etc. |
Seconding this! bfloat16 and fp8 support are important for my use case. I'd love to see these. |
Any progress on this? We really need it for LLM training and inference. |
bfloat16 support is sorely missed in cupy. Would really appreciate it getting fixed!
|
[bfloat16](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format) is widely used in LLM training and inference since it can achieve higher throughput and is less prone to weight growth. ray.util.collective use cupy.cuda.nccl for GPU communication, while cupy doesn't support bfloat16 for now (cupy/cupy#7527). So for allgather/reducescater operation, we should bypass cupy.array and use torch.tensor directly. Signed-off-by: wuxibin <wuxibin89@163.com> Co-authored-by: Stephanie Wang <swang@cs.berkeley.edu>
Would also love this! |
We are using a spiking neural network training library that actually implements custom CuPy functions for forward and backward propagation. The fact that CuPy lacks bfloat16 support is real pain for us. I would highly appreciate any progress on this issue. |
Description
Are there plans to support the bfloat16 data type in the near future? This data type is becoming increasingly popular in LLM training. It looks like currently it's not supported. I.e., calling
y = cp.asarray(x)
, wherex
is a torch tensor of typetorch.bfloat16
, returns "TypeError: Got unsupported ScalarType BFloat16". Are there any recommended workarounds in the meantime?Additional Information
No response
The text was updated successfully, but these errors were encountered: