### Basic MOE 
![basci-moe](./pics/Basic_MOE.png)
### import backages


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class Basic_Expert(nn.Module):
    def __init__(self, feature_in, feature_out):
        super().__init__()

        self.fc = nn.Linear(feature_in, feature_out)

    def forward(self, x):
        return self.fc(x)

In [None]:
class Basic_MOE(nn.Module):
    def __init__(self, feature_in, feature_out, num_experts):
        super().__init__()
        self.gate = feature_in
        self.experts = nn.ModuleList(
            Basic_Expert(
                feature_in, feature_out
            )for _ in range(num_experts)
        )

    def forward(self, x):
        # x shape is [batch, feature_in]
        # faeature_in mean hidden_dim

        expert_weights = self.gate(x)
        expert_out_list = [
            expert(x) for expert in self.experts
        ]
        expert_outputs = [
            expert_out.unsqueeze(1)
            for expert_out in expert_out_list
        ]
        
        expert_out = torch.concat(
            expert_outputs,
            dim = 1
        )

        # expert_weight softmax
        expert_weights = F.softmax(expert_weights, dim = 1)

        # expert_out shape is [batch, num_experts, feature_out]
        expert_weights = expert_weights.unsqueeze(1)
        output = expert_weights @ expert_out

        return output.squeeze(1)

