# MoE Architecture - based on Deepseek-v3 code

In [36]:
import torch
from torch import nn
import torch.nn.functional as F
from typing import Tuple


In [37]:
class Expert(nn.Module):
    """
    Expert layer for Mixture-of-Experts (MoE) models.
    It looks like Deepseek modeled this off of Meta's Llama code: https://github.com/meta-llama/llama/blob/main/llama/model.py
    It is a gated linear unit (GLU) that uses a SiLu (instead of sigmoid) activation function.
    Orig GLU paper: https://arxiv.org/pdf/2002.05202

    Attributes:
        w1 (nn.Module): Linear layer for input-to-hidden transformation.
        w2 (nn.Module): Linear layer for hidden-to-output transformation.
        w3 (nn.Module): Additional linear layer for feature transformation.
    """
    def __init__(self, dim: int, inter_dim: int):
        """
        Initializes the Expert layer.

        Args:
            dim (int): Input and output dimensionality.
            inter_dim (int): Hidden layer dimensionality.
        """
        super().__init__()
        self.w1 = nn.Linear(dim, inter_dim)
        self.w2 = nn.Linear(inter_dim, dim)
        self.w3 = nn.Linear(dim, inter_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for the Expert layer.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor after expert computation.
        """
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class Gate(nn.Module):
    def __init__(self, dim, n_experts, topk_experts):
        """
        Gating mechanism for routing inputs in a mixture-of-experts (MoE) model.

        Load balancing: Deepseek-v3 uses auxiliary-loss-free load balancing, which adds a 
        bias term for each experts that is added into the affinity scores prior to selecting
        the topk; this is dynamically updated during training. If the expert is overloaded,
        then those experts are scaled down by a constant; and vice versa. The authors 
        claim this does better than pure auxiliary loss!

        The authors also use a complementary sequence-wise auxiliary loss; this further encourages tokens within a sequence to be balanced across experts.
        They note that the contribution of this is small related to the auxiliary-loss-free load.
        
        Implements:
        - topK gating
        - auxiliary-loss-free load balancing

        Does NOT implement:
        - sequence-wise auxiliary loss: not present in pulic deepseek code, but mentioned in paper?
        - group-wise routing (hierarchical routing)
        - mixed-precision training
        - shared experts
        - route-scale
        
        Args:
            dim (int): Input dimension
            n_experts (int): Number of experts
            topk_experts (int): Number of  experts activated for each input.
            weight (torch.nn.Parameter): Learnable weights for the gate.


        """
        super().__init__()
        self.dim = dim
        self.n_experts = n_experts
        self.topk_experts = topk_experts
        self.weight = nn.Parameter(torch.empty(n_experts, dim))
        self.bias = nn.Parameter(torch.empty(n_experts))
        #Unlike the Deepseek paper, I am adding Xavier initialization
        #Reminder: functions that end with _ indicate inplace operations
        nn.init.xavier_uniform_(self.weight)
        # Unlike the Deepseek paper, I am adding zero initialization
        nn.init.zeros_(self.bias)
        assert self.n_experts > 1, 'Number of experts must be greater than 1'

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        scores = F.linear(x, self.weight)
        # Deepseek-v3 uses a sigmoid function to compute the affinity scores, and then applies normalization after experts are selected
        scores = scores.sigmoid()
        original_scores = scores
        #Bias term ONLY used to influence topK selection
        scores = scores + self.bias
        indices = torch.topk(scores, self.topk_experts, dim=-1)[1]
        weights = original_scores.gather(1, indices)
        weights /= weights.sum(dim=-1, keepdim=True) #Normalize to a probabilty vector (sum to 1)
        return weights, indices

class MoE(nn.Module):
    """
    Mixture-of-Experts (MoE) module.

    Attributes:
        dim (int): Dimensionality of input features.
        n_experts (int): Total number of experts in the model.
        n_activated_experts (int): Number of experts activated for each input.
        experts (nn.ModuleList): List of expert modules.
    """
    def __init__(self):
        super().__init__()
        self.dim = dim
        self.n_experts = n_experts
        self.n_activated_experts = n_activated_experts
        # self.experts = nn.ModuleList([Expert(dim, moe_inter_dim) for _ in range(n_experts)])
        # self.gate = Gate(dim, n_experts)
        

In [35]:
# test the gate
gate = Gate(dim=100, n_experts=10, topk_experts=2)

x = torch.randn(10, 100) #Returns a tensor filled with random numbers from a normal distribution with mean 0 and variance 1
weights, indices = gate(x)
print(weights)
print(indices)

# test the expert
expert = Expert(dim=100, inter_dim=10)
print(x.shape)
y = expert(x)
print(y.shape)


tensor([[0.5641, 0.4359],
        [0.5179, 0.4821],
        [0.5052, 0.4948],
        [0.5439, 0.4561],
        [0.5055, 0.4945],
        [0.5003, 0.4997],
        [0.5230, 0.4770],
        [0.5349, 0.4651],
        [0.5326, 0.4674],
        [0.5061, 0.4939]], grad_fn=<DivBackward0>)
tensor([[5, 1],
        [5, 2],
        [7, 0],
        [1, 9],
        [9, 7],
        [3, 5],
        [7, 1],
        [8, 0],
        [9, 3],
        [8, 3]])
torch.Size([10, 100])
torch.Size([10, 100])


## Draft

In [None]:
# From: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
class Gate(nn.Module):
    """
    Gating mechanism for routing inputs in a mixture-of-experts (MoE) model.

    Attributes:
        dim (int): Dimensionality of input features.
        topk (int): Number of top experts activated for each input.
        n_groups (int): Number of groups for routing.
        topk_groups (int): Number of groups to route inputs to.
        score_func (str): Scoring function ('softmax' or 'sigmoid').
        route_scale (float): Scaling factor for routing weights.
        weight (torch.nn.Parameter): Learnable weights for the gate.
        bias (Optional[torch.nn.Parameter]): Optional bias term for the gate.
    """
    def __init__(self, args: ModelArgs):
        """
        Initializes the Gate module.

        Args:
            args (ModelArgs): Model arguments containing gating parameters.
        """
        super().__init__()
        self.dim = args.dim
        self.topk = args.n_activated_experts
        self.n_groups = args.n_expert_groups
        self.topk_groups = args.n_limited_groups
        self.score_func = args.score_func
        self.route_scale = args.route_scale
        self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
        self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass for the gating mechanism.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices.
        """
        scores = linear(x, self.weight)
        if self.score_func == "softmax":
            scores = scores.softmax(dim=-1, dtype=torch.float32)
        else:
            scores = scores.sigmoid()
        original_scores = scores
        if self.bias is not None:
            scores = scores + self.bias
        if self.n_groups > 1:
            scores = scores.view(x.size(0), self.n_groups, -1)
            if self.bias is None:
                group_scores = scores.amax(dim=-1)
            else:
                group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
            indices = group_scores.topk(self.topk_groups, dim=-1)[1]
            mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices, True)
            scores = (scores * mask.unsqueeze(-1)).flatten(1)
        indices = torch.topk(scores, self.topk, dim=-1)[1]
        weights = original_scores.gather(1, indices)
        if self.score_func == "sigmoid":
            weights /= weights.sum(dim=-1, keepdim=True)
        weights *= self.route_scale
        return weights.type_as(x), indices


