From 877566d512d02a3a288ea38c1f7b1c25ccd6c082 Mon Sep 17 00:00:00 2001 From: Aaron Teo Date: Tue, 25 Nov 2025 09:56:07 +0800 Subject: [PATCH 1/7] llama: introduce support for model-embedded sampling parameters (#17120) --- common/arg.cpp | 12 ++++++ common/common.cpp | 55 ++++++++++++++++++++++++ common/common.h | 18 ++++++++ gguf-py/gguf/constants.py | 14 ++++++ gguf-py/gguf/gguf_writer.py | 36 ++++++++++++++++ gguf-py/gguf/metadata.py | 85 +++++++++++++++++++++++++++++++++++++ include/llama.h | 18 ++++++++ src/llama-arch.cpp | 38 +++++++++++------ src/llama-arch.h | 12 ++++++ src/llama-model.cpp | 18 ++++++++ 10 files changed, 293 insertions(+), 13 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 430ab45dfe2..dd787290d25 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1232,6 +1232,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { const auto sampler_names = string_split(value, ';'); params.sampling.samplers = common_sampler_types_from_names(sampler_names, true); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS; } ).set_sparam()); add_opt(common_arg( @@ -1261,6 +1262,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.sampling.temp = std::stof(value); params.sampling.temp = std::max(params.sampling.temp, 0.0f); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP; } ).set_sparam()); add_opt(common_arg( @@ -1268,6 +1270,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("top-k sampling (default: %d, 0 = disabled)", params.sampling.top_k), [](common_params & params, int value) { params.sampling.top_k = value; + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K; } ).set_sparam()); add_opt(common_arg( @@ -1275,6 +1278,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p), [](common_params & params, const std::string & value) { params.sampling.top_p = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P; } ).set_sparam()); add_opt(common_arg( @@ -1282,6 +1286,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p), [](common_params & params, const std::string & value) { params.sampling.min_p = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P; } ).set_sparam()); add_opt(common_arg( @@ -1296,6 +1301,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability), [](common_params & params, const std::string & value) { params.sampling.xtc_probability = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY; } ).set_sparam()); add_opt(common_arg( @@ -1303,6 +1309,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold), [](common_params & params, const std::string & value) { params.sampling.xtc_threshold = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD; } ).set_sparam()); add_opt(common_arg( @@ -1321,6 +1328,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } params.sampling.penalty_last_n = value; params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N; } ).set_sparam()); add_opt(common_arg( @@ -1328,6 +1336,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat), [](common_params & params, const std::string & value) { params.sampling.penalty_repeat = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT; } ).set_sparam()); add_opt(common_arg( @@ -1425,6 +1434,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sampling.mirostat), [](common_params & params, int value) { params.sampling.mirostat = value; + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT; } ).set_sparam()); add_opt(common_arg( @@ -1432,6 +1442,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta), [](common_params & params, const std::string & value) { params.sampling.mirostat_eta = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA; } ).set_sparam()); add_opt(common_arg( @@ -1439,6 +1450,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau), [](common_params & params, const std::string & value) { params.sampling.mirostat_tau = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU; } ).set_sparam()); add_opt(common_arg( diff --git a/common/common.cpp b/common/common.cpp index f3cc55247e7..0d7fd9a9371 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -8,6 +8,7 @@ #include "common.h" #include "log.h" #include "llama.h" +#include "sampling.h" #include #include @@ -949,6 +950,58 @@ std::vector fs_list_files(const std::string & path) { // Model utils // +static inline void common_init_sampler_from_model( + const llama_model * model, + common_params_sampling & sparams) { + + const uint64_t config = sparams.user_sampling_config; + + auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) { + if (config & user_config) return; + + char buf[64] = {0}; + if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) { + char * end = nullptr; + int32_t v = strtol(buf, &end, 10); + if (end && end != buf) dst = v; + } + }; + + auto get_float = [&](const char * key, float & dst, uint64_t user_config) { + if (config & user_config) return; + + char buf[128] = {0}; + if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) { + char * end = nullptr; + float v = strtof(buf, &end); + if (end && end != buf) dst = v; + } + }; + + // Sampling sequence + if (!(config & common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS)) { + char buf[512] = {0}; + if (llama_model_meta_val_str(model, llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE), buf, sizeof(buf)) > 0) { + const std::vector sampler_names = string_split(std::string(buf), ';'); + if (!sampler_names.empty()) { + sparams.samplers = common_sampler_types_from_names(sampler_names, true); + } + } + } + + get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_K), sparams.top_k, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_P), sparams.top_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIN_P), sparams.min_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY), sparams.xtc_probability, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD), sparams.xtc_threshold, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TEMP), sparams.temp, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP); + get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N), sparams.penalty_last_n, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT), sparams.penalty_repeat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT); + get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT), sparams.mirostat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU), sparams.mirostat_tau, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA); +} + struct common_init_result common_init_from_params(common_params & params) { common_init_result iparams; auto mparams = common_model_params_to_llama(params); @@ -960,6 +1013,8 @@ struct common_init_result common_init_from_params(common_params & params) { return iparams; } + common_init_sampler_from_model(model, params.sampling); + const llama_vocab * vocab = llama_model_get_vocab(model); auto cparams = common_context_params_to_llama(params); diff --git a/common/common.h b/common/common.h index de5b404dd88..2f23d0baa83 100644 --- a/common/common.h +++ b/common/common.h @@ -140,6 +140,22 @@ struct common_grammar_trigger { llama_token token = LLAMA_TOKEN_NULL; }; +enum common_params_sampling_config : uint64_t { + COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS = 1 << 0, + COMMON_PARAMS_SAMPLING_CONFIG_TOP_K = 1 << 1, + COMMON_PARAMS_SAMPLING_CONFIG_TOP_P = 1 << 2, + COMMON_PARAMS_SAMPLING_CONFIG_MIN_P = 1 << 3, + COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY = 1 << 4, + COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD = 1 << 5, + COMMON_PARAMS_SAMPLING_CONFIG_TEMP = 1 << 6, + COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N = 1 << 7, + COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT = 1 << 8, + COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT = 1 << 9, + COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU = 1 << 10, + COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11, +}; + + // sampling parameters struct common_params_sampling { uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler @@ -172,6 +188,8 @@ struct common_params_sampling { bool no_perf = false; // disable performance metrics bool timing_per_token = false; + uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers + std::vector dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 8bc558fe4b5..6f5a742e04a 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -25,6 +25,20 @@ class General: ALIGNMENT = "general.alignment" FILE_TYPE = "general.file_type" + # Recommended Sampler Parameters + SAMPLING_SEQUENCE = "general.sampling.sequence" + SAMPLING_TOP_K = "general.sampling.top_k" + SAMPLING_TOP_P = "general.sampling.top_p" + SAMPLING_MIN_P = "general.sampling.min_p" + SAMPLING_XTC_PROBABILITY = "general.sampling.xtc_probability" + SAMPLING_XTC_THRESHOLD = "general.sampling.xtc_threshold" + SAMPLING_TEMP = "general.sampling.temp" + SAMPLING_PENALTY_LAST_N = "general.sampling.penalty_last_n" + SAMPLING_PENALTY_REPEAT = "general.sampling.penalty_repeat" + SAMPLING_MIROSTAT = "general.sampling.mirostat" + SAMPLING_MIROSTAT_TAU = "general.sampling.mirostat_tau" + SAMPLING_MIROSTAT_ETA = "general.sampling.mirostat_eta" + # Authorship Metadata NAME = "general.name" AUTHOR = "general.author" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index a051daeeb13..642ae2ae596 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -496,6 +496,42 @@ def add_custom_alignment(self, alignment: int) -> None: def add_file_type(self, ftype: int) -> None: self.add_uint32(Keys.General.FILE_TYPE, ftype) + def add_sampling_sequence(self, sequence: str) -> None: + self.add_string(Keys.General.SAMPLING_SEQUENCE, sequence) + + def add_sampling_top_k(self, top_k: int) -> None: + self.add_int32(Keys.General.SAMPLING_TOP_K, top_k) + + def add_sampling_top_p(self, top_p: float) -> None: + self.add_float32(Keys.General.SAMPLING_TOP_P, top_p) + + def add_sampling_min_p(self, min_p: float) -> None: + self.add_float32(Keys.General.SAMPLING_MIN_P, min_p) + + def add_sampling_xtc_probability(self, xtc_probability: float) -> None: + self.add_float32(Keys.General.SAMPLING_XTC_PROBABILITY, xtc_probability) + + def add_sampling_xtc_threshold(self, xtc_threshold: float) -> None: + self.add_float32(Keys.General.SAMPLING_XTC_THRESHOLD, xtc_threshold) + + def add_sampling_temp(self, temp: float) -> None: + self.add_float32(Keys.General.SAMPLING_TEMP, temp) + + def add_sampling_penalty_last_n(self, penalty_last_n: int) -> None: + self.add_int32(Keys.General.SAMPLING_PENALTY_LAST_N, penalty_last_n) + + def add_sampling_penalty_repeat(self, penalty_repeat: float) -> None: + self.add_float32(Keys.General.SAMPLING_PENALTY_REPEAT, penalty_repeat) + + def add_sampling_mirostat(self, mirostat: int) -> None: + self.add_int32(Keys.General.SAMPLING_MIROSTAT, mirostat) + + def add_sampling_mirostat_tau(self, mirostat_tau: float) -> None: + self.add_float32(Keys.General.SAMPLING_MIROSTAT_TAU, mirostat_tau) + + def add_sampling_mirostat_eta(self, mirostat_eta: float) -> None: + self.add_float32(Keys.General.SAMPLING_MIROSTAT_ETA, mirostat_eta) + def add_name(self, name: str) -> None: self.add_string(Keys.General.NAME, name) diff --git a/gguf-py/gguf/metadata.py b/gguf-py/gguf/metadata.py index 67efedbdbc5..e0d478ce95d 100644 --- a/gguf-py/gguf/metadata.py +++ b/gguf-py/gguf/metadata.py @@ -17,6 +17,20 @@ @dataclass class Metadata: + # Recommended Sampler Parameters to be written to GGUF KV Store + sampling_sequence: Optional[str] = None + sampling_top_k: Optional[int] = None + sampling_top_p: Optional[float] = None + sampling_min_p: Optional[float] = None + sampling_xtc_probability: Optional[float] = None + sampling_xtc_threshold: Optional[float] = None + sampling_temp: Optional[float] = None + sampling_penalty_last_n: Optional[int] = None + sampling_penalty_repeat: Optional[float] = None + sampling_mirostat: Optional[int] = None + sampling_mirostat_tau: Optional[float] = None + sampling_mirostat_eta: Optional[float] = None + # Authorship Metadata to be written to GGUF KV Store name: Optional[str] = None author: Optional[str] = None @@ -54,15 +68,43 @@ def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Pat model_card = Metadata.load_model_card(model_path) hf_params = Metadata.load_hf_parameters(model_path) + gen_config = Metadata.load_generation_config(model_path) # TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter # heuristics metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params) + if gen_config: + metadata.sampling_sequence = gen_config.get("sequence", metadata.sampling_sequence) + metadata.sampling_top_k = gen_config.get("top_k", metadata.sampling_top_k) + metadata.sampling_top_p = gen_config.get("top_p", metadata.sampling_top_p) + metadata.sampling_min_p = gen_config.get("min_p", metadata.sampling_min_p) + metadata.sampling_xtc_probability = gen_config.get("xtc_probability", metadata.sampling_xtc_probability) + metadata.sampling_xtc_threshold = gen_config.get("xtc_threshold", metadata.sampling_xtc_threshold) + metadata.sampling_temp = gen_config.get("temperature", metadata.sampling_temp) + metadata.sampling_penalty_last_n = gen_config.get("penalty_last_n", metadata.sampling_penalty_last_n) + metadata.sampling_penalty_repeat = gen_config.get("penalty_repeat", metadata.sampling_penalty_repeat) + metadata.sampling_mirostat = gen_config.get("mirostat", metadata.sampling_mirostat) + metadata.sampling_mirostat_tau = gen_config.get("mirostat_tau", metadata.sampling_mirostat_tau) + metadata.sampling_mirostat_eta = gen_config.get("mirostat_eta", metadata.sampling_mirostat_eta) + # Metadata Override File Provided # This is based on LLM_KV_NAMES mapping in llama.cpp metadata_override = Metadata.load_metadata_override(metadata_override_path) + metadata.sampling_sequence = metadata_override.get(Keys.General.SAMPLING_SEQUENCE, metadata.sampling_sequence) + metadata.sampling_top_k = metadata_override.get(Keys.General.SAMPLING_TOP_K, metadata.sampling_top_k) + metadata.sampling_top_p = metadata_override.get(Keys.General.SAMPLING_TOP_P, metadata.sampling_top_p) + metadata.sampling_min_p = metadata_override.get(Keys.General.SAMPLING_MIN_P, metadata.sampling_min_p) + metadata.sampling_xtc_probability = metadata_override.get(Keys.General.SAMPLING_XTC_PROBABILITY, metadata.sampling_xtc_probability) + metadata.sampling_xtc_threshold = metadata_override.get(Keys.General.SAMPLING_XTC_THRESHOLD, metadata.sampling_xtc_threshold) + metadata.sampling_temp = metadata_override.get(Keys.General.SAMPLING_TEMP, metadata.sampling_temp) + metadata.sampling_penalty_last_n = metadata_override.get(Keys.General.SAMPLING_PENALTY_LAST_N, metadata.sampling_penalty_last_n) + metadata.sampling_penalty_repeat = metadata_override.get(Keys.General.SAMPLING_PENALTY_REPEAT, metadata.sampling_penalty_repeat) + metadata.sampling_mirostat = metadata_override.get(Keys.General.SAMPLING_MIROSTAT, metadata.sampling_mirostat) + metadata.sampling_mirostat_tau = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_TAU, metadata.sampling_mirostat_tau) + metadata.sampling_mirostat_eta = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_ETA, metadata.sampling_mirostat_eta) + metadata.name = metadata_override.get(Keys.General.NAME, metadata.name) metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author) metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version) @@ -172,6 +214,23 @@ def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, Any]: with open(config_path, "r", encoding="utf-8") as f: return json.load(f) + @staticmethod + def load_generation_config(model_path: Optional[Path] = None) -> dict[str, Any]: + if model_path is None or not model_path.is_dir(): + return {} + + generation_config_path = model_path / "generation_config.json" + + if not generation_config_path.is_file(): + return {} + + try: + with open(generation_config_path, "r", encoding="utf-8") as f: + return json.load(f) + except (json.JSONDecodeError, IOError): + # not all models have valid generation_config.json + return {} + @staticmethod def id_to_title(string): # Convert capitalization into title form unless acronym or version number @@ -546,6 +605,32 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str): def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter): assert self.name is not None + + if self.sampling_sequence is not None: + gguf_writer.add_sampling_sequence(self.sampling_sequence) + if self.sampling_top_k is not None: + gguf_writer.add_sampling_top_k(self.sampling_top_k) + if self.sampling_top_p is not None: + gguf_writer.add_sampling_top_p(self.sampling_top_p) + if self.sampling_min_p is not None: + gguf_writer.add_sampling_min_p(self.sampling_min_p) + if self.sampling_xtc_probability is not None: + gguf_writer.add_sampling_xtc_probability(self.sampling_xtc_probability) + if self.sampling_xtc_threshold is not None: + gguf_writer.add_sampling_xtc_threshold(self.sampling_xtc_threshold) + if self.sampling_temp is not None: + gguf_writer.add_sampling_temp(self.sampling_temp) + if self.sampling_penalty_last_n is not None: + gguf_writer.add_sampling_penalty_last_n(self.sampling_penalty_last_n) + if self.sampling_penalty_repeat is not None: + gguf_writer.add_sampling_penalty_repeat(self.sampling_penalty_repeat) + if self.sampling_mirostat is not None: + gguf_writer.add_sampling_mirostat(self.sampling_mirostat) + if self.sampling_mirostat_tau is not None: + gguf_writer.add_sampling_mirostat_tau(self.sampling_mirostat_tau) + if self.sampling_mirostat_eta is not None: + gguf_writer.add_sampling_mirostat_eta(self.sampling_mirostat_eta) + gguf_writer.add_name(self.name) if self.author is not None: diff --git a/include/llama.h b/include/llama.h index 8547226ff21..b52eaacfa7e 100644 --- a/include/llama.h +++ b/include/llama.h @@ -246,6 +246,21 @@ extern "C" { LLAMA_KV_OVERRIDE_TYPE_STR, }; + enum llama_model_meta_key { + LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE, + LLAMA_MODEL_META_KEY_SAMPLING_TOP_K, + LLAMA_MODEL_META_KEY_SAMPLING_TOP_P, + LLAMA_MODEL_META_KEY_SAMPLING_MIN_P, + LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY, + LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD, + LLAMA_MODEL_META_KEY_SAMPLING_TEMP, + LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N, + LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT, + LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT, + LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU, + LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA, + }; + struct llama_model_kv_override { enum llama_model_kv_override_type tag; @@ -518,6 +533,9 @@ extern "C" { // Get the number of metadata key/value pairs LLAMA_API int32_t llama_model_meta_count(const struct llama_model * model); + // Get sampling metadata key name. Returns nullptr if the key is invalid + LLAMA_API const char * llama_model_meta_key_str(enum llama_model_meta_key key); + // Get metadata key name by index LLAMA_API int32_t llama_model_meta_key_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size); diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index fc6cddc92f5..7ef87acf1b3 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -114,19 +114,31 @@ static const std::map LLM_ARCH_NAMES = { }; static const std::map LLM_KV_NAMES = { - { LLM_KV_GENERAL_TYPE, "general.type" }, - { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, - { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, - { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, - { LLM_KV_GENERAL_FILE_TYPE, "general.file_type" }, - { LLM_KV_GENERAL_NAME, "general.name" }, - { LLM_KV_GENERAL_AUTHOR, "general.author" }, - { LLM_KV_GENERAL_VERSION, "general.version" }, - { LLM_KV_GENERAL_URL, "general.url" }, - { LLM_KV_GENERAL_DESCRIPTION, "general.description" }, - { LLM_KV_GENERAL_LICENSE, "general.license" }, - { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" }, - { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" }, + { LLM_KV_GENERAL_TYPE, "general.type" }, + { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, + { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, + { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, + { LLM_KV_GENERAL_FILE_TYPE, "general.file_type" }, + { LLM_KV_GENERAL_SAMPLING_SEQUENCE, "general.sampling.sequence" }, + { LLM_KV_GENERAL_SAMPLING_TOP_K, "general.sampling.top_k" }, + { LLM_KV_GENERAL_SAMPLING_TOP_P, "general.sampling.top_p" }, + { LLM_KV_GENERAL_SAMPLING_MIN_P, "general.sampling.min_p" }, + { LLM_KV_GENERAL_SAMPLING_XTC_PROBABILITY, "general.sampling.xtc_probability" }, + { LLM_KV_GENERAL_SAMPLING_XTC_THRESHOLD, "general.sampling.xtc_threshold" }, + { LLM_KV_GENERAL_SAMPLING_TEMP, "general.sampling.temp" }, + { LLM_KV_GENERAL_SAMPLING_PENALTY_LAST_N, "general.sampling.penalty_last_n" }, + { LLM_KV_GENERAL_SAMPLING_PENALTY_REPEAT, "general.sampling.penalty_repeat" }, + { LLM_KV_GENERAL_SAMPLING_MIROSTAT, "general.sampling.mirostat" }, + { LLM_KV_GENERAL_SAMPLING_MIROSTAT_TAU, "general.sampling.mirostat_tau" }, + { LLM_KV_GENERAL_SAMPLING_MIROSTAT_ETA, "general.sampling.mirostat_eta" }, + { LLM_KV_GENERAL_NAME, "general.name" }, + { LLM_KV_GENERAL_AUTHOR, "general.author" }, + { LLM_KV_GENERAL_VERSION, "general.version" }, + { LLM_KV_GENERAL_URL, "general.url" }, + { LLM_KV_GENERAL_DESCRIPTION, "general.description" }, + { LLM_KV_GENERAL_LICENSE, "general.license" }, + { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" }, + { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" }, { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index 02a1c2dc258..9ad3157bf67 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -123,6 +123,18 @@ enum llm_kv { LLM_KV_GENERAL_QUANTIZATION_VERSION, LLM_KV_GENERAL_ALIGNMENT, LLM_KV_GENERAL_FILE_TYPE, + LLM_KV_GENERAL_SAMPLING_SEQUENCE, + LLM_KV_GENERAL_SAMPLING_TOP_K, + LLM_KV_GENERAL_SAMPLING_TOP_P, + LLM_KV_GENERAL_SAMPLING_MIN_P, + LLM_KV_GENERAL_SAMPLING_XTC_PROBABILITY, + LLM_KV_GENERAL_SAMPLING_XTC_THRESHOLD, + LLM_KV_GENERAL_SAMPLING_TEMP, + LLM_KV_GENERAL_SAMPLING_PENALTY_LAST_N, + LLM_KV_GENERAL_SAMPLING_PENALTY_REPEAT, + LLM_KV_GENERAL_SAMPLING_MIROSTAT, + LLM_KV_GENERAL_SAMPLING_MIROSTAT_TAU, + LLM_KV_GENERAL_SAMPLING_MIROSTAT_ETA, LLM_KV_GENERAL_NAME, LLM_KV_GENERAL_AUTHOR, LLM_KV_GENERAL_VERSION, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 35179a98e0c..a042ea9632c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -7687,6 +7687,24 @@ int32_t llama_model_meta_count(const llama_model * model) { return (int)model->gguf_kv.size(); } +const char * llama_model_meta_key_str(llama_model_meta_key key) { + switch (key) { + case LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE: return "general.sampling.sequence"; + case LLAMA_MODEL_META_KEY_SAMPLING_TOP_K: return "general.sampling.top_k"; + case LLAMA_MODEL_META_KEY_SAMPLING_TOP_P: return "general.sampling.top_p"; + case LLAMA_MODEL_META_KEY_SAMPLING_MIN_P: return "general.sampling.min_p"; + case LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY: return "general.sampling.xtc_probability"; + case LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD: return "general.sampling.xtc_threshold"; + case LLAMA_MODEL_META_KEY_SAMPLING_TEMP: return "general.sampling.temp"; + case LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N: return "general.sampling.penalty_last_n"; + case LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT: return "general.sampling.penalty_repeat"; + case LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT: return "general.sampling.mirostat"; + case LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU: return "general.sampling.mirostat_tau"; + case LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA: return "general.sampling.mirostat_eta"; + default: return nullptr; + } +} + int32_t llama_model_meta_key_by_index(const llama_model * model, int i, char * buf, size_t buf_size) { if (i < 0 || i >= (int)model->gguf_kv.size()) { if (buf_size > 0) { From d414db02d3ab3744402bc57a7b3fce7de66e3d5a Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 25 Nov 2025 00:11:27 -0600 Subject: [PATCH 2/7] vulkan: Use fewer rows for scalar FA when HS is not a multiple of 16 (#17455) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 12 +++++++----- tests/test-backend-ops.cpp | 3 +++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d78c727e53b..6cf15b43bb3 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2501,9 +2501,11 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events static constexpr uint32_t flash_attention_num_small_rows = 32; static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; -static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) { +static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv) { if (hsv >= 192) { return 2; + } else if ((hsv | hsk) & 8) { + return 4; } else { return 8; } @@ -2535,9 +2537,9 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 if ((hsv | hsk) & 8) { // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. - return {get_fa_scalar_num_large_rows(hsv), 64}; + return {get_fa_scalar_num_large_rows(hsk, hsv), 64}; } else { - return {get_fa_scalar_num_large_rows(hsv), 32}; + return {get_fa_scalar_num_large_rows(hsk, hsv), 32}; } } } @@ -7740,7 +7742,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con // Needs to be kept up to date on shader changes GGML_UNUSED(hsv); const uint32_t wg_size = scalar_flash_attention_workgroup_size; - const uint32_t Br = get_fa_scalar_num_large_rows(hsv); + const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv); const uint32_t Bc = scalar_flash_attention_Bc; const uint32_t tmpsh = wg_size * sizeof(float); @@ -7871,7 +7873,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx case FA_SCALAR: case FA_COOPMAT1: // We may switch from coopmat1 to scalar, so use the scalar limit for both - max_gqa = get_fa_scalar_num_large_rows(HSV); + max_gqa = get_fa_scalar_num_large_rows(HSK, HSV); break; case FA_COOPMAT2: max_gqa = get_fa_num_small_rows(FA_COOPMAT2); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index ce8c068d7aa..fd48d254752 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7859,6 +7859,9 @@ static std::vector> make_test_cases_perf() { } } + // Qwen3-VL-8B https://github.com/ggml-org/llama.cpp/issues/17012 + test_cases.emplace_back(new test_flash_attn_ext(72, 72, 16, {1, 1}, 5776, 5776, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16)); + for (int kv : { 4096, 8192, 16384, }) { for (int hs : { 64, 128, }) { for (int nr : { 1, 4, }) { From b1846f1c8ecd97ee08593e9498ef3244d43c1ad6 Mon Sep 17 00:00:00 2001 From: Pascal Date: Tue, 25 Nov 2025 08:01:02 +0100 Subject: [PATCH 3/7] webui: add rehype plugin to restore HTML in Markdown table cells (#17477) * webui: add rehype plugin to restore HTML in Markdown table cells The remark/rehype pipeline neutralizes inline HTML as literal text (remarkLiteralHtml) so that XML/HTML snippets in LLM responses display as-is instead of being rendered. This causes
and