From 5734546315241c1d5b1cb2efaca7bd801542f82b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 12 Oct 2025 10:34:44 +0300 Subject: [PATCH 1/3] graph : support cacheless embeddings with FA and iSWA --- src/llama-graph.cpp | 137 ++++++++++++++++++++++++++++++-------------- src/llama-graph.h | 10 +++- src/llama-model.cpp | 5 +- src/llama.cpp | 1 + 4 files changed, 105 insertions(+), 48 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index a24853c63ada4..cb41c7959e69c 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -261,12 +261,17 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { } } -static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) { +static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) { LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__); - const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" : - (swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" : - (swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" : - (swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown"; + const char * swa_type_str = "unknown"; + + switch (swa_type) { + case LLAMA_SWA_TYPE_NONE: swa_type_str = "LLAMA_SWA_TYPE_NONE"; break; + case LLAMA_SWA_TYPE_STANDARD: swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break; + case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break; + case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break; + }; + LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str); LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__); LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__); @@ -295,50 +300,88 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { const int64_t n_kv = ubatch->n_tokens; const int64_t n_tokens = ubatch->n_tokens; - GGML_ASSERT(kq_mask); - GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer)); + { + GGML_ASSERT(self_kq_mask); + GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); - float * data = (float *) kq_mask->data; + float * data = (float *) self_kq_mask->data; - // [TAG_NO_CACHE_ISWA] - GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement"); + for (int h = 0; h < 1; ++h) { + for (int i1 = 0; i1 < n_tokens; ++i1) { + const llama_seq_id s1 = ubatch->seq_id[i1][0]; - for (int h = 0; h < 1; ++h) { - for (int i1 = 0; i1 < n_tokens; ++i1) { - const llama_seq_id s1 = ubatch->seq_id[i1][0]; + for (int i0 = 0; i0 < n_tokens; ++i0) { + float f = -INFINITY; - for (int i0 = 0; i0 < n_tokens; ++i0) { - float f = -INFINITY; + for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) { + const llama_seq_id s0 = ubatch->seq_id[i0][0]; - for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) { - const llama_seq_id s0 = ubatch->seq_id[i0][0]; + if (s0 != s1) { + continue; // skip different sequences + } - if (s0 != s1) { - continue; // skip different sequences - } + if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) { + continue; // skip future tokens for causal attention + } - if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) { - continue; // skip future tokens for causal attention + // TODO: reimplement this like in llama_kv_cache_unified + if (hparams.use_alibi) { + f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]); + } else { + f = 0.0f; + } } + data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f; + } + } + } + if (debug) { + print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type); + } + } - // TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA] - //if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) { - // continue; // skip masked tokens for SWA - //} + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + GGML_ASSERT(self_kq_mask_swa); + GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer)); - // TODO: reimplement this like in llama_kv_cache_unified - if (hparams.use_alibi) { - f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]); - } else { - f = 0.0f; + float * data = (float *) self_kq_mask_swa->data; + + for (int h = 0; h < 1; ++h) { + for (int i1 = 0; i1 < n_tokens; ++i1) { + const llama_seq_id s1 = ubatch->seq_id[i1][0]; + + for (int i0 = 0; i0 < n_tokens; ++i0) { + float f = -INFINITY; + + for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) { + const llama_seq_id s0 = ubatch->seq_id[i0][0]; + + if (s0 != s1) { + continue; // skip different sequences + } + + if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) { + continue; // skip future tokens for causal attention + } + + if (llama_hparams::is_masked_swa(hparams.n_swa, hparams.swa_type, ubatch->pos[i0], ubatch->pos[i1])) { + continue; // skip masked tokens for SWA + } + + // TODO: reimplement this like in llama_kv_cache_unified + if (hparams.use_alibi) { + f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]); + } else { + f = 0.0f; + } } + data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f; } - data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f; } } - } - if (debug) { - print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type); + if (debug) { + print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type); + } } } @@ -1299,12 +1342,10 @@ ggml_tensor * llm_graph_context::build_attn_mha( k = ggml_permute(ctx0, k, 0, 2, 1, 3); v = ggml_permute(ctx0, v, 0, 2, 1, 3); - const auto n_kv = k->ne[1]; - ggml_tensor * cur; // TODO: replace hardcoded padding with ggml-provided padding - if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) { + if (cparams.flash_attn && kq_b == nullptr) { GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet"); if (v_trans) { @@ -1419,10 +1460,20 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con auto inp = std::make_unique(hparams, cparams); // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch - inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); - ggml_set_input(inp->kq_mask); + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); + ggml_set_input(inp->self_kq_mask); - inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask; + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); + ggml_set_input(inp->self_kq_mask_swa); + + inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; + } else { + inp->self_kq_mask_swa = nullptr; + inp->self_kq_mask_swa_cnv = nullptr; + } return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp)); } @@ -1447,7 +1498,9 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - const auto & kq_mask = inp->get_kq_mask(); + const bool is_swa = hparams.is_swa(il); + + const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask(); // [TAG_NO_CACHE_PAD] // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams diff --git a/src/llama-graph.h b/src/llama-graph.h index dc84b7942893a..d0c3934f67927 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -257,10 +257,14 @@ class llm_graph_input_attn_no_cache : public llm_graph_input_i { void set_input(const llama_ubatch * ubatch) override; - ggml_tensor * get_kq_mask() const { return kq_mask_cnv; } + ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } + ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } - ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1] - ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1] + // n_tokens == n_batch + ggml_tensor * self_kq_mask = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] const llama_hparams hparams; const llama_cparams cparams; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 36d495d6cfeab..98b59394228aa 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -11376,8 +11376,7 @@ struct llm_build_gemma_embedding_iswa : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - // TODO: support cacheless iSWA embeddings [TAG_NO_CACHE_ISWA] - auto * inp_attn = build_attn_inp_kv_iswa(); + auto * inp_attn = build_attn_inp_no_cache(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -19378,7 +19377,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_NEO_BERT: case LLM_ARCH_WAVTOKENIZER_DEC: - //case LLM_ARCH_GEMMA_EMBEDDING: // TODO: disabled until the cacheless SWA logic is fixed [TAG_NO_CACHE_ISWA] + case LLM_ARCH_GEMMA_EMBEDDING: case LLM_ARCH_DREAM: case LLM_ARCH_LLADA: case LLM_ARCH_LLADA_MOE: diff --git a/src/llama.cpp b/src/llama.cpp index fe5a7a835488c..38700f97a0688 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -312,6 +312,7 @@ struct llama_model * llama_model_load_from_splits( LLAMA_LOG_ERROR("%s: list of splits is empty\n", __func__); return nullptr; } + splits.reserve(n_paths); for (size_t i = 0; i < n_paths; ++i) { splits.push_back(paths[i]); } From 71a264c7a74fdfc82259d81ed51754f58cc52070 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 13 Oct 2025 22:17:24 +0300 Subject: [PATCH 2/3] cont : deduplicate mask creation --- src/llama-graph.cpp | 89 +++++++++++++++++---------------------------- 1 file changed, 34 insertions(+), 55 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index cb41c7959e69c..f2c628a0b18a3 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -300,43 +300,51 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { const int64_t n_kv = ubatch->n_tokens; const int64_t n_tokens = ubatch->n_tokens; - { - GGML_ASSERT(self_kq_mask); - GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); - - float * data = (float *) self_kq_mask->data; - + const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) { for (int h = 0; h < 1; ++h) { for (int i1 = 0; i1 < n_tokens; ++i1) { const llama_seq_id s1 = ubatch->seq_id[i1][0]; + const llama_pos p1 = ubatch->pos[i1]; - for (int i0 = 0; i0 < n_tokens; ++i0) { - float f = -INFINITY; + const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv; - for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) { - const llama_seq_id s0 = ubatch->seq_id[i0][0]; + for (int i0 = 0; i0 < n_tokens; ++i0) { + const llama_seq_id s0 = ubatch->seq_id[i0][0]; + const llama_pos p0 = ubatch->pos[i0]; - if (s0 != s1) { - continue; // skip different sequences - } + // mask different sequences + if (s0 != s1) { + continue; + } - if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) { - continue; // skip future tokens for causal attention - } + // mask future tokens + if (cparams.causal_attn && p0 > p1) { + continue; + } - // TODO: reimplement this like in llama_kv_cache_unified - if (hparams.use_alibi) { - f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]); - } else { - f = 0.0f; - } + // apply SWA if any + if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) { + continue; } - data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f; + + data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; } } } + }; + + { + GGML_ASSERT(self_kq_mask); + GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); + + float * data = (float *) self_kq_mask->data; + + std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY); + + fill_mask(data, 0, LLAMA_SWA_TYPE_NONE); + if (debug) { - print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type); + print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE); } } @@ -346,39 +354,10 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { float * data = (float *) self_kq_mask_swa->data; - for (int h = 0; h < 1; ++h) { - for (int i1 = 0; i1 < n_tokens; ++i1) { - const llama_seq_id s1 = ubatch->seq_id[i1][0]; + std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY); - for (int i0 = 0; i0 < n_tokens; ++i0) { - float f = -INFINITY; + fill_mask(data, hparams.n_swa, hparams.swa_type); - for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) { - const llama_seq_id s0 = ubatch->seq_id[i0][0]; - - if (s0 != s1) { - continue; // skip different sequences - } - - if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) { - continue; // skip future tokens for causal attention - } - - if (llama_hparams::is_masked_swa(hparams.n_swa, hparams.swa_type, ubatch->pos[i0], ubatch->pos[i1])) { - continue; // skip masked tokens for SWA - } - - // TODO: reimplement this like in llama_kv_cache_unified - if (hparams.use_alibi) { - f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]); - } else { - f = 0.0f; - } - } - data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f; - } - } - } if (debug) { print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type); } From 5c6583190ee7592df9e78a3d183226e65f96ecd6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 13 Oct 2025 22:42:09 +0300 Subject: [PATCH 3/3] cont : fix name [no ci] --- src/llama-model.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 98b59394228aa..0cdad9babd9b2 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -11358,8 +11358,8 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { } }; -struct llm_build_gemma_embedding_iswa : public llm_graph_context { - llm_build_gemma_embedding_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +struct llm_build_gemma_embedding : public llm_graph_context { + llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k; ggml_tensor * cur; @@ -19670,7 +19670,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_GEMMA_EMBEDDING: { - llm = std::make_unique(*this, params); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_STARCODER2: {