From 284cec438bf2445db7eb93fe0f8c12833eaffe6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C8=98tefan-Gabriel=20Muscalu?= Date: Sat, 8 Jun 2024 14:16:15 +0300 Subject: [PATCH 1/4] update: convert-hf-to-gguf.py to support Qwen2-57B-A14B --- convert-hf-to-gguf.py | 48 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 2 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index a86864f04861b..4216493fd44eb 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1624,10 +1624,54 @@ class Qwen2MoeModel(Model): model_arch = gguf.MODEL_ARCH.QWEN2MOE def set_gguf_parameters(self): - super().set_gguf_parameters() - if (n_experts := self.hparams.get("num_experts")) is not None: + self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_block_count(self.block_count) + + if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None: + self.gguf_writer.add_context_length(n_ctx) + logger.info(f"gguf: context length = {n_ctx}") + + n_embd = self.find_hparam(["hidden_size", "n_embd"]) + self.gguf_writer.add_embedding_length(n_embd) + logger.info(f"gguf: embedding length = {n_embd}") + + n_head = self.find_hparam(["num_attention_heads", "n_head"]) + self.gguf_writer.add_head_count(n_head) + logger.info(f"gguf: head count = {n_head}") + + if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None: + self.gguf_writer.add_head_count_kv(n_head_kv) + logger.info(f"gguf: key-value head count = {n_head_kv}") + + if (rope_theta := self.hparams.get("rope_theta")) is not None: + self.gguf_writer.add_rope_freq_base(rope_theta) + logger.info(f"gguf: rope theta = {rope_theta}") + if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None: + self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps) + logger.info(f"gguf: rms norm epsilon = {f_rms_eps}") + if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None: + self.gguf_writer.add_layer_norm_eps(f_norm_eps) + logger.info(f"gguf: layer norm epsilon = {f_norm_eps}") + + if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None: + self.gguf_writer.add_expert_used_count(n_experts_used) + logger.info(f"gguf: experts used count = {n_experts_used}") + + if (n_experts := self.find_hparam(["num_experts", "num_local_experts"])) is not None: self.gguf_writer.add_expert_count(n_experts) + if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: + self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) + logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}") + + if (shared_expert_intermediate_size := self.find_hparam(["shared_expert_intermediate_size","intermediate_size", "n_inner"])) is not None: + self.gguf_writer.add_feed_forward_length(shared_expert_intermediate_size) + logger.info(f"gguf: feed forward length = {shared_expert_intermediate_size}") + + self.gguf_writer.add_file_type(self.ftype) + logger.info(f"gguf: file type = {self.ftype}") + + _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: From aa8a7cd350310fb82784badfec0f508acb11d8eb Mon Sep 17 00:00:00 2001 From: stefan Date: Mon, 10 Jun 2024 17:50:31 +0000 Subject: [PATCH 2/4] fix: QWEN2MOE support for expert_feed_forward_length previously, expert ff was taken from n_ff (intermediate size) but it is now properly taken from LLM_KV_EXPERT_FEED_FORWARD_LENGTH n_ff_exp and n_ff_shared_exp are now properly calculated --- llama.cpp | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/llama.cpp b/llama.cpp index 8b675ea993a38..03d0790e1c9d5 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4247,6 +4247,8 @@ static void llm_load_hparams( } break; case LLM_ARCH_QWEN2MOE: { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { case 24: model.type = e_model::MODEL_A2_7B; break; @@ -5019,6 +5021,10 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); } + + if (model.arch == LLM_ARCH_QWEN2MOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + } } // Returns false if cancelled by progress_callback @@ -5805,16 +5811,17 @@ static bool llm_load_tensors( GGML_ASSERT(hparams.n_expert_used > 0); // MoE branch - auto n_ff_exp = n_ff / hparams.n_expert_used; + auto n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / hparams.n_expert_used; layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}); layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); // Shared expert branch + auto n_ff_shared_exp = hparams.n_ff_exp && hparams.n_expert_used ? hparams.n_ff_exp * hparams.n_expert_used : n_ff; layer.ffn_gate_inp_shexp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}); - layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff}); - layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff, n_embd}); - layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff}); + layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shared_exp}); + layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shared_exp, n_embd}); + layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shared_exp}); } } break; case LLM_ARCH_PHI2: From 06531cbaec419c247fcaad96f080676e958f5d70 Mon Sep 17 00:00:00 2001 From: stefan Date: Wed, 12 Jun 2024 11:30:08 +0000 Subject: [PATCH 3/4] update: convert-hf-to-gguf.py cleanup for Qwen2MoeForCausalLM --- convert-hf-to-gguf.py | 45 ++----------------------------------------- 1 file changed, 2 insertions(+), 43 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 4216493fd44eb..0b8094f4cc50b 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1624,54 +1624,13 @@ class Qwen2MoeModel(Model): model_arch = gguf.MODEL_ARCH.QWEN2MOE def set_gguf_parameters(self): - self.gguf_writer.add_name(self.dir_model.name) - self.gguf_writer.add_block_count(self.block_count) - - if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None: - self.gguf_writer.add_context_length(n_ctx) - logger.info(f"gguf: context length = {n_ctx}") - - n_embd = self.find_hparam(["hidden_size", "n_embd"]) - self.gguf_writer.add_embedding_length(n_embd) - logger.info(f"gguf: embedding length = {n_embd}") - - n_head = self.find_hparam(["num_attention_heads", "n_head"]) - self.gguf_writer.add_head_count(n_head) - logger.info(f"gguf: head count = {n_head}") - - if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None: - self.gguf_writer.add_head_count_kv(n_head_kv) - logger.info(f"gguf: key-value head count = {n_head_kv}") - - if (rope_theta := self.hparams.get("rope_theta")) is not None: - self.gguf_writer.add_rope_freq_base(rope_theta) - logger.info(f"gguf: rope theta = {rope_theta}") - if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None: - self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps) - logger.info(f"gguf: rms norm epsilon = {f_rms_eps}") - if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None: - self.gguf_writer.add_layer_norm_eps(f_norm_eps) - logger.info(f"gguf: layer norm epsilon = {f_norm_eps}") - - if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None: - self.gguf_writer.add_expert_used_count(n_experts_used) - logger.info(f"gguf: experts used count = {n_experts_used}") - - if (n_experts := self.find_hparam(["num_experts", "num_local_experts"])) is not None: + super().set_gguf_parameters() + if (n_experts := self.hparams.get("num_experts")) is not None: self.gguf_writer.add_expert_count(n_experts) - if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}") - if (shared_expert_intermediate_size := self.find_hparam(["shared_expert_intermediate_size","intermediate_size", "n_inner"])) is not None: - self.gguf_writer.add_feed_forward_length(shared_expert_intermediate_size) - logger.info(f"gguf: feed forward length = {shared_expert_intermediate_size}") - - self.gguf_writer.add_file_type(self.ftype) - logger.info(f"gguf: file type = {self.ftype}") - - _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: From d9452267a0069a2873f4855cfc0b671b31704889 Mon Sep 17 00:00:00 2001 From: stefan Date: Fri, 14 Jun 2024 11:38:12 +0000 Subject: [PATCH 4/4] fix: QWEN2MOE support for expert_feed_forward_length previously, expert ff was taken from n_ff (intermediate size) but it is now properly taken from LLM_KV_EXPERT_FEED_FORWARD_LENGTH n_ff_exp and n_ff_shexp are now properly calculated --- convert-hf-to-gguf.py | 3 +++ gguf-py/gguf/constants.py | 31 +++++++++++++------------- gguf-py/gguf/gguf_writer.py | 3 +++ llama.cpp | 44 +++++++++++++++++++++---------------- 4 files changed, 47 insertions(+), 34 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 0b8094f4cc50b..086f9dd78b0b2 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1630,6 +1630,9 @@ def set_gguf_parameters(self): if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}") + if (shared_expert_intermediate_size := self.hparams.get('shared_expert_intermediate_size')) is not None: + self.gguf_writer.add_expert_shared_feed_forward_length(shared_expert_intermediate_size) + logger.info(f"gguf: expert shared feed forward length = {shared_expert_intermediate_size}") _experts: list[dict[str, Tensor]] | None = None diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 8908585ccf957..fb20cfabbcab5 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -33,21 +33,22 @@ class General: FILE_TYPE = "general.file_type" class LLM: - VOCAB_SIZE = "{arch}.vocab_size" - CONTEXT_LENGTH = "{arch}.context_length" - EMBEDDING_LENGTH = "{arch}.embedding_length" - BLOCK_COUNT = "{arch}.block_count" - LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count" - FEED_FORWARD_LENGTH = "{arch}.feed_forward_length" - EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length" - USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual" - TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout" - EXPERT_COUNT = "{arch}.expert_count" - EXPERT_USED_COUNT = "{arch}.expert_used_count" - EXPERT_SHARED_COUNT = "{arch}.expert_shared_count" - EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale" - POOLING_TYPE = "{arch}.pooling_type" - LOGIT_SCALE = "{arch}.logit_scale" + VOCAB_SIZE = "{arch}.vocab_size" + CONTEXT_LENGTH = "{arch}.context_length" + EMBEDDING_LENGTH = "{arch}.embedding_length" + BLOCK_COUNT = "{arch}.block_count" + LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count" + FEED_FORWARD_LENGTH = "{arch}.feed_forward_length" + EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length" + EXPERT_SHARED_FEED_FORWARD_LENGTH = "{arch}.expert_shared_feed_forward_length" + USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual" + TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout" + EXPERT_COUNT = "{arch}.expert_count" + EXPERT_USED_COUNT = "{arch}.expert_used_count" + EXPERT_SHARED_COUNT = "{arch}.expert_shared_count" + EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale" + POOLING_TYPE = "{arch}.pooling_type" + LOGIT_SCALE = "{arch}.logit_scale" class Attention: HEAD_COUNT = "{arch}.attention.head_count" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index b93747aff58b3..0733e6aa56116 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -383,6 +383,9 @@ def add_feed_forward_length(self, length: int) -> None: def add_expert_feed_forward_length(self, length: int) -> None: self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length) + def add_expert_shared_feed_forward_length(self, length: int) -> None: + self.add_uint32(Keys.LLM.EXPERT_SHARED_FEED_FORWARD_LENGTH.format(arch=self.arch), length) + def add_parallel_residual(self, use: bool) -> None: self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use) diff --git a/llama.cpp b/llama.cpp index 03d0790e1c9d5..f31778693675a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -282,6 +282,7 @@ enum llm_kv { LLM_KV_LEADING_DENSE_BLOCK_COUNT, LLM_KV_FEED_FORWARD_LENGTH, LLM_KV_EXPERT_FEED_FORWARD_LENGTH, + LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, LLM_KV_USE_PARALLEL_RESIDUAL, LLM_KV_TENSOR_DATA_LAYOUT, LLM_KV_EXPERT_COUNT, @@ -360,21 +361,22 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" }, { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" }, - { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, - { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, - { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, - { LLM_KV_BLOCK_COUNT, "%s.block_count" }, - { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" }, - { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, - { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" }, - { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, - { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, - { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, - { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, - { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, - { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, - { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, - { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, + { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, + { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, + { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, + { LLM_KV_BLOCK_COUNT, "%s.block_count" }, + { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" }, + { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, + { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" }, + { LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" }, + { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, + { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, + { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, + { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, + { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, + { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, + { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, + { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -1840,6 +1842,7 @@ struct llama_hparams { uint32_t n_lora_q = 0; uint32_t n_lora_kv = 0; uint32_t n_ff_exp = 0; + uint32_t n_ff_shexp = 0; uint32_t n_expert_shared = 0; float expert_weights_scale = 0.0; @@ -1888,6 +1891,7 @@ struct llama_hparams { if (this->n_lora_q != other.n_lora_q) return true; if (this->n_lora_kv != other.n_lora_kv) return true; if (this->n_ff_exp != other.n_ff_exp) return true; + if (this->n_ff_shexp != other.n_ff_shexp) return true; if (this->n_expert_shared != other.n_expert_shared) return true; if (this->rope_finetuned != other.rope_finetuned) return true; @@ -4248,6 +4252,7 @@ static void llm_load_hparams( case LLM_ARCH_QWEN2MOE: { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { @@ -5024,6 +5029,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { if (model.arch == LLM_ARCH_QWEN2MOE) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } } @@ -5817,11 +5823,11 @@ static bool llm_load_tensors( layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); // Shared expert branch - auto n_ff_shared_exp = hparams.n_ff_exp && hparams.n_expert_used ? hparams.n_ff_exp * hparams.n_expert_used : n_ff; + auto n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; layer.ffn_gate_inp_shexp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}); - layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shared_exp}); - layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shared_exp, n_embd}); - layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shared_exp}); + layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}); + layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}); + layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}); } } break; case LLM_ARCH_PHI2: