Skip to content

llama: use f16 mask for FA to save VRAM#23764

Merged
am17an merged 3 commits into
ggml-org:masterfrom
am17an:kq_mask_f16
May 29, 2026
Merged

llama: use f16 mask for FA to save VRAM#23764
am17an merged 3 commits into
ggml-org:masterfrom
am17an:kq_mask_f16

Conversation

@am17an
Copy link
Copy Markdown
Contributor

@am17an am17an commented May 27, 2026

Overview

Currently we reserve the KQ mask in f32 even if FA is used, which is then is converted to f16 while passing to backends. The f32 mask still uses the compute buffer even though is not used, taking up extra VRAM. This PR reserves the kq-mask in f16. This provides 1.2GB of VRAM saving at -ub 2048 and ~300Mb at -ub 512 when using MTP

Additional information

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: YES, used CC and Codex to identify the problem and write the code. Will polish it up a bit

@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented May 27, 2026

This stacked along with reserving only n_outputs == n_seqs saves some more VRAM. On -ub 512 I go from 824 Mb as compute buffer to 444 Mb. On -ub 2048 from 3.2GB to 1.5GB

Copy link
Copy Markdown
Contributor

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like the right approach to me. For non-FA configurations the mask is still FP32 but since the mask size scales linearly with context depth while the size of the KQ matrix scales quadratically it should not make a difference. For most models we would not even need FP16 since the mask values will just be either 0 or -inf. Or we could even calculate the mask values from indices (but which gets complicated a lot for multiple concurrent contexts). But these further optimizations would be a lot more invasive, require a lot of effort to implement properly, and only make sense if with this PR the mask would still be the largest tensor in the compute graph.

@JohannesGaessler
Copy link
Copy Markdown
Contributor

@am17an since you wrote:

Will polish it up a bit

Does that mean there will be more changes or is this PR ready for review?

@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented May 28, 2026

You can review, I just didn't like the extra llama_mask function

Comment thread src/llama-impl.h Outdated
Comment thread src/llama-graph.cpp Outdated
Comment thread src/llama-graph.cpp Outdated
Comment thread src/llama-graph.cpp Outdated
Comment thread src/llama-graph.h Outdated
Comment thread src/llama-graph.h Outdated
Comment thread src/llama-graph.h Outdated
Comment thread src/llama-graph.h Outdated
Comment thread src/llama-graph.h Outdated
Comment thread src/llama-graph.h Outdated
@am17an am17an requested a review from JohannesGaessler May 28, 2026 15:38
Copy link
Copy Markdown
Contributor

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the code for creating the mask again, I think it can be simplified a bit more with this patch:

diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
index 92e8d0d29..1b976eb8e 100644
--- a/src/llama-graph.cpp
+++ b/src/llama-graph.cpp
@@ -396,8 +396,9 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
     const int64_t n_kv     = ubatch->n_tokens;
     const int64_t n_tokens = ubatch->n_tokens;
 
-    const auto fill_mask_inner = [&](auto * data, int n_swa, llama_swa_type swa_type) {
+    const auto fill_mask = [&](auto * data, int64_t ne, int n_swa, llama_swa_type swa_type) {
         using T = std::remove_reference_t<decltype(*data)>;
+        std::fill(data, data + ne, llama_cast<T>(-INFINITY));
         for (int i1 = 0; i1 < n_tokens; ++i1) {
             const llama_seq_id s1 = ubatch->seq_id[i1][0];
             const llama_pos    p1 = ubatch->pos[i1];
@@ -426,39 +427,27 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
                 data[idst + i0] = llama_cast<T>(hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f);
             }
         }
-    };
-
-    const auto fill_mask = [&](ggml_tensor * mask, int n_swa, llama_swa_type swa_type) {
-        GGML_ASSERT(mask);
-        GGML_ASSERT(ggml_backend_buffer_is_host(mask->buffer));
-
-        if (mask->type == GGML_TYPE_F16) {
-            ggml_fp16_t * data = (ggml_fp16_t *) mask->data;
-
-            std::fill(data, data + ggml_nelements(mask), llama_cast<ggml_fp16_t>(-INFINITY));
-
-            fill_mask_inner(data, n_swa, swa_type);
-
-            if (debug) {
-                print_mask(data, n_tokens, n_kv, n_swa, swa_type);
-            }
-        } else {
-            float * data = (float *) mask->data;
-
-            std::fill(data, data + ggml_nelements(mask), -INFINITY);
-
-            fill_mask_inner(data, n_swa, swa_type);
-
-            if (debug) {
-                print_mask(data, n_tokens, n_kv, n_swa, swa_type);
-            }
+        if (debug) {
+            print_mask(data, n_tokens, n_kv, n_swa, swa_type);
         }
     };
 
-    fill_mask(self_kq_mask, 0, LLAMA_SWA_TYPE_NONE);
+    GGML_ASSERT(self_kq_mask);
+    GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
+    if (self_kq_mask->type == GGML_TYPE_F16) {
+        fill_mask((ggml_fp16_t *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE);
+    } else {
+        fill_mask((float       *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE);
+    }
 
     if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
-        fill_mask(self_kq_mask_swa, hparams.n_swa, hparams.swa_type);
+        GGML_ASSERT(self_kq_mask_swa);
+        GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
+        if (self_kq_mask_swa->type == GGML_TYPE_F16) {
+            fill_mask((ggml_fp16_t *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), 0, LLAMA_SWA_TYPE_NONE);
+        } else {
+            fill_mask((float       *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), 0, LLAMA_SWA_TYPE_NONE);
+        }
     }
 }

Comment thread src/llama-impl.h Outdated
Comment thread src/llama-graph.cpp Outdated
@am17an am17an merged commit 031ddb2 into ggml-org:master May 29, 2026
27 checks passed
@am17an am17an deleted the kq_mask_f16 branch May 29, 2026 07:44
DrBearJew pushed a commit to DrBearJew/llama.cpp that referenced this pull request May 29, 2026
From upstream PR ggml-org#23764. Avoids ggml_cast in graph builder.
Our kernel already reads masks as half — no changes needed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants