### Basic MOE 
![basci-moe](./pics/Basic_MOE.png)
### import backages


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class Basic_Expert(nn.Module):
    def __init__(self, feature_in, feature_out):
        super().__init__()

        self.fc = nn.Linear(feature_in, feature_out)

    def forward(self, x):
        return self.fc(x)

In [5]:
class Basic_MOE(nn.Module):
    def __init__(self, feature_in, feature_out, num_experts):
        super().__init__()
        
        self.experts = nn.ModuleList(
            Basic_Expert(
                feature_in, feature_out
            )for _ in range(num_experts)
        )
        self.gate = nn.Linear(feature_in, num_experts)

    def forward(self, x):
        # x shape is [batch, feature_in]
        # faeature_in mean hidden_dim

        expert_weights = self.gate(x)
        expert_out_list = [
            expert(x).unsqueeze(1) for expert in self.experts
        ]
        
        expert_out = torch.cat(
            expert_out_list,
            dim = 1
        )

        # expert_weight softmax
        expert_weights = F.softmax(expert_weights, dim = 1)
        # expert_weights shape is [batch, num_experts]

        # expert_out shape is [batch, num_experts, feature_out]
        expert_weights = expert_weights.unsqueeze(1)
        # expert_weights shape is [batch, 1, num_experts]
        output = expert_weights @ expert_out

        return output.squeeze(1)



In [6]:
def test_basic_moe():
    x = torch.rand(4, 512)
    basic_moe = Basic_MOE(512, 128, 4)
    output = basic_moe(x)
    print(output.shape)

test_basic_moe()

torch.Size([4, 128])
