-
Notifications
You must be signed in to change notification settings - Fork 13.4k
Closed
Labels
Description
Prerequisites
- I am running the latest code. Mention the version if possible as well.
- I carefully followed the README.md.
- I searched using keywords relevant to my issue to make sure that I am creating a new issue that is not already open (or closed).
- I reviewed the Discussions, and have a new and useful enhancement to share.
Feature Description
Implement FA kernels supporting quantized KV cache on Metal GPU.
Motivation
As described in #8918, when using KV cache quantization on Metal GPU, FA will fallback to CPU, making it extremely slow.
Possible Implementation
I'm currently working it, and have a draft PR #9735.
The PR "works" since it's already a huge improvement than falling back to CPU. However, it has severe performance problems when processing long input. (131tok/s vs 265tok/s in pp2048)
I'm seeking advice to improvement it, I find three ways to implement it:
- modify 'kernel_flash_attn_ext_vec_f16': This is exactly what I do in the draft PR. The problem is that the kernel's performance is very low in prefill stage (which is expected).
- modify 'kernel_flash_attn_ext_f16': This kernel uses simdgroup matrix extensively and I don't find an obvious way to insert dequantization code. Maybe I could try first dequantize slice of K or V to shared buffer?
- copy cuda's implementation: I'm a bit confused since I find FA kernel in cuda and metal are not the same. The 'fattn-vec-f16.cuh' seems contains FA for quantized KV. However, the algorithm here use two kernels (one for computation, one for combine results), making it a bit hard to port to Metal.
Could anyone give me some advice on what's the best way to implement FA with quantized KV in Metal?
benja0x40