# Mixture of Experts

## Introduction

The **Mixture of Experts (MoE)** architecture constitutes a deep learning paradigm that
allows efficient model scaling through submodel specialization. Originally introduced by
Jacobs et al. (1991) and subsequently popularized in the context of deep neural networks,
this architecture is based on the "divide and conquer" principle: instead of training a
single monolithic model for all tasks, a set of specialized models (experts) is trained
along with a routing mechanism (gating network) that determines which experts should
process each input.

The computational efficiency of MoE lies in its conditional activation capability:
although the model can contain a large number of parameters distributed among multiple
experts, only a subset of these is activated for each specific input. This property
allows building models with massive expressive capabilities while maintaining manageable
computational costs during inference. The architecture has proven particularly effective
in large-scale language models, where different experts can specialize in different
linguistic domains, styles, or types of knowledge.

## Individual expert architecture

Each **expert** in an MoE architecture constitutes an independent neural network designed
to process a specific subset of the input space. In its simplest form, an expert can be
implemented as a feed-forward network with hidden layers that transform the input into an
output representation. The specialization of each expert emerges naturally during
training, where the routing mechanism learns to direct different types of inputs to
different experts.

Mathematically, each expert $E_i$ can be represented as a parameterized function:

$$E_i(x; \theta_i) : \mathbb{R}^{d_{in}} \rightarrow \mathbb{R}^{d_{out}}$$

where $\theta_i$ represents the specific parameters of expert $i$, and $d_{in}$ and
$d_{out}$ are the input and output dimensionalities respectively.

In [1]:
import torch
from torch import nn
from torch.nn import functional as F

class ExpertModel(nn.Module):
    """
    Individual expert model for MoE
    """

    def __init__(self, input_dim: int, output_dim: int, hidden_dim: int) -> None:
        """
        Initializes an expert model with a simple feed-forward network.

        Args:
            input_dim: Dimensionality of the input data.
            output_dim: Dimensionality of the output data.
            hidden_dim: Dimensionality of the hidden layer.
        """
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim

        self.model = nn.Sequential(
            nn.Linear(in_features=self.input_dim, out_features=self.hidden_dim),
            nn.ReLU(),
            nn.Linear(in_features=self.hidden_dim, out_features=self.output_dim),
        )

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the expert model.

        Args:
            input_tensor: Input tensor to the model.

        Returns:
            The model's output tensor.
        """
        return self.model(input_tensor)

## Routing mechanism via gating network

The **routing mechanism** (gating network) constitutes the central component that
determines how inputs are distributed among available experts. This network learns to
assign weights to each expert based on input characteristics, producing a probability
distribution over experts through a softmax function. The gating network can be
interpreted as a soft classifier that determines which experts are most relevant for
processing each specific input.

The gating function $G(x; \phi)$ produces a vector of normalized weights:

$$G(x; \phi) = \text{softmax}(W_g \cdot h(x) + b_g)$$

where $h(x)$ represents intermediate transformations applied to the input, $W_g$ and
$b_g$ are learnable gating parameters, and $\phi$ denotes the complete set of routing
network parameters. The resulting weights satisfy $\sum_{i=1}^{N} g_i(x) = 1$, where $N$
is the total number of experts.

In [2]:
class Gating(nn.Module):
    """
    Gating mechanism to select experts.
    """

    def __init__(
        self, input_dim: int, num_experts: int, dropout_rate: float = 0.2
    ) -> None:
        """
        Initializes a gating network for expert selection.

        Args:
            input_dim: Dimensionality of the input data.
            num_experts: Number of experts to select from.
            dropout_rate: Rate of dropout for regularization.
        """
        super().__init__()
        self.input_dim = input_dim
        self.num_experts = num_experts
        self.dropout_rate = dropout_rate

        self.model = nn.Sequential(
            nn.Linear(in_features=self.input_dim, out_features=128),
            nn.Dropout(self.dropout_rate),
            nn.LeakyReLU(),
            nn.Linear(in_features=128, out_features=256),
            nn.LeakyReLU(),
            nn.Dropout(self.dropout_rate),
            nn.Linear(in_features=256, out_features=128),
            nn.LeakyReLU(),
            nn.Dropout(self.dropout_rate),
            nn.Linear(in_features=128, out_features=num_experts),
        )

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the gating network.

        Args:
            input_tensor: Input tensor to the network.

        Returns:
            Softmax probabilities for expert selection.
        """
        return F.softmax(self.model(input_tensor), dim=-1)

## Complete Mixture of Experts architecture

The complete MoE architecture integrates individual experts with the gating mechanism to
produce a final output through a weighted combination of expert predictions. For an input
$x$, the MoE system output is calculated as:

$$\text{MoE}(x) = \sum_{i=1}^{N} g_i(x) \cdot E_i(x)$$

where $g_i(x)$ represents the weight assigned to expert $i$ by the gating network, and
$E_i(x)$ is the output of expert $i$. This formulation allows the model to automatically
learn which experts are most relevant for different regions of the input space,
facilitating specialization and improving model capacity without proportionally
increasing computational cost.

During training, both experts and the gating network are jointly optimized through
standard backpropagation. The gradient flows through all experts weighted by their
respective gating weights, allowing the system to learn both expert specialization and
optimal routing in an end-to-end manner.

In [3]:
class MoE(nn.Module):
    """
    Mixture of Experts
    """

    def __init__(
        self,
        trained_experts: list[ExpertModel],
        input_dim: int,
        dropout_rate: float = 0.2,
    ) -> None:
        """
        Initializes a mixture of experts with gating.

        Args:
            trained_experts: List of trained expert models.
            input_dim: Dimensionality of the input data.
            dropout_rate: Rate of dropout in the gating network.
        """
        super().__init__()
        self.experts = nn.ModuleList(trained_experts)
        self.num_experts = len(trained_experts)
        self.input_dim = input_dim
        self.dropout_rate = dropout_rate

        self.gating_layer = Gating(
            input_dim=self.input_dim,
            num_experts=self.num_experts,
            dropout_rate=self.dropout_rate,
        )

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the mixture of experts.

        Args:
            input_tensor: Input tensor to the model.

        Returns:
            Weighted sum of expert outputs.
        """
        expert_weights = self.gating_layer(input_tensor)

        _expert_outputs: list[torch.Tensor] = []
        for expert in self.experts:
            _expert_outputs.append(expert(input_tensor))

        expert_outputs = torch.stack(_expert_outputs, dim=-1)
        expert_weights = expert_weights.unsqueeze(1)

        return torch.sum(expert_outputs * expert_weights, dim=-1)

## Usage example and verification

The MoE implementation allows direct integration into existing deep learning pipelines.
The following example demonstrates basic model instantiation and usage, including
verification that gating weights are correctly normalized.

In [4]:
if __name__ == "__main__":
    input_dim = 10
    output_dim = 5
    num_experts = 3
    batch_size = 32
    hidden_dim = 128

    experts = [
        ExpertModel(input_dim=input_dim, output_dim=output_dim, hidden_dim=hidden_dim)
        for _ in range(num_experts)
    ]

    moe = MoE(experts, input_dim)

    x = torch.randn(batch_size, input_dim)
    output = moe(x)

    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Expected output shape: ({batch_size}, {output_dim})")

    gating_weights = moe.gating_layer(x)
    print(f"Gating weights shape: {gating_weights.shape}")
    print(f"Gating weights sum per sample: {gating_weights.sum(dim=1)}")

Input shape: torch.Size([32, 10])
Output shape: torch.Size([32, 5])
Expected output shape: (32, 5)
Gating weights shape: torch.Size([32, 3])
Gating weights sum per sample: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)
