Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 71 additions & 3 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,15 @@ class ModelBase:
# Mistral format specifics
is_mistral_format: bool = False
disable_mistral_community_chat_template: bool = False
sentence_transformers_dense_modules: bool = False

def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
use_temp_file: bool = False, eager: bool = False,
metadata_override: Path | None = None, model_name: str | None = None,
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
disable_mistral_community_chat_template: bool = False):
disable_mistral_community_chat_template: bool = False,
sentence_transformers_dense_modules: bool = False):
if type(self) is ModelBase or \
type(self) is TextModel or \
type(self) is MmprojModel:
Expand All @@ -114,6 +116,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
self.lazy = not eager or (remote_hf_model_id is not None)
self.dry_run = dry_run
self.remote_hf_model_id = remote_hf_model_id
self.sentence_transformers_dense_modules = sentence_transformers_dense_modules
if remote_hf_model_id is not None:
self.is_safetensors = True

Expand Down Expand Up @@ -5260,6 +5263,54 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
@ModelBase.register("Gemma3TextModel")
class EmbeddingGemma(Gemma3Model):
model_arch = gguf.MODEL_ARCH.GEMMA_EMBEDDING
module_paths = []
dense_features_dims = {}

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.sentence_transformers_dense_modules:
# read molues.json to determine if model has Dense layers
modules_file = self.dir_model / "modules.json"
if modules_file.is_file():
with open(modules_file, encoding="utf-8") as modules_json_file:
mods = json.load(modules_json_file)
for mod in mods:
if mod["type"] == "sentence_transformers.models.Dense":
mod_path = mod["path"]
# check if model.safetensors file for Dense layer exists
model_tensors_file = self.dir_model / mod_path / "model.safetensors"
if model_tensors_file.is_file():
self.module_paths.append(mod_path)
# read config.json of the Dense layer to get in/out features
mod_conf_file = self.dir_model / mod_path / "config.json"
if mod_conf_file.is_file():
with open(mod_conf_file, encoding="utf-8") as mod_conf_json_file:
mod_conf = json.load(mod_conf_json_file)
# hparams dense_2_feat_out and dense_3_feat_in are required when loading model's dense weights
prefix = self._get_dense_prefix(mod_path)
if (mod_conf["in_features"] is not None
and mod_conf["out_features"] is not None):
self.dense_features_dims[prefix] = (mod_conf["in_features"], mod_conf["out_features"])

def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
from safetensors.torch import load_file
module_paths = list(self.module_paths)
for i, module_path in enumerate(module_paths):
tensors_file = self.dir_model / module_path / "model.safetensors"
local_tensors = load_file(tensors_file)
tensor_name = self._get_dense_prefix(module_path)
for name, local_tensor in local_tensors.items():
if not name.endswith(".weight"):
continue
orig_name = name.replace("linear", tensor_name)
name = self.map_tensor_name(orig_name)
yield name, local_tensor.clone()

@staticmethod
def _get_dense_prefix(module_path) -> str:
"""Get the tensor name prefix for the Dense layer from module path."""
tensor_name = "dense_2" if module_path == "2_Dense" else "dense_3"
return tensor_name

def set_gguf_parameters(self):
super().set_gguf_parameters()
Expand All @@ -5276,6 +5327,11 @@ def set_gguf_parameters(self):
logger.info(f"Using original sliding_window from config: {orig_sliding_window} "
f"instead of {self.hparams['sliding_window']}")
self.gguf_writer.add_sliding_window(orig_sliding_window)
if self.sentence_transformers_dense_modules:
for dense, dims in self.dense_features_dims.items():
logger.info(f"Setting dense layer {dense} in/out features to {dims}")
self.gguf_writer.add_dense_features_dims(dense, dims[0], dims[1])
self.gguf_writer.add_pooling_type_opt(False)

self._try_set_pooling_type()

Expand Down Expand Up @@ -9257,6 +9313,13 @@ def parse_args() -> argparse.Namespace:
)
)

parser.add_argument(
"--sentence-transformers-dense-modules", action="store_true",
help=("Whether to include sentence-transformers dense modules."
"It can be used for sentence-transformers models, like google/embeddinggemma-300m"
"Default these modules are not included.")
)

args = parser.parse_args()
if not args.print_supported_models and args.model is None:
parser.error("the following arguments are required: model")
Expand Down Expand Up @@ -9319,9 +9382,13 @@ def main() -> None:
if args.remote:
hf_repo_id = args.model
from huggingface_hub import snapshot_download
allowed_patterns = ["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"]
if args.sentence_transformers_dense_modules:
# include sentence-transformers dense modules safetensors files
allowed_patterns.append("*.safetensors")
local_dir = snapshot_download(
repo_id=hf_repo_id,
allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"])
allow_patterns=allowed_patterns)
dir_model = Path(local_dir)
logger.info(f"Downloaded config and tokenizer to {local_dir}")
else:
Expand Down Expand Up @@ -9389,7 +9456,8 @@ def main() -> None:
split_max_tensors=args.split_max_tensors,
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split,
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template,
sentence_transformers_dense_modules=args.sentence_transformers_dense_modules
)

