diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index b759366684396..0b8c05a183f03 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1054,6 +1054,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "53e325976a6e142379c19b09afcae354f2f496f147afa8f9e189a33fe4e3024e": # ref: https://huggingface.co/ibm-granite/granite-docling-258M res = "granite-docling" + if chkhsh == "f4f37b6c8eb9ea29b3eac6bb8c8487c5ab7885f8d8022e67edc1c68ce8403e95": + # ref: https://huggingface.co/MiniMaxAI/MiniMax-M2 + res = "minimax-m2" if res is None: logger.warning("\n") @@ -6909,6 +6912,84 @@ def prepare_tensors(self): raise ValueError(f"Unprocessed experts: {experts}") +@ModelBase.register("MiniMaxM2ForCausalLM") +class MiniMaxM2Model(TextModel): + model_arch = gguf.MODEL_ARCH.MINIMAXM2 + _experts_cache: dict[int, dict[str, Tensor]] = {} + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.hparams["num_experts"] = self.hparams["num_local_experts"] + + def set_gguf_parameters(self): + if self.hparams["scoring_func"] == "sigmoid": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + elif self.hparams["scoring_func"] == "softmax": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX) + else: + raise ValueError(f"Unsupported scoring_func value: {self.hparams['scoring_func']}") + + block_count = self.find_hparam(["num_hidden_layers", "n_layer"]) + n_embd = self.find_hparam(["hidden_size", "n_embd"]) + n_head = self.find_hparam(["num_attention_heads", "n_head"]) + n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"]) + rms_eps = self.find_hparam(["rms_norm_eps"]) + max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"]) + head_dim = self.find_hparam(["head_dim"]) + + self.gguf_writer.add_context_length(max_pos_embds) + self.gguf_writer.add_embedding_length(n_embd) + self.gguf_writer.add_feed_forward_length(self.find_hparam(["intermediate_size"])) + self.gguf_writer.add_expert_feed_forward_length(self.find_hparam(["intermediate_size"])) + self.gguf_writer.add_expert_count(self.find_hparam(["num_local_experts"])) + self.gguf_writer.add_expert_used_count(self.find_hparam(["num_experts_per_tok"])) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_head_count(n_head) + self.gguf_writer.add_head_count_kv(n_head_kv) + self.gguf_writer.add_layer_norm_rms_eps(rms_eps) + self.gguf_writer.add_layer_norm_eps(rms_eps) + self.gguf_writer.add_key_length(head_dim) + self.gguf_writer.add_value_length(head_dim) + self.gguf_writer.add_rope_dimension_count(self.find_hparam(["rotary_dim"])) + self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"])) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): + if name.endswith("e_score_correction_bias"): + name = name.replace("e_score_correction_bias", "e_score_correction.bias") + + # merge expert weights + if 'experts' in name: + n_experts = self.hparams["num_experts"] + assert bid is not None + + expert_cache = self._experts_cache.setdefault(bid, {}) + expert_cache[name] = data_torch + expert_weights = ["w1", "w2", "w3"] + + # not enough expert weights to merge + if len(expert_cache) < n_experts * len(expert_weights): + return [] + + tensors: list[tuple[str, Tensor]] = [] + for w_name in expert_weights: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{w_name}.weight" + datas.append(expert_cache[ename]) + del expert_cache[ename] + + data_torch = torch.stack(datas, dim=0) + merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight" + new_name = self.map_tensor_name(merged_name) + tensors.append((new_name, data_torch)) + + del self._experts_cache[bid] + return tensors + + return super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("Dots1ForCausalLM") class Dots1Model(Qwen2MoeModel): model_arch = gguf.MODEL_ARCH.DOTS1 diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 0ebc1b160f603..65b2cecbb66ce 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -141,6 +141,7 @@ class TOKENIZER_TYPE(IntEnum): {"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", }, {"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", }, {"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", }, + {"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", }, ] # some models are known to be broken upstream, so we will skip them as exceptions @@ -438,6 +439,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: except OSError as e: logger.error(f"Failed to load tokenizer for model {name}. Error: {e}") continue # Skip this model and continue with the next one in the loop + except TypeError as e: + logger.error(f"Failed to load tokenizer for model {name}. Error: {e}") + continue # Skip this model and continue with the next one in the loop if not os.path.exists(f"models/ggml-vocab-{name}.gguf"): logger.info(f"Skip vocab files for model {name}, no GGUF file found") diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 94fcfaf69cf09..6796abe2946e7 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -420,6 +420,7 @@ class MODEL_ARCH(IntEnum): SEED_OSS = auto() GROVEMOE = auto() APERTUS = auto() + MINIMAXM2 = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -766,6 +767,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.SEED_OSS: "seed_oss", MODEL_ARCH.GROVEMOE: "grovemoe", MODEL_ARCH.APERTUS: "apertus", + MODEL_ARCH.MINIMAXM2: "minimax-m2", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -2837,6 +2839,25 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN_CHEXP, MODEL_TENSOR.FFN_UP_CHEXP, ], + MODEL_ARCH.MINIMAXM2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, + ], + # TODO } diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index d7dcd8efb8426..f9f7e9e6e9556 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -377,6 +377,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe "model.layers.{bid}.mlp.gate.expert_bias", # bailingmoe2 "model.layers.{bid}.feed_forward.expert_bias", # lfm2moe + "model.layers.{bid}.block_sparse_moe.e_score_correction", # minimax-m2 ), # Feed-forward up diff --git a/models/ggml-vocab-minimax-m2.gguf b/models/ggml-vocab-minimax-m2.gguf new file mode 100644 index 0000000000000..4e2f5c6d1f990 Binary files /dev/null and b/models/ggml-vocab-minimax-m2.gguf differ diff --git a/models/ggml-vocab-minimax-m2.gguf.inp b/models/ggml-vocab-minimax-m2.gguf.inp new file mode 100644 index 0000000000000..86b934e4020fb --- /dev/null +++ b/models/ggml-vocab-minimax-m2.gguf.inp @@ -0,0 +1,112 @@ +ied 4 ½ months +__ggml_vocab_test__ +Äpfel +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + + +__ggml_vocab_test__ + + + +__ggml_vocab_test__ + + + + +__ggml_vocab_test__ + + +__ggml_vocab_test__ +Hello world +__ggml_vocab_test__ + Hello world +__ggml_vocab_test__ +Hello World +__ggml_vocab_test__ + Hello World +__ggml_vocab_test__ + Hello World! +__ggml_vocab_test__ +Hello, world! +__ggml_vocab_test__ + Hello, world! +__ggml_vocab_test__ + this is 🦙.cpp +__ggml_vocab_test__ +w048 7tuijk dsdfhu +__ggml_vocab_test__ +нещо на Български +__ggml_vocab_test__ +កាន់តែពិសេសអាចខលចេញ +__ggml_vocab_test__ +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) +__ggml_vocab_test__ +Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello + Hello +__ggml_vocab_test__ + ( +__ggml_vocab_test__ + + = +__ggml_vocab_test__ +' era +__ggml_vocab_test__ +Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ +__ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ +3 +__ggml_vocab_test__ +33 +__ggml_vocab_test__ +333 +__ggml_vocab_test__ +3333 +__ggml_vocab_test__ +33333 +__ggml_vocab_test__ +333333 +__ggml_vocab_test__ +3333333 +__ggml_vocab_test__ +33333333 +__ggml_vocab_test__ +333333333 +__ggml_vocab_test__ +Cửa Việt +__ggml_vocab_test__ + discards +__ggml_vocab_test__ + + + + + + + + + + + +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL +__ggml_vocab_test__ diff --git a/models/ggml-vocab-minimax-m2.gguf.out b/models/ggml-vocab-minimax-m2.gguf.out new file mode 100644 index 0000000000000..900c7aa916410 --- /dev/null +++ b/models/ggml-vocab-minimax-m2.gguf.out @@ -0,0 +1,46 @@ + 1233 32 52 32 23901 4632 + 69967 30230 295 + + 32 + 256 + 326 + 9 + 10 + 367 + 4368 + 10380 + 19739 2035 + 53398 2035 + 19739 5476 + 53398 5476 + 53398 5476 33 + 19739 44 2035 33 + 53398 44 2035 33 + 546 355 9753 166 153 46 52243 + 119 48218 32 55 116 2157 60350 40081 6107 15931 + 8827 40614 3642 11575 185034 8623 + 76300 128 76300 182 76300 147 157246 139 76300 143 157246 130 76300 150 76300 183 76300 159 225 35097 76300 159 76300 162 76300 182 76300 133 76300 129 76300 155 76300 133 225 35097 76300 137 + 150333 359 14291 41 19918 182 61587 79213 171 21243 359 79401 158243 176756 41 181343 359 10141 113958 389 760 1072 1813 11248 41 + 19739 + 53398 + 32 53398 + 256 53398 + 326 53398 + 326 53398 10 326 53398 + 359 + 10 409 + 39 5784 + 19739 44 330 53147 33 2329 457 390 184404 3479 32020 594 44450 2489 17246 35341 49 1419 5516 + 34485 6255 + 51 + 2893 + 18397 + 18397 51 + 18397 2893 + 18397 18397 + 18397 18397 51 + 18397 18397 2893 + 18397 18397 18397 + 67 191937 97 31042 84408 116 + 2300 2958 + 137106 35066 24361 56254 151540 4315 10877 7671 41564 150333 359 14291 41 19918 182 61587 79213 171 21243 359 79401 158243 176756 41 181343 9753 166 153 186278 153 32 51 32 2893 32 18397 32 18397 51 32 18397 2893 32 18397 18397 32 18397 18397 51 32 18397 18397 2893 32 51 46 51 32 51 645 51 32 51 1662 51 29559 158 128 76300 182 76300 147 157246 139 76300 143 157246 130 76300 150 76300 183 76300 159 225 35097 76300 159 76300 162 76300 182 76300 133 21557 129 3479 32020 594 44450 2489 17246 35341 49 1419 5516 109618 1246 9435 6833 40614 3642 11575 185034 8623 8462 3443 64346 2765 111832 22815 34485 6255 61018 13074 8244 1040 722 116 1186 13396 986 44 722 2380 390 3123 63 722 77 516 3123 13098 1454 412 44 722 68 390 1079 1001 17251 63 1559 39 34121 258 99132 76 diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 8ca769c5fd2ef..26d976b9f299a 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -103,6 +103,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_SEED_OSS, "seed_oss" }, { LLM_ARCH_GROVEMOE, "grovemoe" }, { LLM_ARCH_APERTUS, "apertus" }, + { LLM_ARCH_MINIMAX_M2, "minimax-m2" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -2312,6 +2313,27 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_CHEXPS, "blk.%d.ffn_up_chexps" }, }, }, + { + LLM_ARCH_MINIMAX_M2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + }, + }, { LLM_ARCH_UNKNOWN, { diff --git a/src/llama-arch.h b/src/llama-arch.h index dea725c1a753a..cef99196cef79 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -107,6 +107,7 @@ enum llm_arch { LLM_ARCH_SEED_OSS, LLM_ARCH_GROVEMOE, LLM_ARCH_APERTUS, + LLM_ARCH_MINIMAX_M2, LLM_ARCH_UNKNOWN, }; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index ea6f59ed482bb..ee0eab9d9de4e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -120,6 +120,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_100B_A6B: return "100B.A6B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; + case LLM_TYPE_230B_A10B: return "230B.A10B"; case LLM_TYPE_235B_A22B: return "235B.A22B"; case LLM_TYPE_300B_A47B: return "300B.A47B"; case LLM_TYPE_355B_A32B: return "355B.A32B"; @@ -2124,6 +2125,18 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_MINIMAX_M2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + + switch (hparams.n_layer) { + case 62: type = LLM_TYPE_230B_A10B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + default: throw std::runtime_error("unsupported model architecture"); } @@ -6136,6 +6149,35 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); } } break; + case LLM_ARCH_MINIMAX_M2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k * n_head}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_k_gqa}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -19641,6 +19683,130 @@ struct llm_build_apertus : public llm_graph_context { } }; +struct llm_build_minimax_m2 : public llm_graph_context { + llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + // GGML_ASSERT(n_embd_head == hparams.n_rot); this is wrong in case of minimax, head_dim = 128, n_rot = 64 + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * inp_pos = build_inp_pos(); + auto inp_attn = build_attn_inp_kv(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = inpL; + + // self_attention + { + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, + LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, + LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(cur, "ffn_moe_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + llama_memory_i * llama_model::create_memory(const llama_memory_params & params, const llama_cparams & cparams) const { llama_memory_i * res; @@ -20165,6 +20331,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_MINIMAX_M2: + { + llm = std::make_unique(*this, params); + } break; default: GGML_ABORT("fatal error"); } @@ -20382,6 +20552,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_SEED_OSS: case LLM_ARCH_GROVEMOE: case LLM_ARCH_APERTUS: + case LLM_ARCH_MINIMAX_M2: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: diff --git a/src/llama-model.h b/src/llama-model.h index 1ab1cf7f8e94d..a4fb962bf0287 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -114,6 +114,7 @@ enum llm_type { LLM_TYPE_30B_A3B, LLM_TYPE_100B_A6B, LLM_TYPE_106B_A12B, // GLM-4.5-Air + LLM_TYPE_230B_A10B, // Minimax M2 LLM_TYPE_235B_A22B, LLM_TYPE_300B_A47B, // Ernie MoE big LLM_TYPE_355B_A32B, // GLM-4.5 diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 639fecbd31745..735c5d547f9e4 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -401,6 +401,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { }; break; case LLAMA_VOCAB_PRE_TYPE_GPT4O: + case LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2: regex_exprs = { // original regex from tokenizer.json // "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", @@ -1992,6 +1993,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "grok-2") { pre_type = LLAMA_VOCAB_PRE_TYPE_GROK_2; clean_spaces = false; + } else if ( + tokenizer_pre == "minimax-m2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2; + clean_spaces = false; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 5e468675e4447..1194ec473d03a 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -49,6 +49,7 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39, LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40, + LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41, }; struct LLM_KV; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d9cc5e933f4ce..3896d2c00a8f0 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -124,6 +124,7 @@ llama_test(test-tokenizer-0 NAME test-tokenizer-0-phi-3 ARGS ${PROJE llama_test(test-tokenizer-0 NAME test-tokenizer-0-qwen2 ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-qwen2.gguf) llama_test(test-tokenizer-0 NAME test-tokenizer-0-refact ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-refact.gguf) llama_test(test-tokenizer-0 NAME test-tokenizer-0-starcoder ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-starcoder.gguf) +llama_test(test-tokenizer-0 NAME test-tokenizer-0-minimax-m2 ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-minimax-m2.gguf) if (NOT WIN32) llama_test_cmd(