Skip to content

Feature Request: [metal] implement FA kernels for quantized KV cache  #9736

@FanShupei

Description

@FanShupei

Prerequisites

  • I am running the latest code. Mention the version if possible as well.
  • I carefully followed the README.md.
  • I searched using keywords relevant to my issue to make sure that I am creating a new issue that is not already open (or closed).
  • I reviewed the Discussions, and have a new and useful enhancement to share.

Feature Description

Implement FA kernels supporting quantized KV cache on Metal GPU.

Motivation

As described in #8918, when using KV cache quantization on Metal GPU, FA will fallback to CPU, making it extremely slow.

Possible Implementation

I'm currently working it, and have a draft PR #9735.

The PR "works" since it's already a huge improvement than falling back to CPU. However, it has severe performance problems when processing long input. (131tok/s vs 265tok/s in pp2048)

I'm seeking advice to improvement it, I find three ways to implement it:

  • modify 'kernel_flash_attn_ext_vec_f16': This is exactly what I do in the draft PR. The problem is that the kernel's performance is very low in prefill stage (which is expected).
  • modify 'kernel_flash_attn_ext_f16': This kernel uses simdgroup matrix extensively and I don't find an obvious way to insert dequantization code. Maybe I could try first dequantize slice of K or V to shared buffer?
  • copy cuda's implementation: I'm a bit confused since I find FA kernel in cuda and metal are not the same. The 'fattn-vec-f16.cuh' seems contains FA for quantized KV. However, the algorithm here use two kernels (one for computation, one for combine results), making it a bit hard to port to Metal.

Could anyone give me some advice on what's the best way to implement FA with quantized KV in Metal?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions