## Implementation of the paper *Auxiliary-Loss-Free Load Balancing Strategy for Mixture-of-Experts*

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns

## Los-Free Balance

In [4]:
class SwiGLUFFN(nn.Module):
    def __init__(self,input_dim:int,hidden_dim:int):
        """
        Initializes the SwiGLUFFN module.
        
        Args:
            input_dim (int): The dimensionality of the input features.
            hidden_dim (int): The dimensionality of the hidden layer.
            
        Initializes three linear layers:
        - `w_1`: Projects input features to the hidden dimension.
        - `w_2`: Projects input features to the hidden dimension using a separate path.
        - `out`: Projects the transformed hidden representation back to the input dimension.
        """
        super().__init__()
        self.w_1=nn.Linear(input_dim,hidden_dim)
        self.w_2=nn.Linear(input_dim,hidden_dim)
        self.out=nn.Linear(hidden_dim,input_dim)
    def forward(self,x:torch.Tensor):
        """
        Computes the output of the SwiGLUFFN module.
        """
        return self.out(self.w_1(x) * F.silu(self.w_2(x)))


class MOE(nn.Module):
    def __init__(self,input_size,hidden_size,output_size,n_experts,k):
        super().__init__()
        self.input_size=input_size
        self.output_size=output_size
        self.k=k
        self.experts=self.experts=nn.ModuleList([SwiGLUFFN(input_dim=input_size,hidden_dim=hidden_size) for _ in range(n_experts)])
        self.W_router=nn.Linear(in_features=input_size,out_features=n_experts)
        self.bias=nn.Parameter(torch.zeros(n_experts), requires_grad=False)
        self.u=0.1
    
    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        router_scores = self.W_router(x)  # (batch_size, seq_len, num_experts)
        router_scores_biased = router_scores + self.bias.view(1, 1, -1)  
        topk_vals, topk_indices = torch.topk(router_scores_biased, k=self.k, dim=-1)#(batch_size, seq_len, k)
        topk_probs = F.softmax(topk_vals, dim=-1)  #(batch_size, seq_len, k)
        all_expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=0)  #(num_experts, batch_size, seq_len, output_size)
        indices = topk_indices.unsqueeze(-1).expand(-1, -1, -1, self.output_size)  #(batch_size, seq_len, k, output_size)
        expert_outputs = all_expert_outputs.permute(1, 2, 0, 3)  # (batch_size, seq_len, num_experts, output_size)
        expert_outputs = torch.gather(expert_outputs, dim=2, index=indices) 
        final_output = (expert_outputs * topk_probs.unsqueeze(-1)).sum(dim=2)  #(batch_size, seq_len, output_size)

        
        topk_mask=torch.zeros_like(router_scores)
        topk_mask.scatter_(-1,topk_indices,1)
        tokens_per_expert=topk_mask.sum(dim=(0,1))
        mean_load=tokens_per_expert.mean()
        load_violation=mean_load-tokens_per_expert# (num_experts,) 
        self.bias.data+= self.u * torch.sign(load_violation)
        
        return final_output

#### Test moe

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

batch_size = 4
seq_len = 10
input_size = 32
hidden_size = 64
n_experts = 6
top_k = 2

moe = MOE(input_size=input_size, hidden_size=hidden_size, output_size=input_size, n_experts=n_experts, k=top_k)

x = torch.randn(batch_size, seq_len, input_size)

output = moe(x)

print(f"Input shape:  {x.shape}")     
print(f"Output shape: {output.shape}") 

Input shape:  torch.Size([4, 10, 32])
Output shape: torch.Size([4, 10, 32])


## Attention

In [6]:
import torch
import torch.nn as nn
import math
import copy
import torch.nn.functional as F
import torch.optim as optim


device="cuda" if torch.cuda.is_available() else "cpu"
def clones(module,N):
    """
    Create a list of N identical layers.

    Args:
        module (nn.Module): A neural network module to be cloned.
        N (int): The number of clones to create.

    Returns:
        nn.ModuleList: A list containing N deep copies of the input module.
    """
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

def attention(query,key,value,mask=None,dropout=None):
  
  "mask: binary mask of 0 and 1s, 1-> allowed , 0->mask it"
  #query, key, value -> N,h,T,d_k
  d_k=query.size(-1)
  scores=torch.matmul(query,key.transpose(-1,-2))/math.sqrt(d_k) # N,h,T,d_k @ N,h,d_k,T = N,h,T,T
  if mask is not None:
    scores=scores.masked_fill(mask==0,-1e9)

  p_attn=scores.softmax(dim=-1)
  if dropout is not None:
    p_attn=dropout(p_attn)

  return torch.matmul(p_attn,value), p_attn


class MultiHeadedAttention(nn.Module):
  def __init__(self,h,d_model,dropout=0.1):
    """
    Create a MultiHeadedAttention layer.

    Args:
        h (int): The number of heads in the multi-head attention mechanism.
        d_model (int): The number of expected features in the input.
        dropout (float, optional): The dropout to apply to the attention weights. Defaults to 0.1.
    """

    super(MultiHeadedAttention,self).__init__()

    self.d_k=d_model//h
    self.h=h
    self.linears=clones(nn.Linear(d_model, d_model), 4)
    self.dropout=nn.Dropout(p=dropout)

  def forward(self,query,key,value,mask=None):


    """
    Compute the forward pass of the multi-headed attention layer.

    Args:
        query (torch.Tensor): The query tensor.
        key (torch.Tensor): The key tensor.
        value (torch.Tensor): The value tensor.
        mask (torch.Tensor, optional): The mask to apply to the attention weights. Defaults to None.

    Returns:
        torch.Tensor: The output of the forward pass.
    """
    if mask is not None:
      mask=mask.unsqueeze(1)
      # print(mask.shape)

    nbatches=query.shape[0]

    query,key,value=[
        lin(x).view(nbatches,-1,self.h,self.d_k).transpose(1,2)
        for lin,x in zip(self.linears,(query,key,value))
    ]

    x,self_attn=attention(query,key,value,mask=mask,dropout=self.dropout)

    x=(x.transpose(1,2).contiguous().view(nbatches,-1,self.h*self.d_k))

    del query
    del key
    del value

    return self.linears[-1](x)

## Padding

In [7]:
def pad_mask(x, pad_token):
    """Create a mask to ignore padding tokens.
    x: (N, T)
    Returns: (N, 1, T)
    """
    return (x != pad_token).unsqueeze(-2)  # N, 1, T

def causal_mask(size):
    """Create a causal mask to prevent attending to future tokens.
    Returns: (1, T, T)
    """
    attn_shape = (1, size, size)
    mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
        torch.uint8
    )
    return mask == 0  # 1 -> valid, 0 -> masked

def combined_mask(x, pad_token):
    """Combine padding mask and causal mask.
    x: (N, T)
    Returns: (N, T, T)
    """
    N, T = x.shape



    # Create padding mask: (N, 1, T)
    pad_mask_ = pad_mask(x, pad_token)

    # Create causal mask: (1, T, T)
    causal_mask_ = causal_mask(T)

    # Expand causal mask to match batch size: (N, T, T)
    causal_mask_ = causal_mask_.expand(N, -1, -1)

    # Apply padding mask to the causal mask
    # For each query position (row), if it's a padding token, the entire row should be False
    combined_mask_ = pad_mask_.transpose(-1, -2) & causal_mask_

    return combined_mask_

## Layer

In [10]:
class SubLayerConnection(nn.Module):
  def __init__(self,d_model,dropout):
    """
  Initialize a sublayer connection layer.
  
  Parameters:
  d_model (int): The number of expected features in the input.
  dropout (float): The amount of dropout to apply.
    """
    super(SubLayerConnection,self).__init__()
    self.norm=nn.LayerNorm(d_model)
    self.dropout=nn.Dropout(dropout)

  def forward(self,x,sublayer):

    return x + self.dropout(sublayer(self.norm(x)))


class Layer(nn.Module):

  def __init__(self,self_attn,dropout,n_experts=8,k=2,hidden=768*4,h=12,d_model=768):

    """
  Initialize a GPT layer.
  
  Parameters:
  self_attn (nn.Module): A multi-headed self-attention layer.
  feed_forward (nn.Module): A MOE feed-forward layer.
  dropout (float): The amount of dropout to apply.
  hidden (int): The number of neurons in the hidden layer of the feed-forward layer. Defaults to 768*4.
  h (int): The number of attention heads. Defaults to 12.
  d_model (int): The number of expected features in the input. Defaults to 768.
  """
    super().__init__()
    self.self_attn=self_attn
    self.h=h
    self.feed_forward=MOE(input_size=d_model,hidden_size=hidden,output_size=d_model,n_experts=n_experts,k=k)
    self.sublayer=clones(SubLayerConnection(d_model,dropout),2)
    self.d_model=d_model


  def forward(self,x,mask=None):
    if mask is not None:
      x=self.sublayer[0](x, lambda x:self.self_attn(x,x,x,mask))
    else:
      x=self.sublayer[0](x, lambda x:self.self_attn(x,x,x,mask=None))
    return self.sublayer[1](x, self.feed_forward)

##### test cases for layer

In [18]:

batch_size, seq_len, d_model = 2, 3, 4
x = torch.randn(batch_size, seq_len, d_model)
mask = None

self_attn = MultiHeadedAttention(h=4,d_model=4)
layer_test = Layer(d_model=d_model, self_attn=self_attn, dropout=0.1)

output = layer_test(x, mask)
print(output.shape)
assert output.shape == x.shape, "❌ Layer output shape mismatch"


batch_size, seq_len, d_model = 2, 4, 8
pad_token = 0

x = torch.randn(batch_size, seq_len, d_model)
token_ids = torch.randint(1, 10, (batch_size, seq_len))  # Simulated token indices
mask = combined_mask(token_ids, pad_token)

self_attn = MultiHeadedAttention(h=4,d_model=d_model)
layer = Layer(d_model=d_model, self_attn=self_attn, dropout=0.1)

output = layer(x, mask)

assert output.shape == x.shape, "❌ Layer output shape mismatch"
assert not torch.isnan(output).any(), "❌ Output contains NaNs"

torch.Size([2, 3, 4])


## Model

In [None]:
class Model(nn.Module):
    def __init__(self, layer, vocab=50257, max_seq_len=1024, N=12, d_model=768, h=12, k=2, n_experts=8, hidden=768*4, dropout=0.1):
        """
        Initialize the model.

        Args:
            layer (nn.Module): The type of layer to use in the model.
            vocab (int): The size of the input vocabulary.
            max_seq_len (int): The maximum sequence length supported by the model. Defaults to 1024.
            N (int): The number of layers to use. Defaults to 12.
            d_model (int): The number of features in the input and output. Defaults to 768.
            h (int): The number of attention heads to use. Defaults to 12.
            hidden (int): The number of neurons in the hidden layer of the feed-forward layer. Defaults to 768 * 4.
            dropout (float): The dropout probability to use. Defaults to 0.1.
        """
        super().__init__()
        self.layers = clones(layer, N)  # clones
        self.norm = nn.LayerNorm(d_model)  # Fix: Use d_model, not layer.size
        self.embedding = nn.Embedding(vocab, d_model)
        self.d_model = d_model
        self.pos = nn.Embedding(max_seq_len, d_model)
        self.dropout = nn.Dropout(dropout)
        self.lm_head=nn.Linear(d_model,vocab)
        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        """Initialize weights for the model."""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                module.weight.data.normal_(mean=0.0, std=0.02)  # Normal init with std=0.02
                if module.bias is not None:
                    module.bias.data.zero_()  # Bias = 0

            elif isinstance(module, nn.Embedding):
                module.weight.data.normal_(mean=0.0, std=0.02)  # Normal init with std=0.02

            elif isinstance(module, nn.LayerNorm):
                module.weight.data.fill_(1.0)  # LayerNorm weight = 1
                module.bias.data.zero_()  # LayerNorm bias = 0

    def forward(self, x, mask=None):
        N, T = x.shape
        positions = torch.arange(0, T, device=x.device).unsqueeze(0)  # (1, T)
        token_embeddings = self.embedding(x)  # (N, T, d_model)
        position_embeddings = self.pos(positions)  # (1, T, d_model)
        x = token_embeddings + position_embeddings  # (N, T, d_model)
        x = self.dropout(x)  # Apply dropout

        for layer in self.layers:
            x = layer(x, mask) if mask is not None else layer(x)
        
        logits=self.lm_head(self.norm(x)) 
        print('logit_shape',logits.shape) # Apply final LayerNorm
        return logits


In [19]:

batch_size, seq_len, d_model = 2, 4, 768  # Matching d_model init
pad_token = 0
vocab=1000
x = torch.randint(0, 1024, (batch_size, seq_len))
token_ids = torch.randint(1, 10, (batch_size, seq_len))  # No padding
mask = combined_mask(token_ids, pad_token)
self_attn = MultiHeadedAttention(h=4, d_model=d_model)
layer = Layer(d_model=d_model, self_attn=self_attn, dropout=0.1,n_experts=8,k=2,hidden=768*4)

model = Model(layer=layer, N=8,vocab=vocab, d_model=d_model)
output = model(x, mask=mask)
print(output.shape)
assert output.shape == (batch_size, seq_len, vocab), "❌ Output shape mismatch"
assert not torch.isnan(output).any(), "❌ Output contains NaNs"

print("Passed!")


logit_shape torch.Size([2, 4, 1000])
torch.Size([2, 4, 1000])
Passed!
