llama: use f16 mask for FA to save VRAM#23764
Conversation
|
This stacked along with reserving only |
JohannesGaessler
left a comment
There was a problem hiding this comment.
This seems like the right approach to me. For non-FA configurations the mask is still FP32 but since the mask size scales linearly with context depth while the size of the KQ matrix scales quadratically it should not make a difference. For most models we would not even need FP16 since the mask values will just be either 0 or -inf. Or we could even calculate the mask values from indices (but which gets complicated a lot for multiple concurrent contexts). But these further optimizations would be a lot more invasive, require a lot of effort to implement properly, and only make sense if with this PR the mask would still be the largest tensor in the compute graph.
|
@am17an since you wrote:
Does that mean there will be more changes or is this PR ready for review? |
|
You can review, I just didn't like the extra llama_mask function |
JohannesGaessler
left a comment
There was a problem hiding this comment.
Looking at the code for creating the mask again, I think it can be simplified a bit more with this patch:
diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
index 92e8d0d29..1b976eb8e 100644
--- a/src/llama-graph.cpp
+++ b/src/llama-graph.cpp
@@ -396,8 +396,9 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
const int64_t n_kv = ubatch->n_tokens;
const int64_t n_tokens = ubatch->n_tokens;
- const auto fill_mask_inner = [&](auto * data, int n_swa, llama_swa_type swa_type) {
+ const auto fill_mask = [&](auto * data, int64_t ne, int n_swa, llama_swa_type swa_type) {
using T = std::remove_reference_t<decltype(*data)>;
+ std::fill(data, data + ne, llama_cast<T>(-INFINITY));
for (int i1 = 0; i1 < n_tokens; ++i1) {
const llama_seq_id s1 = ubatch->seq_id[i1][0];
const llama_pos p1 = ubatch->pos[i1];
@@ -426,39 +427,27 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
data[idst + i0] = llama_cast<T>(hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f);
}
}
- };
-
- const auto fill_mask = [&](ggml_tensor * mask, int n_swa, llama_swa_type swa_type) {
- GGML_ASSERT(mask);
- GGML_ASSERT(ggml_backend_buffer_is_host(mask->buffer));
-
- if (mask->type == GGML_TYPE_F16) {
- ggml_fp16_t * data = (ggml_fp16_t *) mask->data;
-
- std::fill(data, data + ggml_nelements(mask), llama_cast<ggml_fp16_t>(-INFINITY));
-
- fill_mask_inner(data, n_swa, swa_type);
-
- if (debug) {
- print_mask(data, n_tokens, n_kv, n_swa, swa_type);
- }
- } else {
- float * data = (float *) mask->data;
-
- std::fill(data, data + ggml_nelements(mask), -INFINITY);
-
- fill_mask_inner(data, n_swa, swa_type);
-
- if (debug) {
- print_mask(data, n_tokens, n_kv, n_swa, swa_type);
- }
+ if (debug) {
+ print_mask(data, n_tokens, n_kv, n_swa, swa_type);
}
};
- fill_mask(self_kq_mask, 0, LLAMA_SWA_TYPE_NONE);
+ GGML_ASSERT(self_kq_mask);
+ GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
+ if (self_kq_mask->type == GGML_TYPE_F16) {
+ fill_mask((ggml_fp16_t *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE);
+ } else {
+ fill_mask((float *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE);
+ }
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
- fill_mask(self_kq_mask_swa, hparams.n_swa, hparams.swa_type);
+ GGML_ASSERT(self_kq_mask_swa);
+ GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
+ if (self_kq_mask_swa->type == GGML_TYPE_F16) {
+ fill_mask((ggml_fp16_t *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), 0, LLAMA_SWA_TYPE_NONE);
+ } else {
+ fill_mask((float *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), 0, LLAMA_SWA_TYPE_NONE);
+ }
}
}From upstream PR ggml-org#23764. Avoids ggml_cast in graph builder. Our kernel already reads masks as half — no changes needed.
Overview
Currently we reserve the KQ mask in f32 even if FA is used, which is then is converted to f16 while passing to backends. The f32 mask still uses the compute buffer even though is not used, taking up extra VRAM. This PR reserves the kq-mask in f16. This provides 1.2GB of VRAM saving at
-ub 2048and ~300Mb at-ub 512when using MTPAdditional information
Requirements