From 433782bc71eec48a2ab8e5a758944fe98781b526 Mon Sep 17 00:00:00 2001 From: Shunta Saito Date: Thu, 18 Sep 2025 13:28:32 +0900 Subject: [PATCH 1/8] Fix to use hidden_size_per_head --- convert_hf_to_gguf.py | 3 ++- src/llama-model.cpp | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 7ddec48ad7129..f8805a462a4a7 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4204,8 +4204,9 @@ def set_gguf_parameters(self): self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 2048)) self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096)) + self.gguf_writer.add_features_length(hparams.get("hidden_size_per_head", 128)) self.gguf_writer.add_block_count(block_count) - self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 32)) + self.gguf_writer.add_wkv_head_size(hparams.get("num_attention_heads", 32)) self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06)) self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10000)) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 981e57083c48d..b5202a2d758c4 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3371,11 +3371,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const uint32_t d_state = hparams.ssm_d_state; const uint32_t num_heads = hparams.ssm_dt_rank; const uint32_t intermediate_size = hparams.ssm_d_inner; - const uint32_t head_dim = intermediate_size / num_heads; + const uint32_t head_dim = hparams.wkv_head_size; const uint32_t qk_dim = head_dim; const uint32_t v_dim = head_dim; - const int64_t num_attention_heads = hparams.n_head(); - const int64_t q_num_heads = num_attention_heads; const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -3392,6 +3390,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { auto & layer = layers[i]; bool is_mamba_layer = hparams.is_recurrent(i); + const int64_t num_attention_heads = hparams.n_head_kv_arr[i]; + const int64_t q_num_heads = num_attention_heads; + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); if (is_mamba_layer) { From 4fe5c9662ef1b9510417bc9b6c45818406df534a Mon Sep 17 00:00:00 2001 From: Shunta Saito Date: Thu, 18 Sep 2025 15:27:24 +0900 Subject: [PATCH 2/8] Fix num heads --- convert_hf_to_gguf.py | 17 ++++++++++------- src/llama-hparams.h | 2 +- src/llama-model.cpp | 22 +++++++++++----------- 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index f8805a462a4a7..b98f5ff241e72 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4185,7 +4185,8 @@ def set_gguf_parameters(self): # This logic matches modeling_plamo.py's is_mamba function mamba_step = hparams.get("mamba_step", 2) mamba_enabled = hparams.get("mamba_enabled", True) - mamba_layers = [] + num_key_value_heads = [] + num_attention_heads = [] if mamba_enabled: for i in range(block_count): @@ -4195,18 +4196,20 @@ def set_gguf_parameters(self): else: is_mamba = (i % mamba_step) != (mamba_step // 2) if is_mamba: - mamba_layers.append(0) + num_key_value_heads.append(0) else: - mamba_layers.append(hparams.get("num_key_value_heads", 4)) + num_key_value_heads.append(hparams.get("num_key_value_heads", 4)) + num_attention_heads.append(hparams.get("num_attention_heads", 32)) - if mamba_layers: - self.gguf_writer.add_head_count_kv(mamba_layers) + if num_key_value_heads and num_attention_heads: + self.gguf_writer.add_head_count_kv(num_key_value_heads) + self.gguf_writer.add_head_count(num_attention_heads) self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 2048)) self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096)) - self.gguf_writer.add_features_length(hparams.get("hidden_size_per_head", 128)) + self.gguf_writer.add_key_length(hparams.get("hidden_size_per_head", 128)) + self.gguf_writer.add_value_length(hparams.get("hidden_size_per_head", 128)) self.gguf_writer.add_block_count(block_count) - self.gguf_writer.add_wkv_head_size(hparams.get("num_attention_heads", 32)) self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06)) self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10000)) diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 202cbbd1b2884..582fde49432d3 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -42,7 +42,7 @@ struct llama_hparams { uint32_t n_embd; uint32_t n_embd_features = 0; uint32_t n_layer; - int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache + int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache uint32_t n_rot; uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head diff --git a/src/llama-model.cpp b/src/llama-model.cpp index b5202a2d758c4..9e8a4f16dba62 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3367,15 +3367,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_PLAMO2: { - const uint32_t d_conv = hparams.ssm_d_conv; - const uint32_t d_state = hparams.ssm_d_state; - const uint32_t num_heads = hparams.ssm_dt_rank; - const uint32_t intermediate_size = hparams.ssm_d_inner; - const uint32_t head_dim = hparams.wkv_head_size; - const uint32_t qk_dim = head_dim; - const uint32_t v_dim = head_dim; - const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // output @@ -3390,12 +3381,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) { auto & layer = layers[i]; bool is_mamba_layer = hparams.is_recurrent(i); - const int64_t num_attention_heads = hparams.n_head_kv_arr[i]; - const int64_t q_num_heads = num_attention_heads; layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); if (is_mamba_layer) { + const uint32_t d_conv = hparams.ssm_d_conv; + const uint32_t d_state = hparams.ssm_d_state; + const uint32_t num_heads = hparams.ssm_dt_rank; + const uint32_t intermediate_size = hparams.ssm_d_inner; + const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2 * intermediate_size}, 0); layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, intermediate_size}, 0); @@ -3412,6 +3407,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0); layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0); } else { + const uint32_t head_dim = hparams.n_embd_head_k; + const uint32_t qk_dim = head_dim; + const uint32_t v_dim = head_dim; + const int64_t num_attention_heads = hparams.n_head(i); + const int64_t q_num_heads = num_attention_heads; const int64_t num_key_value_heads = hparams.n_head_kv(i); const int64_t k_num_heads = num_key_value_heads; const int64_t v_num_heads = num_key_value_heads; From b164aa10da2ac65bb4a7efa73e39cd7b53e222a1 Mon Sep 17 00:00:00 2001 From: Shunta Saito Date: Thu, 18 Sep 2025 15:28:45 +0900 Subject: [PATCH 3/8] Fix array --- convert_hf_to_gguf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index b98f5ff241e72..c21b0495e4218 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4197,6 +4197,7 @@ def set_gguf_parameters(self): is_mamba = (i % mamba_step) != (mamba_step // 2) if is_mamba: num_key_value_heads.append(0) + num_attention_heads.append(0) else: num_key_value_heads.append(hparams.get("num_key_value_heads", 4)) num_attention_heads.append(hparams.get("num_attention_heads", 32)) From 7f5d8a93738b97b18419850a6146992d333d6965 Mon Sep 17 00:00:00 2001 From: Shunta Saito Date: Thu, 18 Sep 2025 19:46:44 +0900 Subject: [PATCH 4/8] Fix loading weights --- src/llama-model.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 9e8a4f16dba62..51c46bdb81a3e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1077,6 +1077,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { break; default: type = LLM_TYPE_UNKNOWN; } + + // Load attention parameters + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v); } break; case LLM_ARCH_GPT2: { @@ -17521,6 +17525,7 @@ struct llm_build_plamo2 : public llm_graph_context_mamba { const int64_t n_embd_head_q = hparams.n_embd_head_k; const int64_t n_embd_head_k = hparams.n_embd_head_k; const int64_t n_embd_head_v = hparams.n_embd_head_v; + int32_t n_head = hparams.n_head(il); int32_t n_head_kv = hparams.n_head_kv(il); const int64_t q_offset = 0; From 10af12c3f08eca12753f56a7c8831fe5009a7a2c Mon Sep 17 00:00:00 2001 From: Shunta Saito Date: Fri, 19 Sep 2025 00:18:12 +0900 Subject: [PATCH 5/8] Support old GGUF converted by the previous version of llama.cpp --- src/llama-model.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 51c46bdb81a3e..980398a34d430 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1079,8 +1079,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { } // Load attention parameters - ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k); - ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); } break; case LLM_ARCH_GPT2: { From 0e8aff147b6a54196e27615800500c9c5199ee7b Mon Sep 17 00:00:00 2001 From: Shunta Saito Date: Mon, 22 Sep 2025 17:07:55 +0900 Subject: [PATCH 6/8] Update src/llama-model.cpp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- src/llama-model.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 980398a34d430..8d2814b928204 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1077,10 +1077,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { break; default: type = LLM_TYPE_UNKNOWN; } - - // Load attention parameters - ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); - ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); } break; case LLM_ARCH_GPT2: { From 07b55f4be2a8b5994119b5d4f6dba2beff416231 Mon Sep 17 00:00:00 2001 From: Shunta Saito Date: Mon, 22 Sep 2025 17:17:12 +0900 Subject: [PATCH 7/8] Move shared parameter definitions to the outside of loop --- src/llama-model.cpp | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 8d2814b928204..9c0e8d98cd723 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3367,6 +3367,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_PLAMO2: { + // mamba parameters + const uint32_t d_conv = hparams.ssm_d_conv; + const uint32_t d_state = hparams.ssm_d_state; + const uint32_t num_heads = hparams.ssm_dt_rank; + const uint32_t intermediate_size = hparams.ssm_d_inner; + const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); + + // attention parameters + const uint32_t qk_dim = hparams.n_embd_head_k; + const uint32_t v_dim = hparams.n_embd_head_v; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // output @@ -3381,16 +3392,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { auto & layer = layers[i]; bool is_mamba_layer = hparams.is_recurrent(i); - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); if (is_mamba_layer) { - const uint32_t d_conv = hparams.ssm_d_conv; - const uint32_t d_state = hparams.ssm_d_state; - const uint32_t num_heads = hparams.ssm_dt_rank; - const uint32_t intermediate_size = hparams.ssm_d_inner; - const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2 * intermediate_size}, 0); layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, intermediate_size}, 0); @@ -3407,9 +3411,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0); layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0); } else { - const uint32_t head_dim = hparams.n_embd_head_k; - const uint32_t qk_dim = head_dim; - const uint32_t v_dim = head_dim; const int64_t num_attention_heads = hparams.n_head(i); const int64_t q_num_heads = num_attention_heads; const int64_t num_key_value_heads = hparams.n_head_kv(i); @@ -3420,8 +3421,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t v_proj_dim = v_num_heads * v_dim; layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, q_proj_dim + k_proj_dim + v_proj_dim}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim, num_attention_heads}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim, k_num_heads}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {qk_dim, num_attention_heads}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {qk_dim, k_num_heads}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {q_num_heads * v_dim, n_embd}, 0); } From 1be2787b5c32f3730608da2170da68e3dcd1bfd0 Mon Sep 17 00:00:00 2001 From: Shunta Saito Date: Mon, 22 Sep 2025 17:44:36 +0900 Subject: [PATCH 8/8] Not calculating n_embd_head_k,v by n_embd / n_head --- src/llama-model.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 9c0e8d98cd723..093ed963fd73f 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1076,7 +1076,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; default: type = LLM_TYPE_UNKNOWN; - } + } + + // Load attention parameters + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); } break; case LLM_ARCH_GPT2: {