### Sparse-MOE
![sparse-moe](./pics/Sparse_MoE.png)
和Basic-MOE的区别是，在Sparse-MOE中，MOE选择TopK个专家，然后对这topK个专家的输出进行加权求和。
并把输入样本变成了大模型中真实的输入Shape，（batch_size, sqe_len, hidden_dim）

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class MOEConfig:
    def __init__(self, hidden_dim, expert_number, top_k, shared_expert_numbers = 2):
        self.hidden_dim = hidden_dim
        self.expert_number = expert_number
        self.top_k = top_k
        self.shared_expert_numbers = shared_expert_numbers


class MOERouter(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.gate = nn.Linear(config.hidden_dim, config.expert_number)

        self.expert_number = config.expert_number
        self.top_k = config.top_k

    def forward(self, x):
        router_logits = self.gate(x)

        router_probs = F.softmax(router_logits, dim = 1, dtype = torch.float)
        
        # top_k 可以反向传播
        router_weights, selected_expert_indices = torch.top_k(
            router_probs,
            self.top_k,
            dim = -1
        )

        router_weights = router_weights / router_weights.sum(
            dim = -1,
            keepdim = True
        )
        router_weights = router_weights.to(x.dtype)

        expert_mask = F.one_hot(
            selected_expert_indices,
            num_classes = self.expert_number
        ) # (batch, top_k, expert_number)

        expert_mask = expert_mask.permute(2, 1, 0)

        return router_logits, router_weights, selected_expert_indices, expert_mask




class Spare_MOE(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.config = config
        self.top_k = config.top_k
        self.expert_number = config.expert_number

        self.experts = nn.ModuleList(
            [
                Basic_Expert(
                    config.hidden_dim, 
                    config.hidden_dim
                ) for _ in range(config.expert_number)
            ]
        )

        self.shared_experts = nn.ModuleList()


    def forward(self, x):
        # x shape: [batch, seq_len, hidden_dim]
        batch_size, seq_len, hidden_dim = x.size()

        # token 维度计算
        # x reshape to [batch * seq_len, hidden_dim]
        hidden_state = x.view(-1, hidden_dim)

        # 专家计算
        router_logits, router_weights, selected_expert_indices, expert_mask = self.router(hidden_state)

        