In [None]:
# 简化版本， 代码不一定能跑
class SharedExpert(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(SharedExpert, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 4 * input_dim),
            nn.GELU(),
            nn.Linear(4*input_dim, hidden_dim)
        )

    def forward(self,x):
        x = self.fc(x)
        return x


class RouterExpert(nn.Moudle):
    def __init__(self, input_dim, hidden_dim):
        super(RouterExpert, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 4*input_dim),
            nn.GELU(),
            nn.Linear(4*input_dim, hidden_dim)
        )

    def forward(self,x):
        return self.fc(x)

class MoEGate(nn.Moudle):
    def __init__(self, topk, nums_expert, input_dim):
        super(MoEGate, self).__init__()
        self.topk = topk
        self.gate_linear = nn.Linear(input_dim, nums_expert)

    def forward(self, x):
        lgit = self.gate_linear(x)
        wght = torch.softmax(lgit, dim=-1)
        topk_wght, topk_index = torch.topk(wght, self.topk, dim = -1)
        return topk_wght, topk_index

class DeepseekMoE(nn.Moudle):
    def __init__(self, input_dim, hidden_dim, num_shared, num_routed, topk):
        super(DeepseekMoE, self).__init__()
        self.shared_experts = nn.MoudleList([SharedExpert(input_dim, hidden_dim) for _ in range(num_shared)])
        self.routed_experts = nn.MoudelList([RouterExpert(input_dim, hidden_dim) for _ in range(num_routed)])
        self.gate = MoEGate(input_dim, num_routed, topk)
        self.topk = topk

    def forward(self,x):
        shared_out = torch.stack([expert(x) for expert in self.shared_experts]).mean(dim=0)
        topk_wght, topk_index = self.gate(x)

        expert_out = torch.stack([expert(x) for expert in self.routed_experts])
        select_expert = expert_out.gather(self.topk, topk_index)
        routed_out = (select_expert*topk_wght).sum()

        return shared_out + routed_out

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SharedExpert(nn.Module):
    """共享专家网络（所有输入强制经过该专家）"""
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 4 * input_dim),
            nn.GELU(),
            nn.Linear(4 * input_dim, hidden_dim)
        )

    def forward(self, x):
        x = self.fc(x)
        return x  # [batch_size, seq_len, hidden_dim]

class RoutedExpert(nn.Module):
    """路由专家网络（动态选择）"""
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 4 * input_dim),
            nn.GELU(),
            nn.Linear(4 * input_dim, hidden_dim)
        )

    def forward(self, x):
        return self.fc(x)

class MoEGate(nn.Module):
    """权重路由门控网络（含Top-K稀疏化）"""
    def __init__(self, input_dim, num_experts, top_k=2):
        super().__init__()
        self.top_k = top_k
        self.gate_linear = nn.Linear(input_dim, num_experts)
        self.aux_loss = 0  # 用于负载均衡损失

    def forward(self, x):
        # 计算路由权重
        logits = self.gate_linear(x)  # [batch_size, seq_len, num_experts]
        weights = F.softmax(logits, dim=-1)
        topk_weights, topk_indices = torch.topk(weights, self.top_k, dim=-1)
        return topk_weights, topk_indices

class DeepSeekMoE(nn.Module):
    """集成共享专家与路由专家的MoE模块"""
    def __init__(self, input_dim, hidden_dim, num_shared=1, num_routed=8, top_k=2):
        super().__init__()
        # 专家定义
        self.shared_experts = nn.ModuleList([SharedExpert(input_dim, hidden_dim) for _ in range(num_shared)])
        self.routed_experts = nn.ModuleList([RoutedExpert(input_dim, hidden_dim) for _ in range(num_routed)])
        # 路由门控
        self.gate = MoEGate(input_dim, num_routed, top_k=top_k)
        self.top_k = top_k

    def forward(self, x):
        # 共享专家输出（强制计算）
        shared_out = torch.stack([expert(x) for expert in self.shared_experts]).mean(dim=0)

        # 路由专家输出（动态选择）
        batch_size, seq_len, _ = x.shape
        topk_weights, topk_indices = self.gate(x)  # [batch_size, seq_len, top_k]

        # 批量计算所有路由专家输出
        expert_outputs = torch.stack([expert(x) for expert in self.routed_experts], dim=2)  # [batch, seq, num_experts, hidden]

        # 索引选择与加权求和
        selected_experts = expert_outputs.gather(2, topk_indices.unsqueeze(-1).expand(-1, -1, -1, expert_outputs.size(-1)))
        routed_out = (selected_experts * topk_weights.unsqueeze(-1)).sum(dim=2)

        # 合并共享与路由输出
        return shared_out + routed_out  # [batch_size, seq_len, hidden_dim]

In [None]:
moe_layer = DeepSeekMoE(
    input_dim=768,
    hidden_dim=768,
    num_shared=2,   # 共享专家数量
    num_routed=8,   # 路由专家数量
    top_k=2        # 每个Token激活的专家数
)

x = torch.randn(32, 128, 768)  # 输入维度 [batch, seq_len, dim]
output = moe_layer(x)          # 输出维度 [32, 128, 768]