Skip to content

Conversation

@ggerganov
Copy link
Member

ref #13231 (comment)

Sample implementation for using FA in the CLIP. Reduces memory usage and improves performance.

Testing with Gemma 12B, using llama-server and 2 images:

# before
alloc_compute_meta:      Metal compute buffer size =  1132.00 MiB
alloc_compute_meta:        CPU compute buffer size =     9.19 MiB

srv  process_chun: image processed in 1653 ms
srv  process_chun: image processed in 1093 ms

# after
alloc_compute_meta:      Metal compute buffer size =   121.25 MiB
alloc_compute_meta:        CPU compute buffer size =     9.19 MiB

srv  process_chun: image processed in 1386 ms
srv  process_chun: image processed in 810 ms

TODO:

  • Add FA sizes to other backends (f.ex. Gemma uses non-standard head size of 72)

@ggerganov
Copy link
Member Author

ggerganov commented Oct 29, 2025

@ngxson Feel free to use this PR as a starting point for enabling FA generally in the CLIP. I think the main thing that we are missing is backend support for any unusual head sizes that might occur with vision models. I added Metal support for HS=72 as an example.

@github-actions github-actions bot added testing Everything test related examples ggml changes relating to the ggml tensor library for machine learning Apple Metal https://en.wikipedia.org/wiki/Metal_(API) labels Oct 29, 2025
@ngxson
Copy link
Collaborator

ngxson commented Oct 29, 2025

Thanks for the initial work. I actually planned to work on flash attn this week/next week, so this will help me a lot.

Btw, how do we know if we can use flash attn for a given cgraph? I don't quite remember how llama.cpp check if the current model can use flash attn or not.

@ggerganov
Copy link
Member Author

Btw, how do we know if we can use flash attn for a given cgraph?

Maybe you are thinking about the logic in libllama to for checking when to enable FA?

// resolve automatic Flash Attention use
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
if (!gf) {
throw std::runtime_error("failed to split graph for Flash Attention check");
}
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
bool fa_device_mismatch = false;
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
ggml_tensor * n = ggml_graph_node(gf, i);
if (n->op != GGML_OP_FLASH_ATTN_EXT) {
continue;
}
ggml_backend_dev_t device_fa = ggml_backend_get_device(
ggml_backend_sched_get_tensor_backend(sched.get(), n));
// TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
const int il = std::stoi(n->name + prefix_len);
ggml_backend_dev_t device_kv = model.dev_layer(il);
if (device_fa != device_kv) {
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
"is assigned to device %s (usually due to missing support)\n",
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
// FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
fa_device_mismatch = true;
break;
}
}
if (fa_device_mismatch) {
cparams.flash_attn = false;
LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
if (ggml_is_quantized(params.type_v)) {
throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
}
} else {
cparams.flash_attn = true;
LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
}
}

I think for CLIP it can safely be always enabled.

@ngxson
Copy link
Collaborator

ngxson commented Oct 29, 2025

Hmm so for example, in case I completely replace the clip's build_attn with flash attn, if a certain head size is not supported by the backend, it will be fallback to CPU, right?

@ggerganov
Copy link
Member Author

Yes. CPU FA is always supported. It could be a problem because we might not notice that the CPU fallback is being triggered in some cases. But still, it seems like the better default IMO.

@ngxson
Copy link
Collaborator

ngxson commented Oct 29, 2025

In such case, I think we still need the logic to check if the GPU backend supports flash attn or not.

I agree that flash attn should be the default, I think I can safely reuse the same llama_flash_attn_type enum (with LLAMA_FLASH_ATTN_TYPE_AUTO as the default value) and also reuse the detection logic from llama.cpp

@ggerganov
Copy link
Member Author

We can also print a big warning each time the CLIP scheduler runs with more than 1 graph splits. This way we will immediately spot cases where the implementation uses an unsupported operator.

Demonstrated in a4b54f2