if args.vocab_only:
Expand Down
9 changes: 9 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ class LLM:
ALTUP_ACTIVE_IDX = "{arch}.altup.active_idx"
ALTUP_NUM_INPUTS = "{arch}.altup.num_inputs"
EMBD_LENGTH_PER_LAYER_INP = "{arch}.embedding_length_per_layer_input"
DENSE_FEAT_IN_SIZE = "{arch}.{dense}_feat_in"
DENSE_FEAT_OUT_SIZE = "{arch}.{dense}_feat_out"
POOLING_TYPE_OPT = "{arch}.pooling_type_opt"

class Attention:
HEAD_COUNT = "{arch}.attention.head_count"
Expand Down Expand Up @@ -431,6 +434,8 @@ class MODEL_TENSOR(IntEnum):
TOKEN_TYPES = auto()
POS_EMBD = auto()
OUTPUT = auto()
DENSE_2_OUT = auto() # embeddinggemma 2_Dense
DENSE_3_OUT = auto() # embeddinggemma 3_Dense
OUTPUT_NORM = auto()
ROPE_FREQS = auto()
ROPE_FACTORS_LONG = auto()
Expand Down Expand Up @@ -774,6 +779,8 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.POS_EMBD: "position_embd",
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
MODEL_TENSOR.OUTPUT: "output",
MODEL_TENSOR.DENSE_2_OUT: "dense_2", # embeddinggemma 2_Dense
MODEL_TENSOR.DENSE_3_OUT: "dense_3", # embeddinggemma 2_Dense
MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long",
MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short",
Expand Down Expand Up @@ -1756,6 +1763,8 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.GEMMA_EMBEDDING: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.DENSE_2_OUT,
MODEL_TENSOR.DENSE_3_OUT,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
Expand Down
7 changes: 7 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,13 @@ def add_shared_kv_layers(self, value: int) -> None:
def add_sliding_window_pattern(self, value: Sequence[bool]) -> None:
self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value)

def add_dense_features_dims(self, dense:str, in_f:int, out_f:int) -> None:
self.add_uint32(Keys.LLM.DENSE_FEAT_IN_SIZE.format(arch=self.arch, dense=dense), in_f)
self.add_uint32(Keys.LLM.DENSE_FEAT_OUT_SIZE.format(arch=self.arch, dense=dense), out_f)

def add_pooling_type_opt(self, enable: bool) -> None:
self.add_bool(Keys.LLM.POOLING_TYPE_OPT.format(arch=self.arch), enable)

def add_logit_scale(self, value: float) -> None:
self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)

Expand Down
7 changes: 6 additions & 1 deletion gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,12 @@ class TensorNameMap:
"lm_head", # llama4
"model.transformer.ff_out", # llada
),

