### Sparse-MOE
![sparse-moe](./pics/Sparse_MoE.png)
和Basic-MOE的区别是，在Sparse-MOE中，MOE选择TopK个专家，然后对这topK个专家的输出进行加权求和。
并把输入样本变成了大模型中真实的输入Shape，（batch_size, sqe_len, hidden_dim）

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class BasicExpert(nn.Module):
    # 一个 Expert 可以是一个最简单的， linear 层即可
    # 也可以是 MLP 层
    # 也可以是 更复杂的 MLP 层（active function 设置为 swiglu）
    def __init__(self, feature_in, feature_out):
        super().__init__()
        self.linear = nn.Linear(feature_in, feature_out)

    def forward(self, x):
        return self.linear(x)

In [4]:
class MOERouter(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.gate=nn.Linear(config.hidden_dim,config.expert_number)
        #但是后面只会选top_k个专家
        self.exper_number=config.expert_number
        self.top_k=config.top_k

    def forward(self,x):
        #假设expert number是8，top_k是2
        router_logits=self.gate(x) #(batch*seq_len,exper_number)
        #计算每一个专家的概率
        router_probs=F.softmax(router_logits,dim=1,dtype=torch.float)
        #计算top_k专家的输出，top_k可反向传播
        router_weights,selected_expers_inices=torch.topk(
            router_probs,self.top_k,dim=-1
        )#shape都是（batch*seq_len,top_k）


        #专家权重归一化
        router_weights=router_weights/router_weights.sum(
            dim=-1,
            keepdim=True
        )
        router_weights=router_weights.to(x.dtype)

        #生成专家掩码
        expert_mask=F.one_hot(
            selected_expers_inices,
            num_classes=self.exper_number
        )
        #(batch*seq_len,top_k,exper_number)

        expert_mask=expert_mask.permute(2,1,0)
        #(exper_number,top_k,batch*seq_len)
        return router_logits,router_weights,selected_expers_inices,expert_mask
        #router_logits(batch*seq_len,expert_number)
        #router_weights(batch*seq_len,top_k)
        #selected_experts_indices(batch*seq_len,top_k)
        #expert_mask(exper_number,top_k,batch*seq_len)


class MOEConfig:
    def __init__(self,hidden_dim,expert_number,top_k,shared_experts_number=2):
        self.hidden_dim = hidden_dim
        self.expert_number = expert_number
        self.top_k = top_k
        self.shared_experts_number = shared_experts_number



class SparseMOE(nn.Module):
     #稀疏 MOE 模型，这里每一个 token 都会过 topk 个专家，得到对应token 的 hidden_embeddings
    def __init__(self,config):
        super().__init__()
        self.hidden_dim=config.hidden_dim
        self.expert_number=config.expert_number

        self.top_k=config.top_k

        #初始化专家
        self.experts=nn.ModuleList(
            [
                BasicExpert(self.hidden_dim,self.hidden_dim,) for _ in range(self.expert_number)
            ]
        )

        self.router=MOERouter(config)

    def forward(self, x):
        #x shape (batch,seq_len,hidden_dim)
        batch_size,seq_len,hidden_dim=x.size()

        #token维度计算，x,reshape(batch*seq_len,hidden_dim)
        hidden_states=x.view(-1,hidden_dim)
        #做相关专家计算
        router_logits,router_weights,selected_experts_indices,expert_mask=self.router(hidden_states)
        #expert_mask shape(exper_number,top_k,batch*seq_len)
        #最终是（batch*seq_len,hidden_dim）
        final_hidden_states=torch.zeros(
            (batch_size*seq_len,hidden_dim),
            dtype=hidden_states.dtype,
            device=hidden_states.device
        )
        #遍历每个专家模型
        #把选中的专家的token的hidden_states加到final_hidden_states中
        #expert=0可能有100个token选中
        #token的总数是batch*seq_len
        for expert_idx in range(self.expert_number):
            expert_layer=self.experts[expert_idx]
            #expert_masks(expert_number,top_k,batch*seq_len)
            current_expert_mask=expert_mask[expert_idx]
            #current_expert_mask shape(top_k,batch*seq_len)

            router_weights_idx,top_x=torch.where(current_expert_mask)
            #idx是0或1  #假设top_k是2
            #表示这个token是作为当前专家的top1还是top2
            #top_x是token在batch*seq_len中的位置索引
            #例如对于batch_size=2,seq_len=4的输入：
            #top_x的值的范围是0-7，表示在展平台后的8个token中的位置
            #他们都是一个一维的值
            #idx肯定是用来选weight,top_x用来选取hidden_states

            #hidden_states#shape是（1，batch*seq_len,hidden_dim）
            current_state=hidden_states.unsqueeze(0)[:,top_x,:].reshape(-1,hidden_dim)
            #current_state shape(selected_token_number,hidden_dim)
            current_state=expert_layer(current_state)
            #100个token选中
            #router_weights shape是（batch*seq_len,top_k）
            current_token_router_weight=router_weights[top_x,router_weights_idx]
            #最终的shape就变成了(selected_token_number)
            current_token_router_weight=current_token_router_weight.unsqueeze(-1)
            #最终的current=token_router_weight shape就变成了（selected_token_number,1）

            current_hidden_states=current_state*current_token_router_weight
            final_hidden_states.index_add_(
                0,
                top_x,
                current_hidden_states.to(hidden_states.dtype)
            )
            #把final_hidden_states还原到原来的shape
            final_hidden_states=final_hidden_states.reshape(batch_size,seq_len,hidden_dim)
            return final_hidden_states,router_logits
            #shape是（b*s,expert_number）


In [5]:
def test_token_level_moe():
    x=torch.rand(2,4,16)
    config=MOEConfig(16,2,2)
    token_level_moe=SparseMOE(config)
    out=token_level_moe(x)
    print(out[0].shape,out[1].shape)

test_token_level_moe()

torch.Size([2, 4, 16]) torch.Size([8, 2])
