-
Notifications
You must be signed in to change notification settings - Fork 13.3k
metal : FA support F32 K and V #16531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gg/cacheless-embd
Are you sure you want to change the base?
Conversation
@JohannesGaessler @jeffbolznv Would it be possible to add support for F32 K and V tensors in the respective backends? The issue is that these casts on Lines 1313 to 1326 in 4b2dae3
If we remove the casts, the memory usage should be significantly reduced for this use case. But to remove them, the FA implementation has to support |
It's definitely possible but it will require additional considerations w.r.t. SRAM limits. For the tile kernel what would need to be done is to determine FP16 vs. FP32 use via a template parameter rather than the |
Do the models for which this is relevant use GQA? |
We could also consider making the operations preceding FA write back their data as FP16 in the first place. In terms of performance that would definitely preferable for all CUDA/ROCm GPUs except for Pascal. |
Generally yes. If it would make the implementation simpler, maybe we can treat F32 K and V as just another "quantization" type, where the dequantize function is a cast to F16? |
For CUDA that can definitely be done with comparatively little effort but it would not eliminate the additional memory use, it would just shift it from the compute buffer to the buffer pool in the CUDA backend. |
I think this should be relatively straightforward in the vulkan backend, I'll look into it. This comment is how I'd expect to implement it (we dequantize while loading, so no extra memory usage):
|
Done for Vulkan in #16543 |
Basic CUDA support in #16546 . |
c308925
to
5734546
Compare
target #16528
Remove K and V casts with cacheless contexts(we should keep the casts for now)Sample command for testing:
llama-embedding -hf ggml-org/bge-small-en-v1.5-Q8_0-GGUF -e -p "$(printf 'hello %.0s' {1..510})" --pooling cls -c 512 -fa on