In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F



In [4]:
class Linear(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x)
    

In [18]:
class MoELayer(nn.Module):
    def __init__(self, num_experts, in_features, out_features):
        super().__init__()
        self.num_experts = num_experts
        self.expert = nn.ModuleList([Linear(in_features, out_features) for _ in range(num_experts)])
        self.gate = Linear(in_features, num_experts)

    def forward(self, x):
        gate_scores = F.softmax(self.gate(x), dim=-1)
        expert_outputs = torch.stack([expert(x) for expert in self.expert], dim=1)
        output = torch.bmm(gate_scores.unsqueeze(1), expert_outputs).squeeze(1)
        return output


In [19]:
input_dim = 4
output_dim = 6
num_experts = 3
batch_size = 10

model = MoELayer(num_experts, input_dim, output_dim)

demo = torch.randn(batch_size, input_dim)

result = model(demo)

print(result.shape)

torch.Size([10, 6])