class Expert(nn.Module):
    """
    Expert layer for Mixture-of-Experts (MoE) models.

    Attributes:
        w1 (nn.Module): Linear layer for input-to-hidden transformation.
        w2 (nn.Module): Linear layer for hidden-to-output transformation.
        w3 (nn.Module): Additional linear layer for feature transformation.
    """
    def __init__(self, dim: int, inter_dim: int):
        """
        Initializes the Expert layer.

        Args:
            dim (int): Input and output dimensionality.
            inter_dim (int): Hidden layer dimensionality.
        """
        super().__init__()
        self.w1 = Linear(dim, inter_dim)
        self.w2 = Linear(inter_dim, dim)
        self.w3 = Linear(dim, inter_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for the Expert layer.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor after expert computation.
        """
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class MoE(nn.Module):
    """
    Mixture-of-Experts (MoE) module.

    Attributes:
        dim (int): Dimensionality of input features.
        n_routed_experts (int): Total number of experts in the model.
        n_local_experts (int): Number of experts handled locally in distributed systems.
        n_activated_experts (int): Number of experts activated for each input.
        gate (nn.Module): Gating mechanism to route inputs to experts.
        experts (nn.ModuleList): List of expert modules.
        shared_experts (nn.Module): Shared experts applied to all inputs.
    """
    def __init__(self, args: ModelArgs):
        """
        Initializes the MoE module.

        Args:
            args (ModelArgs): Model arguments containing MoE parameters.
        """
        super().__init__()
        self.dim = args.dim
        assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})"
        self.n_routed_experts = args.n_routed_experts
        self.n_local_experts = args.n_routed_experts // world_size
        self.n_activated_experts = args.n_activated_experts
        self.experts_start_idx = rank * self.n_local_experts
        self.experts_end_idx = self.experts_start_idx + self.n_local_experts
        self.gate = Gate(args)
        self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None
                                      for i in range(self.n_routed_experts)])
        self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for the MoE module.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor after expert routing and computation.
        """
        shape = x.size()
        x = x.view(-1, self.dim)
        weights, indices = self.gate(x)
        y = torch.zeros_like(x)
        counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
        for i in range(self.experts_start_idx, self.experts_end_idx):
            if counts[i] == 0:
                continue
            expert = self.experts[i]
            idx, top = torch.where(indices == i)
            y[idx] += expert(x[idx]) * weights[idx, top, None]
        z = self.shared_experts(x)
        if world_size > 1:
            dist.all_reduce(y)
        return (y + z).view(shape)