From 3e41c1493563c8176ce1007fcdea4af63a9cf4af Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 25 Nov 2025 14:23:45 +0100 Subject: [PATCH 01/10] conversion script --- convert_hf_to_gguf.py | 45 ++++++++++++++++++++++++++++++++++--- gguf-py/gguf/constants.py | 23 +++++++++++++++++++ gguf-py/gguf/gguf_writer.py | 3 +++ 3 files changed, 68 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index d24a4682f3d..e4014ed7221 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1581,10 +1581,27 @@ def __init__(self, *args, **kwargs): # load preprocessor config self.preprocessor_config = {} - if not self.is_mistral_format: - with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f: + + # prefer preprocessor_config.json if possible + preprocessor_config_path = self.dir_model / "preprocessor_config.json" + if preprocessor_config_path.is_file(): + with open(preprocessor_config_path, "r", encoding="utf-8") as f: self.preprocessor_config = json.load(f) + # prefer processor_config.json if possible + processor_config_path = self.dir_model / "processor_config.json" + if processor_config_path.is_file(): + with open(processor_config_path, "r", encoding="utf-8") as f: + cfg = json.load(f) + # move image_processor to root level for compat + if "image_processor" in cfg: + cfg = { + **cfg, + **cfg["image_processor"], + } + # merge configs + self.preprocessor_config = {**self.preprocessor_config, **cfg} + def get_vision_config(self) -> dict[str, Any] | None: config_name = "vision_config" if not self.is_mistral_format else "vision_encoder" return self.global_config.get(config_name) @@ -2797,7 +2814,29 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter @ModelBase.register("Mistral3ForConditionalGeneration") class Mistral3Model(LlamaModel): - model_arch = gguf.MODEL_ARCH.LLAMA + model_arch = gguf.MODEL_ARCH.MISTRAL3 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # for compatibility, we use LLAMA arch for older models + if self.hparams.get("model_type") != "ministral3": + self.model_arch = gguf.MODEL_ARCH.LLAMA + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + rope_params = self.hparams.get("rope_parameters") + if self.hparams.get("model_type") == "ministral3": + assert rope_params is not None, "ministral3 must have 'rope_parameters' config" + assert rope_params["rope_type"] == "yarn", "ministral3 rope_type must be 'yarn'" + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_params["factor"]) + self.gguf_writer.add_rope_scaling_yarn_beta_fast(rope_params["beta_fast"]) + self.gguf_writer.add_rope_scaling_yarn_beta_slow(rope_params["beta_slow"]) + self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_params["mscale_all_dim"]) # copied from deepseekv2 + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_params["original_max_position_embeddings"]) + self.gguf_writer.add_rope_freq_base(rope_params["rope_theta"]) + self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): name = name.replace("language_model.", "") diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 6f5a742e04a..a77551d6e9f 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -175,6 +175,7 @@ class Attention: VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla" SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers" SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern" + TEMPERATURE_SCALE = "{arch}.attention.temperature_scale" class Rope: DIMENSION_COUNT = "{arch}.rope.dimension_count" @@ -443,6 +444,7 @@ class MODEL_ARCH(IntEnum): MINIMAXM2 = auto() RND1 = auto() PANGU_EMBED = auto() + MISTRAL3 = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -814,6 +816,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.COGVLM: "cogvlm", MODEL_ARCH.RND1: "rnd1", MODEL_ARCH.PANGU_EMBED: "pangu-embedded", + MODEL_ARCH.MISTRAL3: "mistral3", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -3038,6 +3041,26 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.MISTRAL3: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], # TODO } diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 642ae2ae596..ab6dc28956f 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -893,6 +893,9 @@ def add_attn_output_scale(self, value: float) -> None: def add_attn_temperature_length(self, value: int) -> None: self.add_uint32(Keys.Attention.TEMPERATURE_LENGTH.format(arch=self.arch), value) + def add_attn_temperature_scale(self, value: float) -> None: + self.add_float32(Keys.Attention.TEMPERATURE_SCALE.format(arch=self.arch), value) + def add_pooling_type(self, value: PoolingType) -> None: self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value) From 2b2f411a0df33bb3c2d078f6f3176df3918546a6 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 25 Nov 2025 14:59:39 +0100 Subject: [PATCH 02/10] support ministral 3 --- src/CMakeLists.txt | 1 + src/llama-arch.cpp | 28 +++++++ src/llama-arch.h | 2 + src/llama-graph.cpp | 3 + src/llama-hparams.h | 4 +- src/llama-model.cpp | 49 ++++++++++-- src/models/mistral3.cpp | 167 ++++++++++++++++++++++++++++++++++++++++ src/models/models.h | 4 + 8 files changed, 251 insertions(+), 7 deletions(-) create mode 100644 src/models/mistral3.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f7a8c9841ec..0997e21d9a6 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -131,6 +131,7 @@ add_library(llama models/t5-enc.cpp models/wavtokenizer-dec.cpp models/xverse.cpp + models/mistral3.cpp models/graph-context-mamba.cpp ) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 7ef87acf1b3..c1e64b4ee77 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -110,6 +110,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_COGVLM, "cogvlm" }, { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, + { LLM_ARCH_MISTRAL3, "mistral3" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -203,6 +204,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" }, { LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" }, + { LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, @@ -2479,6 +2481,32 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_MISTRAL3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_UNKNOWN, { diff --git a/src/llama-arch.h b/src/llama-arch.h index 9ad3157bf67..8316fc2741d 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -114,6 +114,7 @@ enum llm_arch { LLM_ARCH_COGVLM, LLM_ARCH_RND1, LLM_ARCH_PANGU_EMBED, + LLM_ARCH_MISTRAL3, LLM_ARCH_UNKNOWN, }; @@ -207,6 +208,7 @@ enum llm_kv { LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_OUTPUT_SCALE, LLM_KV_ATTENTION_TEMPERATURE_LENGTH, + LLM_KV_ATTENTION_TEMPERATURE_SCALE, LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 650e40ec6ff..3138424867a 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -71,6 +71,9 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) { if (ubatch->pos && attn_scale) { const int64_t n_tokens = ubatch->n_tokens; + GGML_ASSERT(f_attn_temp_scale != 0.0f); + GGML_ASSERT(n_attn_temp_floor_scale != 0); + std::vector attn_scale_data(n_tokens, 0.0f); for (int i = 0; i < n_tokens; ++i) { const float pos = ubatch->pos[i]; diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 9203af83b2e..270de346d27 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -162,8 +162,8 @@ struct llama_hparams { // llama4 smallthinker uint32_t n_moe_layer_step = 0; uint32_t n_no_rope_layer_step = 4; - uint32_t n_attn_temp_floor_scale = 8192; - float f_attn_temp_scale = 0.1; + uint32_t n_attn_temp_floor_scale = 0; + float f_attn_temp_scale = 0.0f; // gemma3n altup uint32_t n_altup = 4; // altup_num_inputs diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a042ea9632c..790c05f3ea2 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -627,8 +627,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { switch (arch) { case LLM_ARCH_LLAMA: { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - if (hparams.n_expert == 8) { switch (hparams.n_layer) { case 32: type = LLM_TYPE_8x7B; break; @@ -664,8 +662,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_NONE; hparams.n_no_rope_layer_step = hparams.n_layer; // always use rope } else { - hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED; - hparams.n_swa = 8192; + hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED; + hparams.n_swa = 8192; + hparams.n_attn_temp_floor_scale = 8192; + hparams.f_attn_temp_scale = 0.1f; hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full } @@ -2225,6 +2225,39 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_MISTRAL3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); + + ml.get_key(LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, hparams.yarn_attn_factor, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false); + + // TODO: maybe add n_attn_temp_floor_scale as a separate KV? + if (hparams.f_attn_temp_scale != 0.0f) { + hparams.n_attn_temp_floor_scale = hparams.n_ctx_orig_yarn; + if (hparams.n_attn_temp_floor_scale == 0) { + throw std::runtime_error("invalid n_ctx_orig_yarn for attention temperature scaling"); + } + } + + // the same as deepseek2 + hparams.rope_freq_scale_train = 1.0f; + if (hparams.rope_yarn_log_mul != 0.0f) { + float freq_scale = hparams.rope_freq_scale_train; + float mscale = hparams.yarn_attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); + + hparams.f_attention_scale = 1.0f * mscale * mscale / sqrtf(float(hparams.n_embd_head_k)); + hparams.rope_attn_factor = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); + } + + switch (hparams.n_layer) { + // TODO + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -2538,6 +2571,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_MINICPM: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: + case LLM_ARCH_MISTRAL3: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -7425,7 +7459,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { case LLM_ARCH_PANGU_EMBED: { llm = std::make_unique(*this, params); - }break; + } break; + case LLM_ARCH_MISTRAL3: + { + llm = std::make_unique(*this, params); + } break; default: GGML_ABORT("fatal error"); } @@ -7594,6 +7632,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ARCEE: case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: + case LLM_ARCH_MISTRAL3: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 diff --git a/src/models/mistral3.cpp b/src/models/mistral3.cpp new file mode 100644 index 00000000000..3f38a07798f --- /dev/null +++ b/src/models/mistral3.cpp @@ -0,0 +1,167 @@ +#include "models.h" + +llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // (optional) temperature tuning + ggml_tensor * inp_attn_scale = nullptr; + if (hparams.f_attn_temp_scale != 0.0f) { + inp_attn_scale = build_inp_attn_scale(); + } + + auto * inp_attn = build_attn_inp_kv(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + if (inp_attn_scale) { + // apply llama 4 temperature scaling + Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale); + cb(Qcur, "Qcur_attn_temp_scaled", il); + } + + if (hparams.use_kq_norm) { + // Llama4TextL2Norm + Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps); + Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps); + cb(Qcur, "Qcur_normed", il); + cb(Kcur, "Kcur_normed", il); + } + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network (non-MoE) + if (model.layers[il].ffn_gate_inp == nullptr) { + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(cur, "ffn_moe_out", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/models.h b/src/models/models.h index 5f019c59be8..0b9cf3f402c 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -487,3 +487,7 @@ struct llm_build_wavtokenizer_dec : public llm_graph_context { struct llm_build_xverse : public llm_graph_context { llm_build_xverse(const llama_model & model, const llm_graph_params & params); }; + +struct llm_build_mistral3 : public llm_graph_context { + llm_build_mistral3(const llama_model & model, const llm_graph_params & params); +}; From 4cebf7bad09740a6ac4b0d48b036d65548bd7c08 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 25 Nov 2025 15:35:26 +0100 Subject: [PATCH 03/10] maybe this is better? --- convert_hf_to_gguf.py | 4 ++-- src/llama-model.cpp | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index e4014ed7221..cf035f5a6af 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2830,10 +2830,10 @@ def set_gguf_parameters(self): assert rope_params is not None, "ministral3 must have 'rope_parameters' config" assert rope_params["rope_type"] == "yarn", "ministral3 rope_type must be 'yarn'" self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) - self.gguf_writer.add_rope_scaling_factor(rope_params["factor"]) + self.gguf_writer.add_rope_scaling_yarn_attn_factor(rope_params["factor"]) self.gguf_writer.add_rope_scaling_yarn_beta_fast(rope_params["beta_fast"]) self.gguf_writer.add_rope_scaling_yarn_beta_slow(rope_params["beta_slow"]) - self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_params["mscale_all_dim"]) # copied from deepseekv2 + self.gguf_writer.add_rope_scaling_yarn_log_mul(rope_params["mscale_all_dim"]) self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_params["original_max_position_embeddings"]) self.gguf_writer.add_rope_freq_base(rope_params["rope_theta"]) self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"]) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 790c05f3ea2..ba5d4283c59 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2243,14 +2243,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { } } - // the same as deepseek2 - hparams.rope_freq_scale_train = 1.0f; if (hparams.rope_yarn_log_mul != 0.0f) { - float freq_scale = hparams.rope_freq_scale_train; - float mscale = hparams.yarn_attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); - - hparams.f_attention_scale = 1.0f * mscale * mscale / sqrtf(float(hparams.n_embd_head_k)); - hparams.rope_attn_factor = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); + float factor = hparams.yarn_attn_factor; + float mscale = 1.0f; + float mscale_all_dims = hparams.rope_yarn_log_mul; + static auto get_mscale = [](float scale, float mscale) { + return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f); + }; + hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims); } switch (hparams.n_layer) { From 84be00fdedb0a2813c6c617056ef79f322703201 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 25 Nov 2025 15:52:51 +0100 Subject: [PATCH 04/10] add TODO for rope_yarn_log_mul --- src/llama-model.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index ba5d4283c59..601b042d326 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2243,6 +2243,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { } } + // TODO: this seems to be correct with the case of mscale == mscale_all_dims == 1.0f + // but may need further verification with other values if (hparams.rope_yarn_log_mul != 0.0f) { float factor = hparams.yarn_attn_factor; float mscale = 1.0f; From 786b3f8e5ac8d5b229f2480d87156bf2126a72f9 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 26 Nov 2025 00:01:51 +0100 Subject: [PATCH 05/10] better ppl (tested on 14B-Instruct) --- convert_hf_to_gguf.py | 2 +- src/llama-model.cpp | 3 +-- src/models/mistral3.cpp | 7 ------- 3 files changed, 2 insertions(+), 10 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index cf035f5a6af..cd595bd2b74 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2830,7 +2830,7 @@ def set_gguf_parameters(self): assert rope_params is not None, "ministral3 must have 'rope_parameters' config" assert rope_params["rope_type"] == "yarn", "ministral3 rope_type must be 'yarn'" self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) - self.gguf_writer.add_rope_scaling_yarn_attn_factor(rope_params["factor"]) + self.gguf_writer.add_rope_scaling_factor(rope_params["factor"]) self.gguf_writer.add_rope_scaling_yarn_beta_fast(rope_params["beta_fast"]) self.gguf_writer.add_rope_scaling_yarn_beta_slow(rope_params["beta_slow"]) self.gguf_writer.add_rope_scaling_yarn_log_mul(rope_params["mscale_all_dim"]) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 601b042d326..5030c90f5dd 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2230,7 +2230,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); - ml.get_key(LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, hparams.yarn_attn_factor, false); ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false); @@ -2246,7 +2245,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // TODO: this seems to be correct with the case of mscale == mscale_all_dims == 1.0f // but may need further verification with other values if (hparams.rope_yarn_log_mul != 0.0f) { - float factor = hparams.yarn_attn_factor; + float factor = 1.0f / hparams.rope_freq_scale_train; float mscale = 1.0f; float mscale_all_dims = hparams.rope_yarn_log_mul; static auto get_mscale = [](float scale, float mscale) { diff --git a/src/models/mistral3.cpp b/src/models/mistral3.cpp index 3f38a07798f..0b672235911 100644 --- a/src/models/mistral3.cpp +++ b/src/models/mistral3.cpp @@ -85,13 +85,6 @@ llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_grap cb(Qcur, "Qcur_attn_temp_scaled", il); } - if (hparams.use_kq_norm) { - // Llama4TextL2Norm - Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps); - Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps); - cb(Qcur, "Qcur_normed", il); - cb(Kcur, "Kcur_normed", il); - } cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); From a4f540bc3043185d5ff34cb87599330972c785d5 Mon Sep 17 00:00:00 2001 From: Julien Denize Date: Sun, 30 Nov 2025 19:28:35 +0000 Subject: [PATCH 06/10] Add Ministral3 support to Mistral format --- convert_hf_to_gguf.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 86897bcf0e2..ba322dc3e96 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -9893,6 +9893,20 @@ def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mis return template + def set_gguf_parameters(self): + super().set_gguf_parameters() + if "yarn" in self.hparams: + yarn_params = self.hparams["yarn"] + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(yarn_params["factor"]) + self.gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_params["beta"]) + self.gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_params["alpha"]) + self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1) + self.gguf_writer.add_rope_scaling_orig_ctx_len(yarn_params["original_max_position_embeddings"]) + + if "llama_4_scaling" in self.hparams: + self.gguf_writer.add_attn_temperature_scale(self.hparams["llama_4_scaling"]["beta"]) + class PixtralModel(LlavaVisionModel): model_name = "Pixtral" From bf08fcc4550a0f9b8de5667202ad4f96eba30de3 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 1 Dec 2025 11:00:59 +0100 Subject: [PATCH 07/10] improve arch handling --- convert_hf_to_gguf.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ba322dc3e96..3eabeedc3d6 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2819,8 +2819,11 @@ class Mistral3Model(LlamaModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # for compatibility, we use LLAMA arch for older models + # TODO: remove this once everyone has migrated to newer version of llama.cpp if self.hparams.get("model_type") != "ministral3": self.model_arch = gguf.MODEL_ARCH.LLAMA + self.gguf_writer.arch = str(self.model_arch) + self.gguf_writer.add_architecture() self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) def set_gguf_parameters(self): @@ -9848,12 +9851,22 @@ def modify_tensors(self, data_torch, name, bid): class MistralModel(LlamaModel): - model_arch = gguf.MODEL_ARCH.LLAMA + model_arch = gguf.MODEL_ARCH.MISTRAL3 model_name = "Mistral" hf_arch = "" is_mistral_format = True undo_permute = False + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # for compatibility, we use LLAMA arch for older models + # TODO: remove this once everyone migrates to newer version of llama.cpp + if "llama_4_scaling" not in self.hparams: + self.model_arch = gguf.MODEL_ARCH.LLAMA + self.gguf_writer.arch = str(self.model_arch) + self.gguf_writer.add_architecture() + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + @staticmethod def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mistral_format: bool): assert TokenizerVersion is not None and Tekkenizer is not None and SentencePieceTokenizer is not None, _mistral_import_error_msg From 34234a50edf55ce28b42c8e3d13fb4445fdb98e5 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 1 Dec 2025 11:02:42 +0100 Subject: [PATCH 08/10] add sizes --- src/llama-model.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index d5a7d922a12..584efbf3c84 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2277,7 +2277,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { } switch (hparams.n_layer) { - // TODO + case 26: type = LLM_TYPE_3B; break; + case 34: type = LLM_TYPE_8B; break; + case 40: type = LLM_TYPE_14B; break; default: type = LLM_TYPE_UNKNOWN; } } break; From b185b7fc6f43a51190939541bd2e732f962dd40c Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Mon, 1 Dec 2025 11:44:45 +0100 Subject: [PATCH 09/10] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- convert_hf_to_gguf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 3eabeedc3d6..c9bc96f86d6 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2822,7 +2822,7 @@ def __init__(self, *args, **kwargs): # TODO: remove this once everyone has migrated to newer version of llama.cpp if self.hparams.get("model_type") != "ministral3": self.model_arch = gguf.MODEL_ARCH.LLAMA - self.gguf_writer.arch = str(self.model_arch) + self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch] self.gguf_writer.add_architecture() self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) @@ -9863,7 +9863,7 @@ def __init__(self, *args, **kwargs): # TODO: remove this once everyone migrates to newer version of llama.cpp if "llama_4_scaling" not in self.hparams: self.model_arch = gguf.MODEL_ARCH.LLAMA - self.gguf_writer.arch = str(self.model_arch) + self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch] self.gguf_writer.add_architecture() self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) From 5600361cf4508879885c5d2bb2b324c4d5d175ff Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 1 Dec 2025 11:46:11 +0100 Subject: [PATCH 10/10] nits --- convert_hf_to_gguf.py | 2 +- src/models/models.h | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index c9bc96f86d6..a54cce887bb 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -9914,7 +9914,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_factor(yarn_params["factor"]) self.gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_params["beta"]) self.gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_params["alpha"]) - self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1) + self.gguf_writer.add_rope_scaling_yarn_log_mul(1.0) # mscale_all_dim self.gguf_writer.add_rope_scaling_orig_ctx_len(yarn_params["original_max_position_embeddings"]) if "llama_4_scaling" in self.hparams: diff --git a/src/models/models.h b/src/models/models.h index bfe59284339..d93601ad06a 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -322,6 +322,10 @@ struct llm_build_minimax_m2 : public llm_graph_context { llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_mistral3 : public llm_graph_context { + llm_build_mistral3(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_mpt : public llm_graph_context { llm_build_mpt(const llama_model & model, const llm_graph_params & params); }; @@ -537,7 +541,3 @@ struct llm_build_wavtokenizer_dec : public llm_graph_context { struct llm_build_xverse : public llm_graph_context { llm_build_xverse(const llama_model & model, const llm_graph_params & params); }; - -struct llm_build_mistral3 : public llm_graph_context { - llm_build_mistral3(const llama_model & model, const llm_graph_params & params); -};