Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Float8 cache usage #155

Closed
YLGH opened this issue Mar 5, 2024 · 2 comments
Closed

Float8 cache usage #155

YLGH opened this issue Mar 5, 2024 · 2 comments

Comments

@YLGH
Copy link

YLGH commented Mar 5, 2024

Hi! I'm playing with batch_decode_with_padded_kv_cache and wanted to test out the FP8 KVCache. I couldn't find some good instructions on the docs,

I've tried the following:

num_qo_heads = 32
num_kv_heads = 32
batch_size = 16
head_dim = 128 
padded_kv_len = 1024


q = torch.empty(
                batch_size,
                num_qo_heads,
                head_dim,
                device=torch.device("cuda"),
                dtype=torch.float8_e4m3fn,
            )
k_padded = torch.randn(batch_size, padded_kv_len, num_kv_heads, head_dim).to("cuda:0").to(torch.float8_e4m3fn)
v_padded = torch.randn(batch_size, padded_kv_len, num_kv_heads, head_dim).to("cuda:0").to(torch.float8_e4m3fn)
o = flashinfer.batch_decode_with_padded_kv_cache(
    q, k_padded, v_padded, "NHD", "NONE"
)

But it gives me a BatchDecodeWithPaddedKVCache kernel launch failed: supported data type.

How can I enable FP8 KV cache? Thanks in advance!

@zhyncs
Copy link

zhyncs commented Mar 5, 2024

refer to #150

@yzh119
Copy link
Collaborator

yzh119 commented Mar 5, 2024

@YLGH done in #156 .

@yzh119 yzh119 closed this as completed Mar 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants