Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 70 additions & 4 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1581,10 +1581,27 @@ def __init__(self, *args, **kwargs):

# load preprocessor config
self.preprocessor_config = {}
if not self.is_mistral_format:
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:

# prefer preprocessor_config.json if possible
preprocessor_config_path = self.dir_model / "preprocessor_config.json"
if preprocessor_config_path.is_file():
with open(preprocessor_config_path, "r", encoding="utf-8") as f:
self.preprocessor_config = json.load(f)

# prefer processor_config.json if possible
processor_config_path = self.dir_model / "processor_config.json"
if processor_config_path.is_file():
with open(processor_config_path, "r", encoding="utf-8") as f:
cfg = json.load(f)
# move image_processor to root level for compat
if "image_processor" in cfg:
cfg = {
**cfg,
**cfg["image_processor"],
}
# merge configs
self.preprocessor_config = {**self.preprocessor_config, **cfg}

def get_vision_config(self) -> dict[str, Any] | None:
config_name = "vision_config" if not self.is_mistral_format else "vision_encoder"
return self.global_config.get(config_name)
Expand Down Expand Up @@ -2797,7 +2814,32 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter

@ModelBase.register("Mistral3ForConditionalGeneration")
class Mistral3Model(LlamaModel):
model_arch = gguf.MODEL_ARCH.LLAMA
model_arch = gguf.MODEL_ARCH.MISTRAL3
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for maintainers: while the ministral3 and the old mistral models have almost the same cgraph, the hparams handling in llama_model::load_hparams is quite more complicated. Therefore, it's better to separate the 2 archs to make it more readable.

This also make the code to be more future-proof, in case future mistral models become significantly more complicated than the traditional llama arch.


def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# for compatibility, we use LLAMA arch for older models
# TODO: remove this once everyone has migrated to newer version of llama.cpp
if self.hparams.get("model_type") != "ministral3":
self.model_arch = gguf.MODEL_ARCH.LLAMA
self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
self.gguf_writer.add_architecture()
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)

def set_gguf_parameters(self):
super().set_gguf_parameters()
rope_params = self.hparams.get("rope_parameters")
if self.hparams.get("model_type") == "ministral3":
assert rope_params is not None, "ministral3 must have 'rope_parameters' config"
assert rope_params["rope_type"] == "yarn", "ministral3 rope_type must be 'yarn'"
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(rope_params["factor"])
self.gguf_writer.add_rope_scaling_yarn_beta_fast(rope_params["beta_fast"])
self.gguf_writer.add_rope_scaling_yarn_beta_slow(rope_params["beta_slow"])
self.gguf_writer.add_rope_scaling_yarn_log_mul(rope_params["mscale_all_dim"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_params["original_max_position_embeddings"])
self.gguf_writer.add_rope_freq_base(rope_params["rope_theta"])
self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"])

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
name = name.replace("language_model.", "")
Expand Down Expand Up @@ -9809,12 +9851,22 @@ def modify_tensors(self, data_torch, name, bid):


class MistralModel(LlamaModel):
model_arch = gguf.MODEL_ARCH.LLAMA
model_arch = gguf.MODEL_ARCH.MISTRAL3
model_name = "Mistral"
hf_arch = ""
is_mistral_format = True
undo_permute = False

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# for compatibility, we use LLAMA arch for older models
# TODO: remove this once everyone migrates to newer version of llama.cpp
if "llama_4_scaling" not in self.hparams:
self.model_arch = gguf.MODEL_ARCH.LLAMA
self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
self.gguf_writer.add_architecture()
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)

@staticmethod
def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mistral_format: bool):
assert TokenizerVersion is not None and Tekkenizer is not None and SentencePieceTokenizer is not None, _mistral_import_error_msg
Expand Down Expand Up @@ -9854,6 +9906,20 @@ def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mis

return template

def set_gguf_parameters(self):
super().set_gguf_parameters()
if "yarn" in self.hparams:
yarn_params = self.hparams["yarn"]
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(yarn_params["factor"])
self.gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_params["beta"])
self.gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_params["alpha"])
self.gguf_writer.add_rope_scaling_yarn_log_mul(1.0) # mscale_all_dim
self.gguf_writer.add_rope_scaling_orig_ctx_len(yarn_params["original_max_position_embeddings"])

if "llama_4_scaling" in self.hparams:
self.gguf_writer.add_attn_temperature_scale(self.hparams["llama_4_scaling"]["beta"])


class PixtralModel(LlavaVisionModel):
model_name = "Pixtral"
Expand Down
23 changes: 23 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ class Attention:
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers"
SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern"
TEMPERATURE_SCALE = "{arch}.attention.temperature_scale"

class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
Expand Down Expand Up @@ -444,6 +445,7 @@ class MODEL_ARCH(IntEnum):
MINIMAXM2 = auto()
RND1 = auto()
PANGU_EMBED = auto()
MISTRAL3 = auto()


class VISION_PROJECTOR_TYPE(IntEnum):
Expand Down Expand Up @@ -817,6 +819,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.COGVLM: "cogvlm",
MODEL_ARCH.RND1: "rnd1",
MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
MODEL_ARCH.MISTRAL3: "mistral3",
}

VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
Expand Down Expand Up @@ -3071,6 +3074,26 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.MISTRAL3: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
# TODO
}

Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,9 @@ def add_attn_output_scale(self, value: float) -> None:
def add_attn_temperature_length(self, value: int) -> None:
self.add_uint32(Keys.Attention.TEMPERATURE_LENGTH.format(arch=self.arch), value)

def add_attn_temperature_scale(self, value: float) -> None:
self.add_float32(Keys.Attention.TEMPERATURE_SCALE.format(arch=self.arch), value)

def add_pooling_type(self, value: PoolingType) -> None:
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)

Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ add_library(llama
models/t5-enc.cpp
models/wavtokenizer-dec.cpp
models/xverse.cpp
models/mistral3.cpp
models/graph-context-mamba.cpp
)

Expand Down
28 changes: 28 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_COGVLM, "cogvlm" },
{ LLM_ARCH_RND1, "rnd1" },
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
{ LLM_ARCH_MISTRAL3, "mistral3" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};

Expand Down Expand Up @@ -204,6 +205,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" },
{ LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" },
{ LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" },
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },

Expand Down Expand Up @@ -2512,6 +2514,32 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_MISTRAL3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_UNKNOWN,
{
Expand Down
2 changes: 2 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ enum llm_arch {
LLM_ARCH_COGVLM,
LLM_ARCH_RND1,
LLM_ARCH_PANGU_EMBED,
LLM_ARCH_MISTRAL3,
LLM_ARCH_UNKNOWN,
};

Expand Down Expand Up @@ -208,6 +209,7 @@ enum llm_kv {
LLM_KV_ATTENTION_SCALE,
LLM_KV_ATTENTION_OUTPUT_SCALE,
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
LLM_KV_ATTENTION_TEMPERATURE_SCALE,
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,

Expand Down
3 changes: 3 additions & 0 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
if (ubatch->pos && attn_scale) {
const int64_t n_tokens = ubatch->n_tokens;

GGML_ASSERT(f_attn_temp_scale != 0.0f);
GGML_ASSERT(n_attn_temp_floor_scale != 0);

std::vector<float> attn_scale_data(n_tokens, 0.0f);
for (int i = 0; i < n_tokens; ++i) {
const float pos = ubatch->pos[i];
Expand Down
4 changes: 2 additions & 2 deletions src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ struct llama_hparams {
// llama4 smallthinker
uint32_t n_moe_layer_step = 0;
uint32_t n_no_rope_layer_step = 4;
uint32_t n_attn_temp_floor_scale = 8192;
float f_attn_temp_scale = 0.1;
uint32_t n_attn_temp_floor_scale = 0;
float f_attn_temp_scale = 0.0f;

// gemma3n altup
uint32_t n_altup = 4; // altup_num_inputs
Expand Down
50 changes: 46 additions & 4 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -626,8 +626,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
switch (arch) {
case LLM_ARCH_LLAMA:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);

if (hparams.n_expert == 8) {
switch (hparams.n_layer) {
case 32: type = LLM_TYPE_8x7B; break;
Expand Down Expand Up @@ -663,8 +661,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
hparams.n_no_rope_layer_step = hparams.n_layer; // always use rope
} else {
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
hparams.n_swa = 8192;
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
hparams.n_swa = 8192;
hparams.n_attn_temp_floor_scale = 8192;
hparams.f_attn_temp_scale = 0.1f;
hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
}

Expand Down Expand Up @@ -2247,6 +2247,42 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_MISTRAL3:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false);

ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false);
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false);
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false);

// TODO: maybe add n_attn_temp_floor_scale as a separate KV?
if (hparams.f_attn_temp_scale != 0.0f) {
hparams.n_attn_temp_floor_scale = hparams.n_ctx_orig_yarn;
if (hparams.n_attn_temp_floor_scale == 0) {
throw std::runtime_error("invalid n_ctx_orig_yarn for attention temperature scaling");
}
}

// TODO: this seems to be correct with the case of mscale == mscale_all_dims == 1.0f
// but may need further verification with other values
if (hparams.rope_yarn_log_mul != 0.0f) {
float factor = 1.0f / hparams.rope_freq_scale_train;
float mscale = 1.0f;
float mscale_all_dims = hparams.rope_yarn_log_mul;
static auto get_mscale = [](float scale, float mscale) {
return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
};
hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
}

switch (hparams.n_layer) {
case 26: type = LLM_TYPE_3B; break;
case 34: type = LLM_TYPE_8B; break;
case 40: type = LLM_TYPE_14B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
default: throw std::runtime_error("unsupported model architecture");
}

Expand Down Expand Up @@ -2560,6 +2596,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
case LLM_ARCH_MINICPM:
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
case LLM_ARCH_MISTRAL3:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);

Expand Down Expand Up @@ -7522,6 +7559,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique<llm_build_qwen3next>(*this, params);
} break;
case LLM_ARCH_MISTRAL3:
{
llm = std::make_unique<llm_build_mistral3>(*this, params);
} break;
default:
GGML_ABORT("fatal error");
}
Expand Down Expand Up @@ -7690,6 +7731,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_ARCEE:
case LLM_ARCH_ERNIE4_5:
case LLM_ARCH_ERNIE4_5_MOE:
case LLM_ARCH_MISTRAL3:
return LLAMA_ROPE_TYPE_NORM;

// the pairs of head values are offset by n_rot/2
Expand Down
Loading
Loading