Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does flashinfer support float datatype? #191

Open
ZSL98 opened this issue Mar 26, 2024 · 3 comments
Open

Does flashinfer support float datatype? #191

ZSL98 opened this issue Mar 26, 2024 · 3 comments

Comments

@ZSL98
Copy link

ZSL98 commented Mar 26, 2024

The examples are all tensors of half() type. I wonder if flashinfer supports fp32 dtype?

@chenzhuofu
Copy link

I got the same question. I am instantiate the SinglePrefillWithKVCacheDispatched function, but found that it has static_assert(sizeof(DTypeIn) == 2); check. @yzh119 Does this for some implementation consideration?

@yzh119
Copy link
Collaborator

yzh119 commented Jun 5, 2024

The decode attention operators support fp32, we just need to add fp32 to this macro:

[&]() -> bool { \
switch (pytorch_dtype) { \
case at::ScalarType::Half: { \
using c_type = nv_half; \
return __VA_ARGS__(); \
} \
case at::ScalarType::BFloat16: { \
using c_type = nv_bfloat16; \
return __VA_ARGS__(); \
} \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \

For prefill/append attention, it's a little bit tricky, because many instructions such as ldmatrix (https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix) only supports 16bits, which makes it non-trivial to load fp32 tiles (especially the transposed load) from shared memory to registers. An option is to convert fp32 input to bf16 and use bf16 prefill attention kernels, we can design an api that accepts bf16/fp16 input and returns fp32 output in flashinfer.

@chenzhuofu
Copy link

chenzhuofu commented Jun 5, 2024

The decode attention operators support fp32, we just need to add fp32 to this macro:

[&]() -> bool { \
switch (pytorch_dtype) { \
case at::ScalarType::Half: { \
using c_type = nv_half; \
return __VA_ARGS__(); \
} \
case at::ScalarType::BFloat16: { \
using c_type = nv_bfloat16; \
return __VA_ARGS__(); \
} \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \

For prefill/append attention, it's a little bit tricky, because many instructions such as ldmatrix (https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix) only supports 16bits, which makes it non-trivial to load fp32 tiles (especially the transposed load) from shared memory to registers. An option is to convert fp32 input to bf16 and use bf16 prefill attention kernels, we can design an api that accepts bf16/fp16 input and returns fp32 output in flashinfer.

Got it, my use case is prefill/append kernel and it looks tricky indeed. Thanks for your kind reply. I think the support of fp32 output sounds great and helpful!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants