From a0b83f6408d86c8486e56cf985c5bbc017672b3a Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Mon, 29 Sep 2025 15:17:15 +0200 Subject: [PATCH 1/5] model: EmbeddingGemma sentence-transformers dense linear projections support --- convert_hf_to_gguf.py | 39 ++++++++++++++++++++++++++++++++++ gguf-py/gguf/constants.py | 6 ++++++ gguf-py/gguf/tensor_mapping.py | 8 ++++++- src/llama-arch.cpp | 4 ++++ src/llama-arch.h | 2 ++ src/llama-graph.cpp | 14 ++++++++++++ src/llama-graph.h | 8 +++++++ src/llama-model.cpp | 14 ++++++++++++ src/llama-model.h | 6 ++++++ 9 files changed, 100 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 411e36f8cf41e..ee4b48d01a91e 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -280,6 +280,8 @@ def prepare_tensors(self): max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") for name, data_torch in chain(self.generate_extra_tensors(), self.get_tensors()): + if "dense" in name: + break_here = 1 # we don't need these if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): continue @@ -332,6 +334,9 @@ def prepare_tensors(self): gguf.MODEL_TENSOR.A_ENC_EMBD_POS, gguf.MODEL_TENSOR.ALTUP_CORRECT_COEF, gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF, + #gguf.MODEL_TENSOR.DENSE_2_OUT, + #gguf.MODEL_TENSOR.DENSE_3_OUT, + ) ) or not new_name.endswith(".weight") @@ -5255,6 +5260,40 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter @ModelBase.register("Gemma3TextModel") class EmbeddingGemma(Gemma3Model): model_arch = gguf.MODEL_ARCH.GEMMA_EMBEDDING + module_paths = [] + dense_tensors = [] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # read molues.json to determine if model has Dense layers + module_path = self.dir_model / "modules.json" + if module_path.is_file(): + with open(module_path, encoding="utf-8") as f: + modules = json.load(f) + for mod in modules: + if mod["type"] == "sentence_transformers.models.Dense": + module_path = mod["path"] + tensors_file = self.dir_model / module_path / "model.safetensors" + if tensors_file.is_file(): + self.module_paths.append(module_path) + + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + from safetensors.torch import load_file + module_paths = list(self.module_paths) + for i, module_path in enumerate(module_paths): + tensors_file = self.dir_model / module_path / "model.safetensors" + local_tensors = load_file(tensors_file) + tensor_name = "dense_2" if module_path == "2_Dense" else "dense_3" + for name, local_tensor in local_tensors.items(): + if not name.endswith(".weight"): + continue + orig_name = name.replace("linear", tensor_name) + name = self.map_tensor_name(orig_name) + logger.info(f"Adding extra tensor {i+1}/{len(module_paths)}: {orig_name} -> {name}, shape={local_tensor.shape}") + yield name, local_tensor.clone() + + def set_gguf_parameters(self): super().set_gguf_parameters() diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 88ea9f32f8c28..654e3255632b6 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -423,6 +423,8 @@ class MODEL_TENSOR(IntEnum): TOKEN_TYPES = auto() POS_EMBD = auto() OUTPUT = auto() + DENSE_2_OUT = auto() # embedding-gemma Dense layers + DENSE_3_OUT = auto() # embedding-gemma Dense layers OUTPUT_NORM = auto() ROPE_FREQS = auto() ROPE_FACTORS_LONG = auto() @@ -765,6 +767,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.POS_EMBD: "position_embd", MODEL_TENSOR.OUTPUT_NORM: "output_norm", MODEL_TENSOR.OUTPUT: "output", + MODEL_TENSOR.DENSE_2_OUT: "dense_2", # embedding-gemma Dense layers + MODEL_TENSOR.DENSE_3_OUT: "dense_3", # embedding-gemma Dense layers MODEL_TENSOR.ROPE_FREQS: "rope_freqs", MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long", MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short", @@ -1747,6 +1751,8 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GEMMA_EMBEDDING: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.DENSE_2_OUT, + MODEL_TENSOR.DENSE_3_OUT, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index c533b55c0120a..cd5cab603636e 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -76,7 +76,12 @@ class TensorNameMap: "lm_head", # llama4 "model.transformer.ff_out", # llada ), - + MODEL_TENSOR.DENSE_2_OUT: ( + "dense_2_out", # phi2 + ), + MODEL_TENSOR.DENSE_3_OUT: ( + "dense_3_out", # phi2 + ), # Output norm MODEL_TENSOR.OUTPUT_NORM: ( "gpt_neox.final_layer_norm", # gptneox @@ -116,6 +121,7 @@ class TensorNameMap: } block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = { + # Attention norm MODEL_TENSOR.ATTN_NORM: ( "gpt_neox.layers.{bid}.input_layernorm", # gptneox diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 4e8d54c4193cc..856509abf42fe 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1064,6 +1064,8 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_DENSE_2_OUT, "dense_2" }, + { LLM_TENSOR_DENSE_3_OUT, "dense_3" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, @@ -2229,6 +2231,8 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output + {LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index b5c6f3d76a62c..dd44b80acf36a 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -271,6 +271,8 @@ enum llm_tensor { LLM_TENSOR_TOKEN_EMBD_NORM, LLM_TENSOR_TOKEN_TYPES, LLM_TENSOR_POS_EMBD, + LLM_TENSOR_DENSE_2_OUT, + LLM_TENSOR_DENSE_3_OUT, LLM_TENSOR_OUTPUT, LLM_TENSOR_OUTPUT_NORM, LLM_TENSOR_ROPE_FREQS, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 90cd885a60a4f..98de2ece5e0ea 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1853,6 +1853,19 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); } +void llm_graph_context::build_dense_out( + ggml_tensor *dense_2, + ggml_tensor *dense_3) const { + ggml_tensor * cur = res->get_embd_pooled(); + cur = ggml_mul_mat(ctx0, dense_2, cur); + cb(cur, "result_embd_pooled", -1); + cur = ggml_mul_mat(ctx0, dense_3, cur); + cb(cur, "result_embd_pooled", -1); + res->t_embd_pooled = cur; + ggml_build_forward_expand(gf, cur); +} + + void llm_graph_context::build_pooling( ggml_tensor * cls, ggml_tensor * cls_b, @@ -1937,6 +1950,7 @@ void llm_graph_context::build_pooling( ggml_build_forward_expand(gf, cur); } + int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { // TODO move to hparams if a T5 variant appears that uses a different value const int64_t max_distance = 128; diff --git a/src/llama-graph.h b/src/llama-graph.h index 34b984afeb043..dc84b7942893a 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -814,6 +814,14 @@ struct llm_graph_context { ggml_tensor * cls_b, ggml_tensor * cls_out, ggml_tensor * cls_out_b) const; + + // + // dense (out) + // + + void build_dense_out( + ggml_tensor * dense_2, + ggml_tensor * dense_3) const; }; // TODO: better name diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 63655bf6517b4..5edaf30e23f76 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3645,6 +3645,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } + // Dense output layers + //FIXME: meta_data is hardcoded for now + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, 4 * n_embd}, 0); + dense_3_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_3_OUT, "weight"), {4 * n_embd, n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -11352,6 +11358,7 @@ struct llm_build_gemma_embedding_iswa : public llm_graph_context { cur = ggml_add(ctx0, cur, sa_out); + cur = build_cvec(cur, il); cb(cur, "l_out", il); @@ -19628,6 +19635,13 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // add on pooling layer llm->build_pooling(cls, cls_b, cls_out, cls_out_b); + //FIXME: remove this hack when we have a better way to handle pooling + //(sentence-transformer) dense layers are applied after pooling + // for LLM_ARCH_GEMMA_EMBEDDING mean pooling is already added to the graph + if (llm->arch == LLM_ARCH_GEMMA_EMBEDDING) { + //LLAMA_LOG_WARN("%s: adding pooling layer\n", __func__); + llm->build_dense_out(dense_2_out_layers,dense_3_out_layers); + } return llm->res->get_gf(); } diff --git a/src/llama-model.h b/src/llama-model.h index d73ce9693230f..75418d9632a0d 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -431,6 +431,12 @@ struct llama_model { std::vector layers; + //Dense out layer for Sentence Transformers models like embeddinggemma + // For Sentence Transformers models structure see + // https://sbert.net/docs/sentence_transformer/usage/custom_models.html#structure-of-sentence-transformer-models + ggml_tensor * dense_2_out_layers; + ggml_tensor * dense_3_out_layers; + llama_model_params params; // gguf metadata From 8ceff264de22a48c97dc0c70945db47ed7688cf1 Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Wed, 1 Oct 2025 12:53:22 +0200 Subject: [PATCH 2/5] model: add support for EmbeddingGemma SentenceTransformers dense linear projections Adding support for the Dense modules used in EmbeddingGemma models. EmbeddingGemma is a SentenceTransformers model with additional modules beyond the base Transformer backbone. See: https://developers.googleblog.com/en/gemma-explained-embeddinggemma-architecture-and-recipe/ --- convert_hf_to_gguf.py | 7 ------- gguf-py/gguf/constants.py | 8 ++++---- gguf-py/gguf/tensor_mapping.py | 5 ++--- src/llama-arch.cpp | 8 ++++---- src/llama-graph.cpp | 1 - src/llama-model.cpp | 10 +++------- src/llama-model.h | 2 +- 7 files changed, 14 insertions(+), 27 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ee4b48d01a91e..d08e3b57c9a15 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -280,8 +280,6 @@ def prepare_tensors(self): max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") for name, data_torch in chain(self.generate_extra_tensors(), self.get_tensors()): - if "dense" in name: - break_here = 1 # we don't need these if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): continue @@ -334,9 +332,6 @@ def prepare_tensors(self): gguf.MODEL_TENSOR.A_ENC_EMBD_POS, gguf.MODEL_TENSOR.ALTUP_CORRECT_COEF, gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF, - #gguf.MODEL_TENSOR.DENSE_2_OUT, - #gguf.MODEL_TENSOR.DENSE_3_OUT, - ) ) or not new_name.endswith(".weight") @@ -5290,11 +5285,9 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: continue orig_name = name.replace("linear", tensor_name) name = self.map_tensor_name(orig_name) - logger.info(f"Adding extra tensor {i+1}/{len(module_paths)}: {orig_name} -> {name}, shape={local_tensor.shape}") yield name, local_tensor.clone() - def set_gguf_parameters(self): super().set_gguf_parameters() diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 654e3255632b6..433cfb6a5a9ce 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -423,8 +423,8 @@ class MODEL_TENSOR(IntEnum): TOKEN_TYPES = auto() POS_EMBD = auto() OUTPUT = auto() - DENSE_2_OUT = auto() # embedding-gemma Dense layers - DENSE_3_OUT = auto() # embedding-gemma Dense layers + DENSE_2_OUT = auto() # embeddinggemma 2_Dense + DENSE_3_OUT = auto() # embeddinggemma 3_Dense OUTPUT_NORM = auto() ROPE_FREQS = auto() ROPE_FACTORS_LONG = auto() @@ -767,8 +767,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.POS_EMBD: "position_embd", MODEL_TENSOR.OUTPUT_NORM: "output_norm", MODEL_TENSOR.OUTPUT: "output", - MODEL_TENSOR.DENSE_2_OUT: "dense_2", # embedding-gemma Dense layers - MODEL_TENSOR.DENSE_3_OUT: "dense_3", # embedding-gemma Dense layers + MODEL_TENSOR.DENSE_2_OUT: "dense_2", # embeddinggemma 2_Dense + MODEL_TENSOR.DENSE_3_OUT: "dense_3", # embeddinggemma 2_Dense MODEL_TENSOR.ROPE_FREQS: "rope_freqs", MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long", MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short", diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index cd5cab603636e..7b71c3e846e36 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -77,10 +77,10 @@ class TensorNameMap: "model.transformer.ff_out", # llada ), MODEL_TENSOR.DENSE_2_OUT: ( - "dense_2_out", # phi2 + "dense_2_out", # embeddinggemma ), MODEL_TENSOR.DENSE_3_OUT: ( - "dense_3_out", # phi2 + "dense_3_out", # embeddinggemma ), # Output norm MODEL_TENSOR.OUTPUT_NORM: ( @@ -121,7 +121,6 @@ class TensorNameMap: } block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = { - # Attention norm MODEL_TENSOR.ATTN_NORM: ( "gpt_neox.layers.{bid}.input_layernorm", # gptneox diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 856509abf42fe..9746a4613e1af 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1064,8 +1064,8 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_DENSE_2_OUT, "dense_2" }, - { LLM_TENSOR_DENSE_3_OUT, "dense_3" }, + { LLM_TENSOR_DENSE_2_OUT, "dense_2" }, + { LLM_TENSOR_DENSE_3_OUT, "dense_3" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, @@ -2231,8 +2231,8 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output - {LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output + {LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output + {LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 98de2ece5e0ea..d1148dc4b4c89 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1950,7 +1950,6 @@ void llm_graph_context::build_pooling( ggml_build_forward_expand(gf, cur); } - int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { // TODO move to hparams if a T5 variant appears that uses a different value const int64_t max_distance = 128; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 5edaf30e23f76..c6a274c9fb913 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3645,8 +3645,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } - // Dense output layers - //FIXME: meta_data is hardcoded for now + // Dense linear weights dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, 4 * n_embd}, 0); dense_3_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_3_OUT, "weight"), {4 * n_embd, n_embd}, 0); @@ -11358,7 +11357,6 @@ struct llm_build_gemma_embedding_iswa : public llm_graph_context { cur = ggml_add(ctx0, cur, sa_out); - cur = build_cvec(cur, il); cb(cur, "l_out", il); @@ -19635,11 +19633,9 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // add on pooling layer llm->build_pooling(cls, cls_b, cls_out, cls_out_b); - //FIXME: remove this hack when we have a better way to handle pooling - //(sentence-transformer) dense layers are applied after pooling - // for LLM_ARCH_GEMMA_EMBEDDING mean pooling is already added to the graph + // embeddinggemma specific + //sentence-transformer dense linear projections are applied after pooling if (llm->arch == LLM_ARCH_GEMMA_EMBEDDING) { - //LLAMA_LOG_WARN("%s: adding pooling layer\n", __func__); llm->build_dense_out(dense_2_out_layers,dense_3_out_layers); } diff --git a/src/llama-model.h b/src/llama-model.h index 75418d9632a0d..a57c3864f2591 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -431,7 +431,7 @@ struct llama_model { std::vector layers; - //Dense out layer for Sentence Transformers models like embeddinggemma + //Dense linear projections for SentenceTransformers models like embeddinggemma // For Sentence Transformers models structure see // https://sbert.net/docs/sentence_transformer/usage/custom_models.html#structure-of-sentence-transformer-models ggml_tensor * dense_2_out_layers; From f3be74e100b1a1fea27565b682dfb8852dca3975 Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Sat, 4 Oct 2025 08:58:23 +0200 Subject: [PATCH 3/5] model: add support for EmbeddingGemma SentenceTransformers dense linear projections - converting model with dense-layers is optional - introduced dense config params --- convert_hf_to_gguf.py | 70 ++++++++++++++++++++++++++++--------- gguf-py/gguf/constants.py | 3 ++ gguf-py/gguf/gguf_writer.py | 7 ++++ src/llama-arch.cpp | 6 ++++ src/llama-arch.h | 7 ++++ src/llama-context.cpp | 9 +++++ src/llama-graph.cpp | 14 +++++--- src/llama-hparams.h | 9 +++++ src/llama-model.cpp | 40 +++++++++++++-------- 9 files changed, 128 insertions(+), 37 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index d08e3b57c9a15..1bb8cf1a8165c 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -93,13 +93,15 @@ class ModelBase: # Mistral format specifics is_mistral_format: bool = False disable_mistral_community_chat_template: bool = False + sentence_transformers_dense_modules: bool = False def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False, use_temp_file: bool = False, eager: bool = False, metadata_override: Path | None = None, model_name: str | None = None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None, - disable_mistral_community_chat_template: bool = False): + disable_mistral_community_chat_template: bool = False, + sentence_transformers_dense_modules: bool = False): if type(self) is ModelBase or \ type(self) is TextModel or \ type(self) is MmprojModel: @@ -114,6 +116,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, self.lazy = not eager or (remote_hf_model_id is not None) self.dry_run = dry_run self.remote_hf_model_id = remote_hf_model_id + self.sentence_transformers_dense_modules = sentence_transformers_dense_modules if remote_hf_model_id is not None: self.is_safetensors = True @@ -5256,22 +5259,33 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter class EmbeddingGemma(Gemma3Model): model_arch = gguf.MODEL_ARCH.GEMMA_EMBEDDING module_paths = [] - dense_tensors = [] + dense_features_dims = {} def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # read molues.json to determine if model has Dense layers - module_path = self.dir_model / "modules.json" - if module_path.is_file(): - with open(module_path, encoding="utf-8") as f: - modules = json.load(f) - for mod in modules: - if mod["type"] == "sentence_transformers.models.Dense": - module_path = mod["path"] - tensors_file = self.dir_model / module_path / "model.safetensors" - if tensors_file.is_file(): - self.module_paths.append(module_path) - + if self.sentence_transformers_dense_modules: + # read molues.json to determine if model has Dense layers + modules_file = self.dir_model / "modules.json" + if modules_file.is_file(): + with open(modules_file, encoding="utf-8") as modules_json_file: + mods = json.load(modules_json_file) + for mod in mods: + if mod["type"] == "sentence_transformers.models.Dense": + mod_path = mod["path"] + # check if model.safetensors file for Dense layer exists + model_tensors_file = self.dir_model / mod_path / "model.safetensors" + if model_tensors_file.is_file(): + self.module_paths.append(mod_path) + # read config.json of the Dense layer to get in/out features + mod_conf_file = self.dir_model / mod_path / "config.json" + if mod_conf_file.is_file(): + with open(mod_conf_file, encoding="utf-8") as mod_conf_json_file: + mod_conf = json.load(mod_conf_json_file) + # hparams dense_2_feat_out and dense_3_feat_in are required when loading model's dense weights + prefix = self._get_dense_prefix(mod_path) + if (mod_conf["in_features"] is not None + and mod_conf["out_features"] is not None): + self.dense_features_dims[prefix] = (mod_conf["in_features"], mod_conf["out_features"]) def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: from safetensors.torch import load_file @@ -5279,7 +5293,7 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: for i, module_path in enumerate(module_paths): tensors_file = self.dir_model / module_path / "model.safetensors" local_tensors = load_file(tensors_file) - tensor_name = "dense_2" if module_path == "2_Dense" else "dense_3" + tensor_name = self._get_dense_prefix(module_path) for name, local_tensor in local_tensors.items(): if not name.endswith(".weight"): continue @@ -5287,6 +5301,11 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: name = self.map_tensor_name(orig_name) yield name, local_tensor.clone() + @staticmethod + def _get_dense_prefix(module_path) -> str: + """Get the tensor name prefix for the Dense layer from module path.""" + tensor_name = "dense_2" if module_path == "2_Dense" else "dense_3" + return tensor_name def set_gguf_parameters(self): super().set_gguf_parameters() @@ -5303,6 +5322,11 @@ def set_gguf_parameters(self): logger.info(f"Using original sliding_window from config: {orig_sliding_window} " f"instead of {self.hparams['sliding_window']}") self.gguf_writer.add_sliding_window(orig_sliding_window) + if self.sentence_transformers_dense_modules: + for dense, dims in self.dense_features_dims.items(): + logger.info(f"Setting dense layer {dense} in/out features to {dims}") + self.gguf_writer.add_dense_features_dims(dense, dims[0], dims[1]) + self.gguf_writer.add_pooling_type_opt(False) self._try_set_pooling_type() @@ -9247,6 +9271,13 @@ def parse_args() -> argparse.Namespace: ) ) + parser.add_argument( + "--sentence-transformers-dense-modules", action="store_true", + help=("Whether to include sentence-transformers dense modules." + "It can be used for sentence-transformers models, like google/embeddinggemma-300m" + "Default these modules are not included.") + ) + args = parser.parse_args() if not args.print_supported_models and args.model is None: parser.error("the following arguments are required: model") @@ -9309,9 +9340,13 @@ def main() -> None: if args.remote: hf_repo_id = args.model from huggingface_hub import snapshot_download + allowed_patterns = ["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"] + if args.sentence_transformers_dense_modules: + # include sentence-transformers dense modules safetensors files + allowed_patterns.append("*.safetensors") local_dir = snapshot_download( repo_id=hf_repo_id, - allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"]) + allow_patterns=allowed_patterns) dir_model = Path(local_dir) logger.info(f"Downloaded config and tokenizer to {local_dir}") else: @@ -9379,7 +9414,8 @@ def main() -> None: split_max_tensors=args.split_max_tensors, split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run, small_first_shard=args.no_tensor_first_split, - remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template + remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template, + sentence_transformers_dense_modules=args.sentence_transformers_dense_modules ) if args.vocab_only: diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 433cfb6a5a9ce..d40d9bf28fcad 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -128,6 +128,9 @@ class LLM: ALTUP_ACTIVE_IDX = "{arch}.altup.active_idx" ALTUP_NUM_INPUTS = "{arch}.altup.num_inputs" EMBD_LENGTH_PER_LAYER_INP = "{arch}.embedding_length_per_layer_input" + DENSE_FEAT_IN_SIZE = "{arch}.{dense}_feat_in" + DENSE_FEAT_OUT_SIZE = "{arch}.{dense}_feat_out" + POOLING_TYPE_OPT = "{arch}.pooling_type_opt" class Attention: HEAD_COUNT = "{arch}.attention.head_count" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 3152a30d7b212..655b50bfdf132 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -730,6 +730,13 @@ def add_shared_kv_layers(self, value: int) -> None: def add_sliding_window_pattern(self, value: Sequence[bool]) -> None: self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value) + def add_dense_features_dims(self, dense:str, in_f:int, out_f:int) -> None: + self.add_uint32(Keys.LLM.DENSE_FEAT_IN_SIZE.format(arch=self.arch, dense=dense), in_f) + self.add_uint32(Keys.LLM.DENSE_FEAT_OUT_SIZE.format(arch=self.arch, dense=dense), out_f) + + def add_pooling_type_opt(self, enable: bool) -> None: + self.add_bool(Keys.LLM.POOLING_TYPE_OPT.format(arch=self.arch), enable) + def add_logit_scale(self, value: float) -> None: self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 9746a4613e1af..6d19ba038de77 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -217,6 +217,12 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" }, { LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" }, + // sentence-transformers dense modules feature dims + { LLM_KV_DENSE_2_FEAT_IN, "%s.dense_2_feat_in" }, + { LLM_KV_DENSE_2_FEAT_OUT, "%s.dense_2_feat_out" }, + { LLM_KV_DENSE_3_FEAT_IN, "%s.dense_3_feat_in" }, + { LLM_KV_DENSE_3_FEAT_OUT, "%s.dense_3_feat_out" }, + { LLM_KV_POOLING_TYPE_OPT, "%s.pooling_type_opt" }, { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index dd44b80acf36a..0801ce9720402 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -264,6 +264,13 @@ enum llm_kv { LLM_KV_TOKENIZER_PREFIX_ID, LLM_KV_TOKENIZER_SUFFIX_ID, LLM_KV_TOKENIZER_MIDDLE_ID, + + // sentence-transformers dense layers in and out features + LLM_KV_DENSE_2_FEAT_IN, + LLM_KV_DENSE_2_FEAT_OUT, + LLM_KV_DENSE_3_FEAT_IN, + LLM_KV_DENSE_3_FEAT_OUT, + LLM_KV_POOLING_TYPE_OPT, }; enum llm_tensor { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index d8a8b5e647a85..582c5253fb091 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2346,6 +2346,15 @@ llama_context * llama_init_from_model( return nullptr; } + // if setting pooling_type is disabled, set it to model default + // for sentence-transformers models (e.g. EmbeddingGemma) mean-pooling is required + // when dense layers are enabled + if (!model->hparams.pooling_type_opt) { + params.pooling_type = model->hparams.pooling_type; + LLAMA_LOG_INFO("%s: setting pooling_type to models default: %d\n", __func__, params.pooling_type); + + } + try { auto * ctx = new llama_context(*model, params); return ctx; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index d1148dc4b4c89..b5e9d49435e8a 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1856,11 +1856,15 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { void llm_graph_context::build_dense_out( ggml_tensor *dense_2, ggml_tensor *dense_3) const { - ggml_tensor * cur = res->get_embd_pooled(); - cur = ggml_mul_mat(ctx0, dense_2, cur); - cb(cur, "result_embd_pooled", -1); - cur = ggml_mul_mat(ctx0, dense_3, cur); - cb(cur, "result_embd_pooled", -1); + ggml_tensor *cur = res->get_embd_pooled(); + if (dense_2 != nullptr) { + cur = ggml_mul_mat(ctx0, dense_2, cur); + cb(cur, "result_embd_pooled", -1); + } + if (dense_3 != nullptr) { + cur = ggml_mul_mat(ctx0, dense_3, cur); + cb(cur, "result_embd_pooled", -1); + } res->t_embd_pooled = cur; ggml_build_forward_expand(gf, cur); } diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 0fe4b56942405..5a7acdb249163 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -169,6 +169,15 @@ struct llama_hparams { uint32_t laurel_rank = 64; uint32_t n_embd_altup = 256; + // needed for sentence-transformers dense layers + uint32_t dense_2_feat_in = 0; // in_features of the 2_Dense + uint32_t dense_2_feat_out = 0; // out_features of the 2_Dense + uint32_t dense_3_feat_in = 0; // in_features of the 3_Dense + uint32_t dense_3_feat_out = 0; // out_features of the 3_Dense + + // whether pooling_type can be overridden by user + bool pooling_type_opt = true; + // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggerganov/llama.cpp/pull/8141 llama_token dec_start_token_id = LLAMA_TOKEN_NULL; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c6a274c9fb913..d8b50f53941ea 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1207,20 +1207,28 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.set_swa_pattern(6); hparams.causal_attn = false; // embeddings do not use causal attention - hparams.rope_freq_base_train_swa = 10000.0f; + hparams.rope_freq_base_train_swa = 10000.0f; hparams.rope_freq_scale_train_swa = 1.0f; - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_0_3B; break; - default: type = LLM_TYPE_UNKNOWN; - } - hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k)); + //applied only if model converted with --sentence-transformers-dense-modules + ml.get_key(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in, false); + ml.get_key(LLM_KV_DENSE_2_FEAT_OUT, hparams.dense_2_feat_out, false); + ml.get_key(LLM_KV_DENSE_3_FEAT_IN, hparams.dense_3_feat_in, false); + ml.get_key(LLM_KV_DENSE_3_FEAT_OUT, hparams.dense_3_feat_out, false); + ml.get_key(LLM_KV_POOLING_TYPE_OPT, hparams.pooling_type_opt, false); - } break; + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_0_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k)); + } + break; case LLM_ARCH_STARCODER2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -3646,8 +3654,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // Dense linear weights - dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, 4 * n_embd}, 0); - dense_3_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_3_OUT, "weight"), {4 * n_embd, n_embd}, 0); + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.dense_2_feat_out}, TENSOR_NOT_REQUIRED); + dense_3_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_3_OUT, "weight"), {hparams.dense_3_feat_in, n_embd}, TENSOR_NOT_REQUIRED); for (int i = 0; i < n_layer; ++i) { @@ -19633,10 +19641,12 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // add on pooling layer llm->build_pooling(cls, cls_b, cls_out, cls_out_b); - // embeddinggemma specific - //sentence-transformer dense linear projections are applied after pooling - if (llm->arch == LLM_ARCH_GEMMA_EMBEDDING) { - llm->build_dense_out(dense_2_out_layers,dense_3_out_layers); + + // if the gguf model was converted with --sentence-transformers-dense-modules + // there will be two additional dense projection layers + // dense linear projections are applied after pooling + if (dense_2_out_layers != nullptr || dense_3_out_layers != nullptr) { + llm->build_dense_out(dense_2_out_layers, dense_3_out_layers); } return llm->res->get_gf(); From f48b704abc186f94d1d0777f5e7397b4624a9275 Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Mon, 6 Oct 2025 10:00:48 +0200 Subject: [PATCH 4/5] Update convert_hf_to_gguf.py Co-authored-by: Daniel Bevenius --- convert_hf_to_gguf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index a50a6855acfec..6c4104da58c1b 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5269,7 +5269,7 @@ class EmbeddingGemma(Gemma3Model): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.sentence_transformers_dense_modules: - # read molues.json to determine if model has Dense layers + # read modules.json to determine if model has Dense layers modules_file = self.dir_model / "modules.json" if modules_file.is_file(): with open(modules_file, encoding="utf-8") as modules_json_file: From e22325cdb4da1fd3961eeda0e26e34a7293cd172 Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Mon, 6 Oct 2025 14:57:13 +0200 Subject: [PATCH 5/5] fixed formatting issues --- src/llama-hparams.h | 8 ++++---- src/llama-model.cpp | 13 ++++++------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 23f5b208137f6..0b7eaa1601ee6 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -170,10 +170,10 @@ struct llama_hparams { uint32_t n_embd_altup = 256; // needed for sentence-transformers dense layers - uint32_t dense_2_feat_in = 0; // in_features of the 2_Dense - uint32_t dense_2_feat_out = 0; // out_features of the 2_Dense - uint32_t dense_3_feat_in = 0; // in_features of the 3_Dense - uint32_t dense_3_feat_out = 0; // out_features of the 3_Dense + uint32_t dense_2_feat_in = 0; // in_features of the 2_Dense + uint32_t dense_2_feat_out = 0; // out_features of the 2_Dense + uint32_t dense_3_feat_in = 0; // in_features of the 3_Dense + uint32_t dense_3_feat_out = 0; // out_features of the 3_Dense // whether pooling_type can be overridden by user bool pooling_type_opt = true; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 0d4b9405ece76..59e9a1adcf9cb 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1229,14 +1229,13 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_DENSE_3_FEAT_OUT, hparams.dense_3_feat_out, false); ml.get_key(LLM_KV_POOLING_TYPE_OPT, hparams.pooling_type_opt, false); - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_0_3B; break; - default: type = LLM_TYPE_UNKNOWN; - } - hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k)); - } - break; + case 24: type = LLM_TYPE_0_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k)); + + } break; case LLM_ARCH_STARCODER2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);