- 
                Notifications
    
You must be signed in to change notification settings  - Fork 13.5k
 
clip : use FA #16837
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
clip : use FA #16837
Conversation
| 
           @ngxson Feel free to use this PR as a starting point for enabling FA generally in the CLIP. I think the main thing that we are missing is backend support for any unusual head sizes that might occur with vision models. I added Metal support for HS=72 as an example.  | 
    
| 
           Thanks for the initial work. I actually planned to work on flash attn this week/next week, so this will help me a lot. Btw, how do we know if we can use flash attn for a given cgraph? I don't quite remember how llama.cpp check if the current model can use flash attn or not.  | 
    
          
 Maybe you are thinking about the logic in  llama.cpp/src/llama-context.cpp Lines 291 to 331 in 41ebbfd 
 I think for CLIP it can safely be always enabled.  | 
    
| 
           Hmm so for example, in case I completely replace the clip's   | 
    
| 
           Yes. CPU FA is always supported. It could be a problem because we might not notice that the CPU fallback is being triggered in some cases. But still, it seems like the better default IMO.  | 
    
| 
           In such case, I think we still need the logic to check if the GPU backend supports flash attn or not. I agree that flash attn should be the default, I think I can safely reuse the same   | 
    
| 
           We can also print a big warning each time the CLIP scheduler runs with more than 1 graph splits. This way we will immediately spot cases where the implementation uses an unsupported operator. Demonstrated in a4b54f2 Sample output from   | 
    
        
          
                tools/mtmd/clip.cpp
              
                Outdated
          
        
      | LOG_WRN("%s: flash attention not supported, memory usage will increase\n", __func__); | ||
| // TODO: maybe log more details about why flash attention is not supported | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ggerganov I implemented a simple solution to auto-enable flash attn only when the backend support it. Probably we should make this LOG_WRN to be more prominent. Also, which kind of info do you think should be displayed here?
Some users potentially already using models with shapes not supported by GPU flash attn. Falling back to CPU will suddenly make it very slow and thus not a good UX overall. The auto mode + prominent is a better solution as also it encourage users to "voluntary" report certain info back to us - less forcefully for them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, which kind of info do you think should be displayed here?
We can print the actual tensor (shape, strides, types) for which FA is not supported.
| for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { | ||
| ggml_tensor * node = ggml_graph_node(gf, i); | ||
| if (node->op == GGML_OP_FLASH_ATTN_EXT) { | ||
| if (!ggml_backend_supports_op(ctx_clip.backend, node)) { | ||
| return false; | ||
| } | ||
| } | ||
| } | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested this by temporary modify the code to always return false in ggml_metal_device_supports_op, works well for now, but I'm not sure if there are any edge cases.
Currently, mtmd only support 2 backends at the same time: CPU and one GPU backend
| 
           I extended the warmup logic to print all ops that are not supported by the accelerated backend of the CLIP context. For example, we are now informed that the Metal backend does not support the UPSCALE op for Qwen3 VL:  | 
    
| if (!unsupported_ops.empty()) { | ||
| LOG_WRN("%s: *****************************************************************\n", __func__); | ||
| LOG_WRN("%s: WARNING: the CLIP graph uses unsupported operators by the backend\n", __func__); | ||
| LOG_WRN("%s: the performance will be suboptimal \n", __func__); | ||
| LOG_WRN("%s: list of unsupported ops (backend=%s):\n", __func__, ggml_backend_name(ctx_clip.backend)); | ||
| for (const auto & op : unsupported_ops) { | ||
| LOG_WRN("%s: %16s: type = %s, ne = [%d %d %d %d]\n", __func__, | ||
| ggml_op_name(op.op->op), | ||
| ggml_type_name(op.op->type), | ||
| op.op->ne[0], op.op->ne[1], op.op->ne[2], op.op->ne[3]); | ||
| } | ||
| LOG_WRN("%s: flash attention is %s\n", __func__, | ||
| (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled"); | ||
| LOG_WRN("%s: please report this on github as an issue\n", __func__); | ||
| LOG_WRN("%s: ref: https://github.com/ggml-org/llama.cpp/pull/16837#issuecomment-3461676118\n", __func__); | ||
| LOG_WRN("%s: *****************************************************************\n", __func__); | ||
| } | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I improved these messages to make them more prominent while giving the user an instruction of what to do. Lmk if this looks good to you
| static auto print_shape = [](const char * fn, const char * name, ggml_tensor * t) { | ||
| LOG_WRN("%s: %s: type = %s, ne = [%d %d %d %d], nb = [%d %d %d %d]\n", fn, | ||
| name, ggml_type_name(t->type), | ||
| t->ne[0], t->ne[1], t->ne[2], t->ne[3], | ||
| t->nb[0], t->nb[1], t->nb[2], t->nb[3]); | ||
| }; | ||
| print_shape(__func__, " dst", op); | ||
| print_shape(__func__, "src0", op->src[0]); | ||
| print_shape(__func__, "src1", op->src[1]); | ||
| print_shape(__func__, "src2", op->src[2]); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re. what to print when flash attn is not support, I'm printing tensor shapes, type, and stride
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be good - usually the head size (i.e. src[0]->ne[0]) is the thing that is not supported.
* origin/master: (169 commits) opencl: support imrope (ggml-org#16914) fix: Viewing multiple PDF attachments (ggml-org#16974) model-conversion : pass config to from_pretrained (ggml-org#16963) server : add props.model_alias (ggml-org#16943) ggml: CUDA: add head size 72 for flash-attn (ggml-org#16962) mtmd: add --image-min/max-tokens (ggml-org#16921) mtmd: pad mask for qwen2.5vl (ggml-org#16954) ggml : LoongArch fixes (ggml-org#16958) sync: minja (glm 4.6 & minmax m2 templates) (ggml-org#16949) SYCL: optimized repeat_back kernel (3× fewer asm instructions, 2× faster)Feature/sycl repeat back opt (ggml-org#16869) feat(webui): improve LaTeX rendering with currency detection (ggml-org#16508) test-backend-ops : fix segfault in moe-expert-reduce test in support mode and coverage (ggml-org#16936) ci : disable failing riscv cross build (ggml-org#16952) model: add Janus Pro for image understanding (ggml-org#16906) clip : use FA (ggml-org#16837) server : support unified cache across slots (ggml-org#16736) common : move gpt-oss reasoning processing to init params (ggml-org#16937) docs: remove llama_sampler_accept reference in sampling sample usage (ggml-org#16920) CUDA: add FLOOR, CEIL, ROUND, TRUNC unary ops (ggml-org#16917) devops: fix failing s390x docker build (ggml-org#16918) ...
ref #13231 (comment)
Sample implementation for using FA in the CLIP. Reduces memory usage and improves performance.
Testing with Gemma 12B, using
llama-serverand 2 images:TODO: