diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2fc04173aa2c7..a50a6855acfec 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -93,13 +93,15 @@ class ModelBase: # Mistral format specifics is_mistral_format: bool = False disable_mistral_community_chat_template: bool = False + sentence_transformers_dense_modules: bool = False def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False, use_temp_file: bool = False, eager: bool = False, metadata_override: Path | None = None, model_name: str | None = None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None, - disable_mistral_community_chat_template: bool = False): + disable_mistral_community_chat_template: bool = False, + sentence_transformers_dense_modules: bool = False): if type(self) is ModelBase or \ type(self) is TextModel or \ type(self) is MmprojModel: @@ -114,6 +116,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, self.lazy = not eager or (remote_hf_model_id is not None) self.dry_run = dry_run self.remote_hf_model_id = remote_hf_model_id + self.sentence_transformers_dense_modules = sentence_transformers_dense_modules if remote_hf_model_id is not None: self.is_safetensors = True @@ -5260,6 +5263,54 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter @ModelBase.register("Gemma3TextModel") class EmbeddingGemma(Gemma3Model): model_arch = gguf.MODEL_ARCH.GEMMA_EMBEDDING + module_paths = [] + dense_features_dims = {} + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.sentence_transformers_dense_modules: + # read molues.json to determine if model has Dense layers + modules_file = self.dir_model / "modules.json" + if modules_file.is_file(): + with open(modules_file, encoding="utf-8") as modules_json_file: + mods = json.load(modules_json_file) + for mod in mods: + if mod["type"] == "sentence_transformers.models.Dense": + mod_path = mod["path"] + # check if model.safetensors file for Dense layer exists + model_tensors_file = self.dir_model / mod_path / "model.safetensors" + if model_tensors_file.is_file(): + self.module_paths.append(mod_path) + # read config.json of the Dense layer to get in/out features + mod_conf_file = self.dir_model / mod_path / "config.json" + if mod_conf_file.is_file(): + with open(mod_conf_file, encoding="utf-8") as mod_conf_json_file: + mod_conf = json.load(mod_conf_json_file) + # hparams dense_2_feat_out and dense_3_feat_in are required when loading model's dense weights + prefix = self._get_dense_prefix(mod_path) + if (mod_conf["in_features"] is not None + and mod_conf["out_features"] is not None): + self.dense_features_dims[prefix] = (mod_conf["in_features"], mod_conf["out_features"]) + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + from safetensors.torch import load_file + module_paths = list(self.module_paths) + for i, module_path in enumerate(module_paths): + tensors_file = self.dir_model / module_path / "model.safetensors" + local_tensors = load_file(tensors_file) + tensor_name = self._get_dense_prefix(module_path) + for name, local_tensor in local_tensors.items(): + if not name.endswith(".weight"): + continue + orig_name = name.replace("linear", tensor_name) + name = self.map_tensor_name(orig_name) + yield name, local_tensor.clone() + + @staticmethod + def _get_dense_prefix(module_path) -> str: + """Get the tensor name prefix for the Dense layer from module path.""" + tensor_name = "dense_2" if module_path == "2_Dense" else "dense_3" + return tensor_name def set_gguf_parameters(self): super().set_gguf_parameters() @@ -5276,6 +5327,11 @@ def set_gguf_parameters(self): logger.info(f"Using original sliding_window from config: {orig_sliding_window} " f"instead of {self.hparams['sliding_window']}") self.gguf_writer.add_sliding_window(orig_sliding_window) + if self.sentence_transformers_dense_modules: + for dense, dims in self.dense_features_dims.items(): + logger.info(f"Setting dense layer {dense} in/out features to {dims}") + self.gguf_writer.add_dense_features_dims(dense, dims[0], dims[1]) + self.gguf_writer.add_pooling_type_opt(False) self._try_set_pooling_type() @@ -9257,6 +9313,13 @@ def parse_args() -> argparse.Namespace: ) ) + parser.add_argument( + "--sentence-transformers-dense-modules", action="store_true", + help=("Whether to include sentence-transformers dense modules." + "It can be used for sentence-transformers models, like google/embeddinggemma-300m" + "Default these modules are not included.") + ) + args = parser.parse_args() if not args.print_supported_models and args.model is None: parser.error("the following arguments are required: model") @@ -9319,9 +9382,13 @@ def main() -> None: if args.remote: hf_repo_id = args.model from huggingface_hub import snapshot_download + allowed_patterns = ["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"] + if args.sentence_transformers_dense_modules: + # include sentence-transformers dense modules safetensors files + allowed_patterns.append("*.safetensors") local_dir = snapshot_download( repo_id=hf_repo_id, - allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"]) + allow_patterns=allowed_patterns) dir_model = Path(local_dir) logger.info(f"Downloaded config and tokenizer to {local_dir}") else: @@ -9389,7 +9456,8 @@ def main() -> None: split_max_tensors=args.split_max_tensors, split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run, small_first_shard=args.no_tensor_first_split, - remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template + remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template, + sentence_transformers_dense_modules=args.sentence_transformers_dense_modules ) if args.vocab_only: diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 1600405ea8693..65cbaca475103 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -128,6 +128,9 @@ class LLM: ALTUP_ACTIVE_IDX = "{arch}.altup.active_idx" ALTUP_NUM_INPUTS = "{arch}.altup.num_inputs" EMBD_LENGTH_PER_LAYER_INP = "{arch}.embedding_length_per_layer_input" + DENSE_FEAT_IN_SIZE = "{arch}.{dense}_feat_in" + DENSE_FEAT_OUT_SIZE = "{arch}.{dense}_feat_out" + POOLING_TYPE_OPT = "{arch}.pooling_type_opt" class Attention: HEAD_COUNT = "{arch}.attention.head_count" @@ -431,6 +434,8 @@ class MODEL_TENSOR(IntEnum): TOKEN_TYPES = auto() POS_EMBD = auto() OUTPUT = auto() + DENSE_2_OUT = auto() # embeddinggemma 2_Dense + DENSE_3_OUT = auto() # embeddinggemma 3_Dense OUTPUT_NORM = auto() ROPE_FREQS = auto() ROPE_FACTORS_LONG = auto() @@ -774,6 +779,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.POS_EMBD: "position_embd", MODEL_TENSOR.OUTPUT_NORM: "output_norm", MODEL_TENSOR.OUTPUT: "output", + MODEL_TENSOR.DENSE_2_OUT: "dense_2", # embeddinggemma 2_Dense + MODEL_TENSOR.DENSE_3_OUT: "dense_3", # embeddinggemma 2_Dense MODEL_TENSOR.ROPE_FREQS: "rope_freqs", MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long", MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short", @@ -1756,6 +1763,8 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GEMMA_EMBEDDING: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.DENSE_2_OUT, + MODEL_TENSOR.DENSE_3_OUT, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 30fc1a05ec052..2fe1854beb695 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -730,6 +730,13 @@ def add_shared_kv_layers(self, value: int) -> None: def add_sliding_window_pattern(self, value: Sequence[bool]) -> None: self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value) + def add_dense_features_dims(self, dense:str, in_f:int, out_f:int) -> None: + self.add_uint32(Keys.LLM.DENSE_FEAT_IN_SIZE.format(arch=self.arch, dense=dense), in_f) + self.add_uint32(Keys.LLM.DENSE_FEAT_OUT_SIZE.format(arch=self.arch, dense=dense), out_f) + + def add_pooling_type_opt(self, enable: bool) -> None: + self.add_bool(Keys.LLM.POOLING_TYPE_OPT.format(arch=self.arch), enable) + def add_logit_scale(self, value: float) -> None: self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 67b27413405f1..258b8e13c1aeb 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -76,7 +76,12 @@ class TensorNameMap: "lm_head", # llama4 "model.transformer.ff_out", # llada ), - + MODEL_TENSOR.DENSE_2_OUT: ( + "dense_2_out", # embeddinggemma + ), + MODEL_TENSOR.DENSE_3_OUT: ( + "dense_3_out", # embeddinggemma + ), # Output norm MODEL_TENSOR.OUTPUT_NORM: ( "gpt_neox.final_layer_norm", # gptneox diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 4fd083aa04843..74aa14f8b1078 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -218,6 +218,12 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" }, { LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" }, + // sentence-transformers dense modules feature dims + { LLM_KV_DENSE_2_FEAT_IN, "%s.dense_2_feat_in" }, + { LLM_KV_DENSE_2_FEAT_OUT, "%s.dense_2_feat_out" }, + { LLM_KV_DENSE_3_FEAT_IN, "%s.dense_3_feat_in" }, + { LLM_KV_DENSE_3_FEAT_OUT, "%s.dense_3_feat_out" }, + { LLM_KV_POOLING_TYPE_OPT, "%s.pooling_type_opt" }, { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, @@ -1070,6 +1076,8 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_DENSE_2_OUT, "dense_2" }, + { LLM_TENSOR_DENSE_3_OUT, "dense_3" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, @@ -2254,6 +2262,8 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output + {LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index bc4b04bb4e015..6339133d70729 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -270,6 +270,13 @@ enum llm_kv { LLM_KV_TOKENIZER_PREFIX_ID, LLM_KV_TOKENIZER_SUFFIX_ID, LLM_KV_TOKENIZER_MIDDLE_ID, + + // sentence-transformers dense layers in and out features + LLM_KV_DENSE_2_FEAT_IN, + LLM_KV_DENSE_2_FEAT_OUT, + LLM_KV_DENSE_3_FEAT_IN, + LLM_KV_DENSE_3_FEAT_OUT, + LLM_KV_POOLING_TYPE_OPT, }; enum llm_tensor { @@ -277,6 +284,8 @@ enum llm_tensor { LLM_TENSOR_TOKEN_EMBD_NORM, LLM_TENSOR_TOKEN_TYPES, LLM_TENSOR_POS_EMBD, + LLM_TENSOR_DENSE_2_OUT, + LLM_TENSOR_DENSE_3_OUT, LLM_TENSOR_OUTPUT, LLM_TENSOR_OUTPUT_NORM, LLM_TENSOR_ROPE_FREQS, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index d8a8b5e647a85..582c5253fb091 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2346,6 +2346,15 @@ llama_context * llama_init_from_model( return nullptr; } + // if setting pooling_type is disabled, set it to model default + // for sentence-transformers models (e.g. EmbeddingGemma) mean-pooling is required + // when dense layers are enabled + if (!model->hparams.pooling_type_opt) { + params.pooling_type = model->hparams.pooling_type; + LLAMA_LOG_INFO("%s: setting pooling_type to models default: %d\n", __func__, params.pooling_type); + + } + try { auto * ctx = new llama_context(*model, params); return ctx; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 90cd885a60a4f..b5e9d49435e8a 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1853,6 +1853,23 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); } +void llm_graph_context::build_dense_out( + ggml_tensor *dense_2, + ggml_tensor *dense_3) const { + ggml_tensor *cur = res->get_embd_pooled(); + if (dense_2 != nullptr) { + cur = ggml_mul_mat(ctx0, dense_2, cur); + cb(cur, "result_embd_pooled", -1); + } + if (dense_3 != nullptr) { + cur = ggml_mul_mat(ctx0, dense_3, cur); + cb(cur, "result_embd_pooled", -1); + } + res->t_embd_pooled = cur; + ggml_build_forward_expand(gf, cur); +} + + void llm_graph_context::build_pooling( ggml_tensor * cls, ggml_tensor * cls_b, diff --git a/src/llama-graph.h b/src/llama-graph.h index 34b984afeb043..dc84b7942893a 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -814,6 +814,14 @@ struct llm_graph_context { ggml_tensor * cls_b, ggml_tensor * cls_out, ggml_tensor * cls_out_b) const; + + // + // dense (out) + // + + void build_dense_out( + ggml_tensor * dense_2, + ggml_tensor * dense_3) const; }; // TODO: better name diff --git a/src/llama-hparams.h b/src/llama-hparams.h index f29b23eeffe56..23f5b208137f6 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -169,6 +169,14 @@ struct llama_hparams { uint32_t laurel_rank = 64; uint32_t n_embd_altup = 256; + // needed for sentence-transformers dense layers + uint32_t dense_2_feat_in = 0; // in_features of the 2_Dense + uint32_t dense_2_feat_out = 0; // out_features of the 2_Dense + uint32_t dense_3_feat_in = 0; // in_features of the 3_Dense + uint32_t dense_3_feat_out = 0; // out_features of the 3_Dense + + // whether pooling_type can be overridden by user + bool pooling_type_opt = true; // xIELU std::array xielu_alpha_n; std::array xielu_alpha_p; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 4c2d481a41d42..0d4b9405ece76 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1215,20 +1215,28 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.set_swa_pattern(6); hparams.causal_attn = false; // embeddings do not use causal attention - hparams.rope_freq_base_train_swa = 10000.0f; + hparams.rope_freq_base_train_swa = 10000.0f; hparams.rope_freq_scale_train_swa = 1.0f; - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_0_3B; break; - default: type = LLM_TYPE_UNKNOWN; - } - hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k)); + //applied only if model converted with --sentence-transformers-dense-modules + ml.get_key(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in, false); + ml.get_key(LLM_KV_DENSE_2_FEAT_OUT, hparams.dense_2_feat_out, false); + ml.get_key(LLM_KV_DENSE_3_FEAT_IN, hparams.dense_3_feat_in, false); + ml.get_key(LLM_KV_DENSE_3_FEAT_OUT, hparams.dense_3_feat_out, false); + ml.get_key(LLM_KV_POOLING_TYPE_OPT, hparams.pooling_type_opt, false); - } break; + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_0_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k)); + } + break; case LLM_ARCH_STARCODER2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -3668,6 +3676,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } + // Dense linear weights + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.dense_2_feat_out}, TENSOR_NOT_REQUIRED); + dense_3_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_3_OUT, "weight"), {hparams.dense_3_feat_in, n_embd}, TENSOR_NOT_REQUIRED); + + for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -19841,6 +19854,13 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // add on pooling layer llm->build_pooling(cls, cls_b, cls_out, cls_out_b); + // if the gguf model was converted with --sentence-transformers-dense-modules + // there will be two additional dense projection layers + // dense linear projections are applied after pooling + if (dense_2_out_layers != nullptr || dense_3_out_layers != nullptr) { + llm->build_dense_out(dense_2_out_layers, dense_3_out_layers); + } + return llm->res->get_gf(); } diff --git a/src/llama-model.h b/src/llama-model.h index eec564e70b69e..647303c9c3778 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -437,6 +437,12 @@ struct llama_model { std::vector layers; + //Dense linear projections for SentenceTransformers models like embeddinggemma + // For Sentence Transformers models structure see + // https://sbert.net/docs/sentence_transformer/usage/custom_models.html#structure-of-sentence-transformer-models + ggml_tensor * dense_2_out_layers; + ggml_tensor * dense_3_out_layers; + llama_model_params params; // gguf metadata