Sample output from llama-server:

0.02.977.454 I main: server is listening on http://127.0.0.1:8013 - starting the main loop
0.02.977.454 I srv  update_slots: all slots are idle
0.10.164.226 I srv  params_from_: Chat format: Content-only
0.10.164.254 I slot get_availabl: id  0 | task -1 | selected slot by LRU, t_last = -1
0.10.164.280 I slot launch_slot_: id  0 | task 0 | processing task
0.10.164.285 I slot update_slots: id  0 | task 0 | new prompt, n_ctx_slot = 131072, n_keep = 0, n_prompt_tokens = 289
0.10.164.372 I slot update_slots: id  0 | task 0 | n_past = 0, memory_seq_rm [0, end)
0.10.164.377 I slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 27, n_tokens = 27, progress = 0.093426
0.10.168.064 I slot update_slots: id  0 | task 0 | n_past = 27, memory_seq_rm [27, end)
0.10.168.067 I srv  process_chun: processing image...
encoding image slice...
clip_image_batch_encode: *****************************************************************
clip_image_batch_encode: WARNING: the CLIP graph uses unsupported operators by the backend
clip_image_batch_encode:          the performance will be suboptimal                      
clip_image_batch_encode:                                                                  
clip_image_batch_encode: ref: https://github.com/ggml-org/llama.cpp/pull/16837#issuecomment-3461676118
clip_image_batch_encode: *****************************************************************
image slice encoded in 13058 ms
decoding image batch 1/1, n_tokens_batch = 256
image decoded (batch 1/1) in 4 ms
0.23.229.992 I srv  process_chun: image processed in 13062 ms
0.23.230.214 I slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 289, n_tokens = 6, progress = 1.000000
0.23.230.240 I slot update_slots: id  0 | task 0 | prompt done, n_past = 289, n_tokens = 6

Comment on lines 3213 to 3214
LOG_WRN("%s: flash attention not supported, memory usage will increase\n", __func__);
// TODO: maybe log more details about why flash attention is not supported
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ggerganov I implemented a simple solution to auto-enable flash attn only when the backend support it. Probably we should make this LOG_WRN to be more prominent. Also, which kind of info do you think should be displayed here?

Some users potentially already using models with shapes not supported by GPU flash attn. Falling back to CPU will suddenly make it very slow and thus not a good UX overall. The auto mode + prominent is a better solution as also it encourage users to "voluntary" report certain info back to us - less forcefully for them.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, which kind of info do you think should be displayed here?

We can print the actual tensor (shape, strides, types) for which FA is not supported.

Comment on lines 3268 to 3275
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
ggml_tensor * node = ggml_graph_node(gf, i);
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
if (!ggml_backend_supports_op(ctx_clip.backend, node)) {
return false;
}
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this by temporary modify the code to always return false in ggml_metal_device_supports_op, works well for now, but I'm not sure if there are any edge cases.

Currently, mtmd only support 2 backends at the same time: CPU and one GPU backend

@github-actions github-actions bot added the server label Nov 1, 2025
@ggerganov ggerganov marked this pull request as ready for review November 2, 2025 09:17
@ggerganov ggerganov requested a review from slaren as a code owner November 2, 2025 09:17
@ggerganov
Copy link
Member Author

I extended the warmup logic to print all ops that are not supported by the accelerated backend of the CLIP context. For example, we are now informed that the Metal backend does not support the UPSCALE op for Qwen3 VL:

alloc_compute_meta: warmup with image size = 512 x 512
alloc_compute_meta:      Metal compute buffer size =    38.02 MiB
alloc_compute_meta:        CPU compute buffer size =    16.02 MiB
alloc_compute_meta: graph splits = 3, nodes = 766
warmup: flash attention is enabled
warmup: op          UPSCALE is not supported by the CLIP backend: type = f32, ne = [32 32 1024 1]