MODEL_TENSOR.DENSE_2_OUT: (
"dense_2_out", # embeddinggemma
),
MODEL_TENSOR.DENSE_3_OUT: (
"dense_3_out", # embeddinggemma
),
# Output norm
MODEL_TENSOR.OUTPUT_NORM: (
"gpt_neox.final_layer_norm", # gptneox
Expand Down
10 changes: 10 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,12 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" },

{ LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" },
// sentence-transformers dense modules feature dims
{ LLM_KV_DENSE_2_FEAT_IN, "%s.dense_2_feat_in" },
{ LLM_KV_DENSE_2_FEAT_OUT, "%s.dense_2_feat_out" },
{ LLM_KV_DENSE_3_FEAT_IN, "%s.dense_3_feat_in" },
{ LLM_KV_DENSE_3_FEAT_OUT, "%s.dense_3_feat_out" },
{ LLM_KV_POOLING_TYPE_OPT, "%s.pooling_type_opt" },

{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
Expand Down Expand Up @@ -1070,6 +1076,8 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_DENSE_2_OUT, "dense_2" },
{ LLM_TENSOR_DENSE_3_OUT, "dense_3" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
Expand Down Expand Up @@ -2254,6 +2262,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
{LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
{LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
Expand Down
9 changes: 9 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,22 @@ enum llm_kv {
LLM_KV_TOKENIZER_PREFIX_ID,
LLM_KV_TOKENIZER_SUFFIX_ID,
LLM_KV_TOKENIZER_MIDDLE_ID,

// sentence-transformers dense layers in and out features
LLM_KV_DENSE_2_FEAT_IN,
LLM_KV_DENSE_2_FEAT_OUT,
LLM_KV_DENSE_3_FEAT_IN,
LLM_KV_DENSE_3_FEAT_OUT,
LLM_KV_POOLING_TYPE_OPT,
};

enum llm_tensor {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_TOKEN_EMBD_NORM,
LLM_TENSOR_TOKEN_TYPES,
LLM_TENSOR_POS_EMBD,
LLM_TENSOR_DENSE_2_OUT,
LLM_TENSOR_DENSE_3_OUT,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_ROPE_FREQS,
Expand Down
9 changes: 9 additions & 0 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2346,6 +2346,15 @@ llama_context * llama_init_from_model(
return nullptr;
}

// if setting pooling_type is disabled, set it to model default
// for sentence-transformers models (e.g. EmbeddingGemma) mean-pooling is required
// when dense layers are enabled
if (!model->hparams.pooling_type_opt) {
params.pooling_type = model->hparams.pooling_type;
LLAMA_LOG_INFO("%s: setting pooling_type to models default: %d\n", __func__, params.pooling_type);

}

try {
auto * ctx = new llama_context(*model, params);
return ctx;
Expand Down
17 changes: 17 additions & 0 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1853,6 +1853,23 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
}

void llm_graph_context::build_dense_out(
ggml_tensor *dense_2,
ggml_tensor *dense_3) const {
ggml_tensor *cur = res->get_embd_pooled();
if (dense_2 != nullptr) {
cur = ggml_mul_mat(ctx0, dense_2, cur);
cb(cur, "result_embd_pooled", -1);
}
if (dense_3 != nullptr) {
cur = ggml_mul_mat(ctx0, dense_3, cur);
cb(cur, "result_embd_pooled", -1);
}
res->t_embd_pooled = cur;
ggml_build_forward_expand(gf, cur);
}


void llm_graph_context::build_pooling(
ggml_tensor * cls,
ggml_tensor * cls_b,
Expand Down
8 changes: 8 additions & 0 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,14 @@ struct llm_graph_context {
ggml_tensor * cls_b,
ggml_tensor * cls_out,
ggml_tensor * cls_out_b) const;

//
// dense (out)
//

void build_dense_out(
ggml_tensor * dense_2,
ggml_tensor * dense_3) const;
};

// TODO: better name
Expand Down
8 changes: 8 additions & 0 deletions src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,14 @@ struct llama_hparams {
uint32_t laurel_rank = 64;
uint32_t n_embd_altup = 256;

// needed for sentence-transformers dense layers
uint32_t dense_2_feat_in = 0; // in_features of the 2_Dense
uint32_t dense_2_feat_out = 0; // out_features of the 2_Dense
uint32_t dense_3_feat_in = 0; // in_features of the 3_Dense
uint32_t dense_3_feat_out = 0; // out_features of the 3_Dense

// whether pooling_type can be overridden by user
bool pooling_type_opt = true;
// xIELU
std::array<float, LLAMA_MAX_LAYERS> xielu_alpha_n;
std::array<float, LLAMA_MAX_LAYERS> xielu_alpha_p;
Expand Down
38 changes: 29 additions & 9 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1215,20 +1215,28 @@ void llama_model::load_hparams(llama_model_loader & ml) {
hparams.set_swa_pattern(6);

hparams.causal_attn = false; // embeddings do not use causal attention
hparams.rope_freq_base_train_swa = 10000.0f;
hparams.rope_freq_base_train_swa = 10000.0f;
hparams.rope_freq_scale_train_swa = 1.0f;

ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);

switch (hparams.n_layer) {
case 24: type = LLM_TYPE_0_3B; break;
default: type = LLM_TYPE_UNKNOWN;
}
hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k));
//applied only if model converted with --sentence-transformers-dense-modules
ml.get_key(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in, false);
ml.get_key(LLM_KV_DENSE_2_FEAT_OUT, hparams.dense_2_feat_out, false);
ml.get_key(LLM_KV_DENSE_3_FEAT_IN, hparams.dense_3_feat_in, false);
ml.get_key(LLM_KV_DENSE_3_FEAT_OUT, hparams.dense_3_feat_out, false);
ml.get_key(LLM_KV_POOLING_TYPE_OPT, hparams.pooling_type_opt, false);

} break;

switch (hparams.n_layer) {
case 24: type = LLM_TYPE_0_3B; break;
default: type = LLM_TYPE_UNKNOWN;
}
hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k));
}
break;
case LLM_ARCH_STARCODER2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
Expand Down Expand Up @@ -3668,6 +3676,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}

// Dense linear weights
dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.dense_2_feat_out}, TENSOR_NOT_REQUIRED);
dense_3_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_3_OUT, "weight"), {hparams.dense_3_feat_in, n_embd}, TENSOR_NOT_REQUIRED);


for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];

Expand Down Expand Up @@ -19841,6 +19854,13 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
// add on pooling layer
llm->build_pooling(cls, cls_b, cls_out, cls_out_b);

// if the gguf model was converted with --sentence-transformers-dense-modules
// there will be two additional dense projection layers
// dense linear projections are applied after pooling
if (dense_2_out_layers != nullptr || dense_3_out_layers != nullptr) {
llm->build_dense_out(dense_2_out_layers, dense_3_out_layers);
}

return llm->res->get_gf();
}

Expand Down
6 changes: 6 additions & 0 deletions src/llama-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,12 @@ struct llama_model {

std::vector<llama_layer> layers;

//Dense linear projections for SentenceTransformers models like embeddinggemma
// For Sentence Transformers models structure see
// https://sbert.net/docs/sentence_transformer/usage/custom_models.html#structure-of-sentence-transformer-models
ggml_tensor * dense_2_out_layers;
ggml_tensor * dense_3_out_layers;

llama_model_params params;

// gguf metadata
Expand Down