In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

### Basic MoE

In [2]:
class BasicExpert(nn.Module):
    def __init__(self, feature_in, feature_out):
        super().__init__()
        self.fc = nn.Linear(feature_in, feature_out)

    def forward(self, x):
        return self.fc(x)

In [3]:
class BasicMoe(nn.Module):
    def __init__(self, feature_in, feature_out, num_expert):
        super().__init__()
        self.gate = nn.Linear(feature_in, num_expert) # 门控网络，把输入数据映射为 num_expert 维的权重分数
        self.experts = nn.ModuleList(
            BasicExpert(
                feature_in, feature_out
            ) for _ in range(num_expert)
        )

    def forward(self, x):
        # 计算每个 expert 的权重分数，并得到每个 expert 的输出
        expert_weights = self.gate(x) # (batch size, feature in) -> (batch size, num expert)
        expert_out_list = [
            expert(x) for expert in self.experts
        ] # 每个 expert 输出一个 (batch size, feature out)

        expert_outputs = [
            expert_out.unsqueeze(1) for expert_out in expert_out_list
        ] # 把每个 expert out 变为 (batch size, 1, feature out)

        expert_output = torch.concat(
            expert_outputs, dim=1
        ) # 堆叠为 (batch size, num experts, feature out)

        expert_weights = F.softmax(expert_weights, dim=-1) # (batch size, num experts)

        expert_weights = expert_weights.unsqueeze(1) # (batch size, 1, num experts)
        # 目标输出是 (batch, feature_out)
        output = expert_weights @ expert_output # (batch size, 1, feature out)
        return output.squeeze(1) # (batch size, feature out)

In [4]:
def test_basic_moe():
    x = torch.rand(4, 512)
    basic_moe = BasicMoe(512, 128, 4)
    output = basic_moe(x)
    print(output.shape)

In [5]:
test_basic_moe()

torch.Size([4, 128])


### SparseMoE

SparseMoE 和 BasicMoE 的区别是，它会选择 top K 个专家，而非所有的专家，然后对这 top K 个专家的输出进行加权求和

In [6]:
class MOEConfig:
    def __init__(self, hidden_dim, expert_number, top_k, shared_experts_number=2):
        self.hidden_dim = hidden_dim
        self.expert_number = expert_number
        self.top_k = top_k
        self.shared_experts_number = shared_experts_number

class MOERouter(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.gate = nn.Linear(config.hidden_dim, config.expert_number)

        # 但后面只会选择 top k 个专家
        self.expert_number = config.expert_number
        self.top_k = config.top_k

    def forward(self, x):
        # 假设 expert number 是 8，top k 是 2
        router_logits = self.gate(x) # (batch_size, expert number)

        # 计算每一个专家的概率
        router_probs = F.softmax(router_logits, dim=1, dtype=torch.float32)

        # 计算 top k 专家的输出
        # top k 是可以反向传播的
        router_weights, selected_experts_indices = torch.topk(router_probs, self.top_k, dim=-1)
        # 这两个的 shape 都是 (batch_size * seq_len, top_k)

        # 选出 top k 个 expert 之后重新计算各自权重，此处使用归一化
        router_weights = router_weights / router_weights.sum(
            dim=-1, keepdim=True
        )
        router_weights = router_weights.to(x.dtype)

        expert_mask = F.one_hot(
            selected_experts_indices,
            num_classes=self.expert_number
        ) # (batch_size * seq_len, top_k, expert_number)
        expert_mask = expert_mask.permute(2, 1, 0) # (expert_number, top_k, batch_size * seq_len)

        return router_logits, router_weights, selected_experts_indices, expert_mask

class SparseMOE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.top_k = config.top_k
        self.hidden_dim = config.hidden_dim
        self.expert_number = config.expert_number

        # 初始化专家
        self.experts = nn.ModuleList(
            BasicExpert(
                config.hidden_dim,
                config.hidden_dim
            ) for _ in range(config.expert_number)
        )
        self.router = MOERouter(config)

    def forward(self, x):
        # x: (batch size, seq_len , hidden_dim)
        batch_size, seq_len, hidden_dim = x.size()

        # 因为是要对 token 进行维度计算，所以 x reshape -> (batch_size * seq_len, hidden_dim)
        hidden_states = x.view(-1, hidden_dim)

        # 相关的专家计算
        router_logits, router_weights, selected_experts_indices, expert_masks = self.router(
            hidden_states
        )
        # expert_mask shape -> (expert_number, top_k, batch_size * seq_len)

        final_hidden_states = torch.zeros(
            (batch_size * seq_len, hidden_dim),
            dtype=hidden_states.dtype,
            device=hidden_states.device
        )

        # 遍历每个 expert，把选中的 expert 的计算结果 (hidden states) 加到 final hidden states
        for expert_idx in range(self.expert_number):
            expert_layer = self.experts[expert_idx]

            current_expert_mask = expert_masks[expert_idx]

            router_weights_idx, top_x = torch.where(current_expert_mask) # idx 是 0 或 1

            current_state = hidden_states.unsqueeze(0)[:, top_x, :].reshape(-1, hidden_dim) # (selected_token_number, hidden_dim)
            current_state = expert_layer(current_state)

            current_token_router_weight = router_weights[top_x, router_weights_idx] # (selected_token_number)
            current_token_router_weight = current_token_router_weight.unsqueeze(-1) # (selected_token_number, 1)

            current_hidden_states = current_state * current_token_router_weight # (selected_token_number, hidden_dim)

            final_hidden_states.index_add_( # 加下划线表示原地操作
                0,
                top_x,
                current_hidden_states.to(hidden_states.dtype)
            )
        # 把 final hidden states 还原到原来的 shape
        final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_dim)

        return final_hidden_states, router_logits

def test_token_level_moe():
    x = torch.rand(2, 4, 16)
    config = MOEConfig(16, 2, 2)
    token_level_moe = SparseMOE(config)
    out = token_level_moe(x)
    print(out[0].shape, out[1].shape)

test_token_level_moe()

torch.Size([2, 4, 16]) torch.Size([8, 2])


### ShareExpert SparseMoe

In [13]:
class SharedExpertMOE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.routed_experts_moe = SparseMOE(config)
        self.shared_experts = nn.ModuleList(
            [
                BasicExpert(
                    config.hidden_dim,
                    config.hidden_dim,
                )
                for _ in range(config.shared_experts_number)
            ]
        )

    def forward(self, x):
        batch_size, seq_len, hidden_dim = x.size()

        shared_experts_output_list = [
            expert(x) for expert in self.shared_experts
        ]
        shared_experts_output = torch.stack(
            shared_experts_output_list,
            dim = 0,
        ) # (shared_experts_number, batch_size, seq_len, hidden_dim)

        shared_expert_out = shared_experts_output.sum(dim=0, keepdim=False)

        sparse_moe_out, router_logits = self.routed_experts_moe(
            x
        )

        output = shared_expert_out + sparse_moe_out
        return output, router_logits

def test_share_expert_moe():
    x = torch.rand(2, 4, 16)
    config = MOEConfig(16, 2, 2)
    share_expert_moe = SharedExpertMOE(config)
    out = share_expert_moe(x)
    print(out[0].shape, out[1].shape)

test_share_expert_moe()

torch.Size([2, 4, 16]) torch.Size([8, 2])
