In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

In [None]:
class Expert(nn.Module):
  def __init__(self, input_dim, hidden_dim, output_dim):
    super(Expert, self).__init__()
    self.layer1 = nn.Linear(input_dim, hidden_dim)
    self.layer2 = nn.Linear(hidden_dim, output_dim)

  def forward(self, x):
    x = torch.relu(self.layer1(x))
    return torch.softmax(self.layer2(x), dim=-1)

In [None]:
class Gating(nn.Module):
  def __init__(self, input_dim, num_experts, dropout_rate=0.1):
    super(Gating, self).__init__()

    self.layer1 = nn.Linear(input_dim, 128)
    self.dropout1 = nn.Dropout(dropout_rate)

    self.layer2 = nn.Linear(128, 256)
    self.leaky_relu = nn.LeakyReLU()
    self.dropout2 = nn.Dropout(dropout_rate)

    self.layer3 = nn.Linear(256, 128)
    self.leaky_relu = nn.LeakyReLU()
    self.dropout3 = nn.Dropout(dropout_rate)

    self.layer4 = nn.Linear(128, num_experts)

  def forward(self, x):
    x = torch.relu(self.layer1(x))
    x = self.dropout1(x)

    x = self.layer2(x)
    x = self.leaky_relu(x)
    x = self.dropout2(x)

    x = self.layer3(x)
    x = self.leaky_relu(x)
    x = self.dropout3(x)

    retun = torch.softmax(self.layer4(x), dim=1)

In [None]:
class MoE(nn.Module):
  def __init__(self, trained_experts):
    super(MoE, self).__init__()
    self.experts = nn.ModuleList(trained_experts)

    # Freeze experts while MoE is training
    for expert in self.experts:
      for param in expert.parameters():
        param.requires_grad = False

    num_experts = len(self.trained_experts)
    # Assuming all experts have the same input dimension
    input_dim = self.trained_experts[0].layer1.in_features
    self.gating = Gating(input_dim, num_experts)

  def forward(self, x):
    weights = self.gating(x)

    outputs = torch.stack([expert(x) for expert in self.experts], dim=2)

    weights = weights.unsqueeze(1).expand_as(outputs)

    return torch.sum(weights * outputs, dim=2)