Comment on lines 3262 to 3278
if (!unsupported_ops.empty()) {
LOG_WRN("%s: *****************************************************************\n", __func__);
LOG_WRN("%s: WARNING: the CLIP graph uses unsupported operators by the backend\n", __func__);
LOG_WRN("%s: the performance will be suboptimal \n", __func__);
LOG_WRN("%s: list of unsupported ops (backend=%s):\n", __func__, ggml_backend_name(ctx_clip.backend));
for (const auto & op : unsupported_ops) {
LOG_WRN("%s: %16s: type = %s, ne = [%d %d %d %d]\n", __func__,
ggml_op_name(op.op->op),
ggml_type_name(op.op->type),
op.op->ne[0], op.op->ne[1], op.op->ne[2], op.op->ne[3]);
}
LOG_WRN("%s: flash attention is %s\n", __func__,
(ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled");
LOG_WRN("%s: please report this on github as an issue\n", __func__);
LOG_WRN("%s: ref: https://github.com/ggml-org/llama.cpp/pull/16837#issuecomment-3461676118\n", __func__);
LOG_WRN("%s: *****************************************************************\n", __func__);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I improved these messages to make them more prominent while giving the user an instruction of what to do. Lmk if this looks good to you

Comment on lines +3229 to +3238
static auto print_shape = [](const char * fn, const char * name, ggml_tensor * t) {
LOG_WRN("%s: %s: type = %s, ne = [%d %d %d %d], nb = [%d %d %d %d]\n", fn,
name, ggml_type_name(t->type),
t->ne[0], t->ne[1], t->ne[2], t->ne[3],
t->nb[0], t->nb[1], t->nb[2], t->nb[3]);
};
print_shape(__func__, " dst", op);
print_shape(__func__, "src0", op->src[0]);
print_shape(__func__, "src1", op->src[1]);
print_shape(__func__, "src2", op->src[2]);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re. what to print when flash attn is not support, I'm printing tensor shapes, type, and stride

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be good - usually the head size (i.e. src[0]->ne[0]) is the thing that is not supported.

@ggerganov ggerganov requested a review from ngxson November 2, 2025 16:24
@ngxson ngxson merged commit 2f966b8 into master Nov 2, 2025
66 of 72 checks passed
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request Nov 3, 2025
* origin/master: (169 commits)
opencl: support imrope (ggml-org#16914)
fix: Viewing multiple PDF attachments (ggml-org#16974)
model-conversion : pass config to from_pretrained (ggml-org#16963)
server : add props.model_alias (ggml-org#16943)
ggml: CUDA: add head size 72 for flash-attn (ggml-org#16962)
mtmd: add --image-min/max-tokens (ggml-org#16921)
mtmd: pad mask for qwen2.5vl (ggml-org#16954)
ggml : LoongArch fixes (ggml-org#16958)
sync: minja (glm 4.6 & minmax m2 templates) (ggml-org#16949)
SYCL: optimized repeat_back kernel (3× fewer asm instructions, 2× faster)Feature/sycl repeat back opt (ggml-org#16869)
feat(webui): improve LaTeX rendering with currency detection (ggml-org#16508)
test-backend-ops : fix segfault in moe-expert-reduce test in support mode and coverage (ggml-org#16936)
ci : disable failing riscv cross build (ggml-org#16952)
model: add Janus Pro for image understanding (ggml-org#16906)
clip : use FA (ggml-org#16837)
server : support unified cache across slots (ggml-org#16736)
common : move gpt-oss reasoning processing to init params (ggml-org#16937)
docs: remove llama_sampler_accept reference in sampling sample usage (ggml-org#16920)
CUDA: add FLOOR, CEIL, ROUND, TRUNC unary ops (ggml-org#16917)
devops: fix failing s390x docker build (ggml-org#16918)
...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Apple Metal https://en.wikipedia.org/wiki/Metal_(API) examples ggml changes relating to the ggml tensor library for machine learning server testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants