In [2]:

import math 
from contextlib import nullcontext 
from typing import Optional 

import torch 
import torch.nn as nn 
import torch.nn.functional as F

In [3]:
def _gelu(x: torch.Tensor) -> torch.Tensor:
    # Slightly faster GELU (approx)
    return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) *
                                       (x + 0.044715 * torch.pow(x, 3))))
    
class ExpertFFN(nn.Module):
    """
    A 2-layer MLP expert. Hidden dim is usually smaller than a dense FFN
    (e.g., 0.25 × d_model in DeepSeek-V3).
    """
    def __init__(self, d_model: int, hidden: int, dropout: float = 0.0):
        super().__init__()
        self.fc1 = nn.Linear(d_model, hidden, bias=False)
        self.fc2 = nn.Linear(hidden, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc2(self.dropout(_gelu(self.fc1(x))))



class DeepSeekMoE(nn.Module):
    """
    DeepSeek-V3 style Mixture-of-Experts (MoE) layer.

    This MoE layer incorporates both routed experts (selected by a router)
    and shared experts (applied to all inputs). It is designed based on
    the architecture described in the DeepSeek-V3 paper.

    Args:
        d_model (int): The dimension of the input and output features.
        n_routed_exp (int): The number of routed experts.
        n_shared_exp (int, optional): The number of shared experts. Defaults to 1.
        top_k (int, optional): The number of routed experts to select for each token.
                               Defaults to 8.
        routed_hidden (int, optional): The hidden dimension for routed experts.
                                      Defaults to 2048.
        shared_hidden (Optional[int], optional): The hidden dimension for shared experts.
                                                If None, uses routed_hidden. Defaults to None.
        bias_lr (float, optional): Learning rate for the router bias (updated online).
                                   Defaults to 0.01.
        fp16_router (bool, optional): Whether to use FP16 precision for router calculations.
                                     Defaults to False.
    """
    def __init__(
        self,
        d_model: int,
        n_routed_exp: int,
        n_shared_exp: int = 1,
        top_k: int = 8,
        routed_hidden: int = 2_048,
        shared_hidden: Optional[int] = None,
        bias_lr: float = 0.01,
        fp16_router: bool = False,
    ):
        super().__init__()
        # Assert that the number of selected experts (top_k) is less than or equal to the total number of routed experts.
        assert top_k <= n_routed_exp, "k must be ≤ number of routed experts"

        self.d_model = d_model
        self.n_routed = n_routed_exp
        self.n_shared = n_shared_exp
        self.top_k = top_k
        self.bias_lr = bias_lr
        self.fp16_router = fp16_router

        # Module list for the routed experts.
        self.routed = nn.ModuleList(
            [ExpertFFN(d_model, routed_hidden) for _ in range(n_routed_exp)]
        )
        # Determine the hidden dimension for shared experts. Use routed_hidden if shared_hidden is not provided.
        hidden_shared = shared_hidden or routed_hidden
        # Module list for the shared experts.
        self.shared = nn.ModuleList(
            [ExpertFFN(d_model, hidden_shared) for _ in range(n_shared_exp)]
        )

        # Register a parameter for the centroids used by the router.
        # Centroids represent the "preference" of each expert for different input features.
        self.register_parameter("centroids", nn.Parameter(torch.empty(n_routed_exp, d_model)))
        # Initialize centroids with a normal distribution.
        nn.init.normal_(self.centroids, std=d_model ** -0.5)

        # Register a buffer for the router bias. This bias is updated online
        # without using standard gradient descent, hence it's not a parameter.
        self.register_buffer("bias", torch.zeros(n_routed_exp))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the DeepSeekMoE layer.

        Args:
            x (torch.Tensor): The input tensor of shape [B, S, D], where B is
                              batch size, S is sequence length, and D is d_model.

        Returns:
            torch.Tensor: The output tensor of shape [B, S, D], which is the
                          sum of the input, shared expert outputs, and routed
                          expert outputs.
        """
        # Get dimensions of the input tensor.
        B, S, D = x.shape
        # Reshape the input to [N, D], where N = B * S (number of tokens).
        x_flat = x.reshape(-1, D)  # [N, D] with N=B*S

        # 1) Shared path: Process the input through all shared experts and sum their outputs.
        shared_out = torch.zeros_like(x)
        for exp in self.shared:
            shared_out += exp(x)
        # (Optional) Scale the shared expert output by the number of shared experts.
        # This can help in balancing the contribution of shared vs. routed experts.
        # shared_out = shared_out / max(1, self.n_shared)

        # 2) Router logits: Calculate the affinity of each token to each routed expert.
        # Use autocasting to FP16 if fp16_router is True and the device is CUDA.
        use_autocast = self.fp16_router and x.is_cuda
        device_type = "cuda" if x.is_cuda else x.device.type
        with torch.autocast(device_type=device_type, enabled=use_autocast):
            # Calculate logits by taking the dot product of the flattened input with the expert centroids.
            logits = F.linear(x_flat, self.centroids)  # [N, E]
            # Add the router bias to the logits. Ensure bias matches the logits' dtype.
            logits = logits + self.bias.to(logits.dtype)

        # Select the top_k experts with the highest logits for each token.
        topk_logits, topk_idx = torch.topk(logits, self.top_k, dim=-1)        # [N, k]
        # Apply softmax to the top_k logits to get gating weights.
        # Ensure the gate weights have the same dtype as the input for subsequent calculations.
        gate = F.softmax(topk_logits, dim=-1, dtype=x.dtype)                   # [N, k]

        # 3) Dispatch per expert: Route tokens to their selected experts and combine outputs.
        routed_out = torch.zeros_like(x_flat)                                   # [N, D]
        # Iterate through each routed expert.
        for i in range(self.n_routed):
            # Create a mask to identify which tokens selected the current expert (expert i).
            mask = (topk_idx == i)
            # Find the indices of the rows (tokens) and columns (which of the top-k) where expert i was selected.
            row_idx, which_k = mask.nonzero(as_tuple=True)                      # 1-D each
            # If no tokens selected this expert, skip.
            if row_idx.numel() == 0:
                continue
            # Select the input tokens that are routed to expert i.
            exp_in = x_flat.index_select(0, row_idx)                            # [Ti, D] where Ti is the number of tokens routed to expert i
            # Pass the selected tokens through the expert's FFN.
            out = self.routed[i](exp_in)                                        # [Ti, D]
            # Get the gating weights for the tokens routed to expert i.
            w = gate[row_idx, which_k].unsqueeze(-1)                            # [Ti, 1]
            # Scale the expert output by the gating weights and add it to the routed_out tensor
            # at the original token positions using index_add_.
            routed_out.index_add_(0, row_idx, out * w)

        # Reshape the routed output back to the original [B, S, D] shape.
        routed_out = routed_out.view(B, S, D)
        # The final output is the sum of the original input, shared expert outputs, and routed expert outputs.
        return x + shared_out + routed_out

    @torch.no_grad()
    def update_bias(self, x: torch.Tensor):
        """
        Updates the router bias based on expert load.

        This method is typically called once per optimizer step using the
        same batch of tokens that were passed through the forward method.
        It uses the current router logits (including the current bias) to
        estimate the load on each expert and adjusts the bias to encourage
        a more balanced distribution of tokens across experts.

        Args:
            x (torch.Tensor): The input tensor of shape [B, S, D], identical
                              to the input used in the corresponding forward pass.
        """
        # Calculate the total number of tokens.
        N = x.shape[0] * x.shape[1]
        # Calculate the router logits (affinity scores) for each token and expert, including the current bias.
        logits = F.linear(x.reshape(-1, self.d_model), self.centroids) + self.bias
        # Determine the top_k experts selected for each token based on the current logits.
        _, idx = torch.topk(logits, self.top_k, dim=-1)

        # Count how many times each expert was selected as one of the top_k.
        counts = torch.bincount(idx.flatten(), minlength=self.n_routed).float()
        # Calculate the average number of times an expert should ideally be selected.
        avg = counts.sum() / max(1, self.n_routed)

        # Calculate the "violation" for each expert. A positive violation means
        # the expert is under-loaded compared to the average, and its bias
        # should be increased to make it more likely to be selected in the future.
        # A negative violation means it's over-loaded, and its bias should be decreased.
        # Add a small epsilon (1e-6) to the denominator to avoid division by zero.
        violation = (avg - counts) / (avg + 1e-6)
        # Update the bias using a smooth, bounded update based on the violation.
        # torch.tanh() squashes the violation into the range [-1, 1], preventing
        # excessively large bias updates. The bias_lr controls the step size.
        self.bias.add_(self.bias_lr * torch.tanh(violation))
     

In [4]:

# Test the DeepSeekMoE class
d_model = 1024
n_routed_exp = 16
n_shared_exp = 2
top_k = 8

model = DeepSeekMoE(d_model, n_routed_exp, n_shared_exp, top_k)

# Create different random input data
batch_size_new = 2
seq_len_new = 64
random_input_new = torch.randn(batch_size_new, seq_len_new, d_model)

# Pass the new random input to the model's forward method
output_new = model(random_input_new)

print("New output shape:", output_new.shape)

New output shape: torch.Size([2, 64, 1024])
