Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
/.vscode
*.egg-info/
/.idea/
/.qoder/
build/
dist/
__pycache__/
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ The following is the list of models supported by MCore-Bridge:
| GLM | glm4, glm4_moe, glm4_moe_lite<br />glm4v, glm4v_moe, <br />glm_moe_dsa |
| MiniMax | minimax_m2 |
| Kimi | kimi_k2, kimi_vl, kimi_k25 |
| Bailing | bailing_moe |
| InternLM | internlm3, internvl_chat, internvl |
| Ovis | ovis2_5 |
| Llama | llama, llama4 |
Expand Down
1 change: 1 addition & 0 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ uv pip install -e . --torch-backend=auto
| GLM | glm4, glm4_moe, glm4_moe_lite<br />glm4v, glm4v_moe, <br />glm_moe_dsa |
| MiniMax | minimax_m2 |
| Kimi | kimi_k2, kimi_vl, kimi_k25 |
| Bailing | bailing_moe |
| InternLM | internlm3, internvl_chat, internvl |
| Ovis | ovis2_5 |
| Llama | llama, llama4 |
Expand Down
29 changes: 19 additions & 10 deletions src/mcore_bridge/bridge/gpt_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class GPTBridge:
# HF Keys
hf_q_norm_key = 'q_norm.weight'
hf_k_norm_key = 'k_norm.weight'
hf_o_proj_key = 'o_proj'
hf_attn_prefix = 'self_attn'
hf_mlp_prefix = 'mlp'
hf_gate_key = 'gate.weight'
hf_shared_expert_key = None
Expand Down Expand Up @@ -523,11 +525,7 @@ def _filter_prefix(state_dict, prefix: str):
return state_dict
return {k: v for k, v in state_dict.items() if k.startswith(prefix)}

def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool):
if to_mcore:
hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix)
else:
hf_state_dict = {}
def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool):
config = self.config
num_query_groups = (
config.num_query_groups if config.num_query_groups is not None else config.num_attention_heads)
Expand Down Expand Up @@ -618,9 +616,6 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int
hf_state_dict['v_proj.weight_scale_inv'] = scale_inv[:, -kv_block:, :].reshape(
-1, hidden_size_block).clone()
del mg_attn_weight
self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, 'o_proj.weight', to_mcore)
if config.add_bias_linear:
self._set_state_dict(mg_attn, 'linear_proj.bias', hf_state_dict, 'o_proj.bias', to_mcore)

# Copy bias
if (config.add_bias_linear or config.add_qkv_bias) and not self._peft_format:
Expand All @@ -640,6 +635,18 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int
hf_state_dict['q_proj.bias'] = mg_attn_bias[:, :q_dim].reshape(-1).clone()
hf_state_dict['k_proj.bias'] = mg_attn_bias[:, q_dim:-kv_dim].reshape(-1).clone()
hf_state_dict['v_proj.bias'] = mg_attn_bias[:, -kv_dim:].reshape(-1).clone()
return hf_state_dict

def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool):
if to_mcore:
hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix)
else:
hf_state_dict = {}
config = self.config
hf_state_dict.update(self._set_qkv(mg_attn, hf_state_dict, to_mcore))
self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, f'{self.hf_o_proj_key}.weight', to_mcore)
if config.add_bias_linear:
self._set_state_dict(mg_attn, 'linear_proj.bias', hf_state_dict, f'{self.hf_o_proj_key}.bias', to_mcore)
if getattr(config, 'softmax_type', 'vanilla') == 'learnable':
self._set_state_dict(mg_attn, 'core_attention.softmax_offset', hf_state_dict, 'sinks', to_mcore)
if config.qk_layernorm:
Expand Down Expand Up @@ -1559,10 +1566,12 @@ def _set_mla_attn_state(
def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool):
mg_attn = None if mg_layer is None else mg_layer.self_attention
if self.config.multi_latent_attention:
hf_state_dict.update(self._set_mla_attn_state(mg_attn, hf_state_dict, 'self_attn.', layer_idx, to_mcore))
hf_state_dict.update(
self._set_mla_attn_state(mg_attn, hf_state_dict, f'{self.hf_attn_prefix}.', layer_idx, to_mcore))
self._set_state_dict(mg_layer, 'input_layernorm.weight', hf_state_dict, 'input_layernorm.weight', to_mcore)
else:
hf_state_dict.update(self._set_attn_state(mg_attn, hf_state_dict, 'self_attn.', layer_idx, to_mcore))
hf_state_dict.update(
self._set_attn_state(mg_attn, hf_state_dict, f'{self.hf_attn_prefix}.', layer_idx, to_mcore))
self._set_state_dict(mg_layer, 'self_attention.linear_qkv.layer_norm_weight', hf_state_dict,
'input_layernorm.weight', to_mcore)
return hf_state_dict
Expand Down
4 changes: 3 additions & 1 deletion src/mcore_bridge/config/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@
'moe_router_group_topk': ['topk_group'],
'num_moe_experts': ['num_experts', 'n_routed_experts', 'moe_num_experts', 'num_local_experts'],
'moe_router_pre_softmax': ['norm_topk_prob'],
'moe_router_enable_expert_bias': ['moe_router_enable_expert_bias'],
'rotary_interleaved': ['rope_interleave'],
# deepseek
'q_lora_rank': ['q_lora_rank'],
'kv_lora_rank': ['kv_lora_rank'],
'moe_router_score_function': ['scoring_func', 'moe_router_use_sigmoid'],
'moe_router_score_function': ['scoring_func', 'moe_router_use_sigmoid', 'score_function'],
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While adding score_function to the config_mapping is correct, the bailing_moe model type should also be explicitly handled in the hf_to_mcore_config function (around line 120 and 164) to ensure qk_layernorm is enabled and the router score function is set to sigmoid. The bridge definition in bailing_moe.py includes QK normalization keys and expert bias, which strongly suggests these configurations are required for the model to function correctly in Megatron-Core.

'moe_router_bias_update_rate': ['aux_loss_alpha'],
'qk_head_dim': ['qk_nope_head_dim'],
'qk_pos_emb_head_dim': ['qk_rope_head_dim'],
Expand Down
1 change: 1 addition & 0 deletions src/mcore_bridge/model/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class LLMModelType:
glm4 = 'glm4'
minimax_m2 = 'minimax_m2'
hy_v3 = 'hy_v3'
bailing_moe = 'bailing_moe'

qwen3_emb = 'qwen3_emb'

Expand Down
2 changes: 1 addition & 1 deletion src/mcore_bridge/model/gpts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from . import glm4, hunyuan, llm, minimax_m2, olmoe, qwen3_emb, qwen3_next
from . import bailing_moe, glm4, hunyuan, llm, minimax_m2, olmoe, qwen3_emb, qwen3_next
86 changes: 86 additions & 0 deletions src/mcore_bridge/model/gpts/bailing_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import torch
from megatron.core.transformer.attention import SelfAttention
from torch import Tensor
from typing import Optional

from mcore_bridge.bridge import GPTBridge

from ..constant import ModelType
from ..register import ModelLoader, ModelMeta, register_model


class BailingMoeSelfAttention(SelfAttention):

def get_query_key_value_tensors(
self,
hidden_states: Tensor,
key_value_states: Optional[Tensor] = None,
*args,
**kwargs,
):
"""Override to handle BailingMoE's non-interleaved QKV weight layout.

BailingMoE stores weights as [Q_all | K_all | V_all] (split by head count),
not Megatron's interleaved [q1 q2 k1 v1 | q3 q4 k2 v2 | ...].
"""
# [sq, b, h] --> [sq, b, (num_heads + 2 * num_kv_heads) * head_dim]
mixed_qkv, _ = self.linear_qkv(hidden_states)
Comment thread
Jintao-Huang marked this conversation as resolved.

# [sq, b, (num_heads + 2 * num_kv_heads) * head_dim]
# --> [sq, b, num_heads + 2 * num_kv_heads, head_dim]
new_tensor_shape = mixed_qkv.size()[:-1] + (
self.num_attention_heads_per_partition + 2 * self.num_query_groups_per_partition,
self.hidden_size_per_attention_head,
)
mixed_qkv = mixed_qkv.view(*new_tensor_shape)

# Split by head count: [sq, b, num_heads, hn], [sq, b, num_kv_heads, hn], [sq, b, num_kv_heads, hn]
query, key, value = torch.split(
mixed_qkv,
[
self.num_attention_heads_per_partition, self.num_query_groups_per_partition,
self.num_query_groups_per_partition
],
dim=2,
)

if self.q_layernorm is not None:
query = self.q_layernorm(query)

if self.k_layernorm is not None:
key = self.k_layernorm(key)

return query, key, value


class BailingMoeLoader(ModelLoader):

def get_transformer_layer_spec(self, vp_stage: Optional[int] = None):
transformer_layer_spec = super().get_transformer_layer_spec(vp_stage)
for layer_spec in transformer_layer_spec.layer_specs:
layer_spec.submodules.self_attention.module = BailingMoeSelfAttention
return transformer_layer_spec


class BailingMoeBridge(GPTBridge):
hf_embed_key = 'model.word_embeddings.weight'
hf_attn_prefix = 'attention'
hf_q_norm_key = 'query_layernorm.weight'
hf_k_norm_key = 'key_layernorm.weight'
hf_expert_bias_key = 'gate.expert_bias'
hf_o_proj_key = 'dense'

def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool):
self._set_state_dict(mg_attn, 'linear_qkv.weight', hf_state_dict, 'query_key_value.weight', to_mcore)
Comment thread
Jintao-Huang marked this conversation as resolved.
assert not self.config.add_bias_linear
return hf_state_dict


register_model(
ModelMeta(
ModelType.bailing_moe,
['bailing_moe'],
bridge_cls=BailingMoeBridge,
loader=BailingMoeLoader,
))
7 changes: 6 additions & 1 deletion tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ def test_olmoe():
_test_model('allenai/OLMoE-1B-7B-0125-Instruct')


def test_bailing():
_test_model('inclusionAI/Ling-mini-2.0')


if __name__ == '__main__':
# test_qwen2()
# test_llama2()
Expand All @@ -169,7 +173,7 @@ def test_olmoe():
# test_megrez()
# test_llama3_1()
# test_llama3_2()
test_qwen3()
# test_qwen3()
Comment thread
Jintao-Huang marked this conversation as resolved.
# test_qwen2_moe()
# test_qwen3_moe()
# test_internlm3()
Expand All @@ -190,3 +194,4 @@ def test_olmoe():
# test_minimax_m2()
# test_glm4_moe_lite()
# test_olmoe()
test_bailing()
Loading