<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/MoE_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Expert(nn.Module):
    def __init__(self, input_size, output_size, hidden_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class GatingNetwork(nn.Module):
    def __init__(self, input_size, num_experts):
        super().__init__()
        self.fc = nn.Linear(input_size, num_experts)

    def forward(self, x):
        # Output logits for each expert
        logits = self.fc(x)
        # Use softmax to get probabilities/weights for each expert
        weights = F.softmax(logits, dim=1)
        return weights

class MoEAgent(nn.Module):
    def __init__(self, input_size, output_size, num_experts, expert_hidden_size):
        super().__init__()
        self.num_experts = num_experts
        self.experts = nn.ModuleList([
            Expert(input_size, output_size, expert_hidden_size) for _ in range(num_experts)
        ])
        self.gating_network = GatingNetwork(input_size, num_experts)

    def forward(self, x):
        # Get the weights for each expert from the gating network
        expert_weights = self.gating_network(x)

        # Initialize an output tensor
        expert_outputs = [self.experts[i](x) for i in range(self.num_experts)]
        expert_outputs = torch.stack(expert_outputs, dim=2) # Shape: (batch_size, output_size, num_experts)

        # Combine the expert outputs using the weights
        # (batch_size, num_experts) x (batch_size, 1, output_size, num_experts) -> (batch_size, 1, output_size)
        weighted_outputs = torch.matmul(expert_weights.unsqueeze(1), expert_outputs.transpose(1, 2)).squeeze(1)

        return weighted_outputs

# --- Demo Usage ---
if __name__ == "__main__":
    # Agent parameters
    input_size = 10
    output_size = 5
    num_experts = 3
    expert_hidden_size = 20
    batch_size = 4

    # Create an MoE agent
    moe_agent = MoEAgent(input_size, output_size, num_experts, expert_hidden_size)

    # Dummy input
    dummy_input = torch.randn(batch_size, input_size)

    # Get the output from the agent
    agent_output = moe_agent(dummy_input)

    print("Dummy Input Shape:", dummy_input.shape)
    print("Gating Network Weights Shape:", moe_agent.gating_network(dummy_input).shape)
    print("Agent Output Shape:", agent_output.shape)

    # Example of how to access individual expert outputs (before weighted combination)
    for i in range(num_experts):
        print(f"Expert {i} Output Shape:", moe_agent.experts[i](dummy_input).shape)

    # You would typically train this MoE agent using a suitable loss function
    # and optimizer, where the gating network learns to route inputs to
    # the most appropriate experts to minimize the loss.

Dummy Input Shape: torch.Size([4, 10])
Gating Network Weights Shape: torch.Size([4, 3])
Agent Output Shape: torch.Size([4, 5])
Expert 0 Output Shape: torch.Size([4, 5])
Expert 1 Output Shape: torch.Size([4, 5])
Expert 2 Output Shape: torch.Size([4, 5])
