### 기본(초창기) MoE 구현
- 가장 기본적인 전문가(Expert) 및 게이트웨이 네트워크로 구현

In [1]:
import torch
import torch.nn as nn

In [3]:
class Expert(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Expert, self).__init__()
        self.layer = nn.Linear(input_dim, output_dim)
        
    def forward(self, x):
        return self.layer(x)

In [4]:
class GatingNetwork(nn.Module):
    def __init__(self, input_dim, num_experts):
        super(GatingNetwork, self).__init__()
        self.layer = nn.Linear(input_dim, num_experts)
        
    def forward(self, x):
        return torch.softmax(self.layer(x), dim=-1) # [batch size, num experts]

In [6]:
class MoE(nn.Module):
    def __init__(self, input_dim, output_dim, num_experts):
        super(MoE, self).__init__()
        self.num_experts = num_experts
        self.experts = nn.ModuleList([
            Expert(input_dim, output_dim) for _ in range(num_experts)
        ])
        self.gate = GatingNetwork(input_dim, num_experts)
        
    def forward(self, x):
        gate_outputs = self.gate(x) # [batch size, num experts]
        
        # 전문가들의 출력을 병렬로 계산
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1) # [batch size, num experts, output dim]
        
        # 게이트웨이 확률과 전문가 출력을 가중합산
        gate_outputs = gate_outputs.unsqueeze(-1) # [batch size, num experts, 1]
        output = torch.sum(gate_outputs * expert_outputs, dim=1) # [batch size, output dim]
        
        return output

In [19]:
batch_size = 10
input_dim = 512
output_dim = 256
num_experts = 4

moe = MoE(input_dim, output_dim, num_experts)

In [21]:
x = torch.randn(batch_size, input_dim)
output = moe(x)
output.shape, output

(torch.Size([10, 256]),
 tensor([[ 4.5667e-02,  2.7210e-01,  3.2642e-02,  ...,  1.3159e-01,
          -1.6857e-01, -5.6110e-02],
         [-1.1751e-01, -4.8721e-01, -1.2975e-01,  ...,  2.4370e-01,
           2.4538e-01, -9.4550e-02],
         [ 1.6024e-01,  2.1195e-01, -4.2267e-01,  ...,  2.1199e-01,
           5.7871e-02,  6.9476e-01],
         ...,
         [-2.9266e-04, -4.3952e-01,  2.4019e-01,  ..., -8.0054e-02,
           1.4274e-01, -3.5533e-01],
         [-3.5217e-01, -1.9802e-01,  4.3595e-01,  ...,  2.2851e-01,
           3.6919e-01,  4.3869e-01],
         [ 3.2570e-01, -1.5187e-01, -4.3439e-01,  ..., -7.7880e-02,
           2.7718e-01,  3.2285e-01]], grad_fn=<SumBackward1>))