In [2]:
import torch
from torch import nn
import torch.nn.functional as F

In [3]:
B, N, D = 16, 512, 64
vocab_size = 50012
n_experts = 8
n_heads = 8
expansion = 4

In [4]:
inputs = torch.rand(B, N, D)

In [5]:
class ExpertsLayer(nn.Module):
    def __init__(self, dim, expansion):
        super(ExpertsLayer, self).__init__()
        self.linear1 = nn.Linear(dim, dim*expansion)
        self.relu = nn.GELU()
        self.linear2 = nn.Linear(dim*expansion, dim)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

In [6]:
expert = ExpertsLayer(D, expansion=expansion)
print(expert(inputs).shape)

torch.Size([16, 512, 64])


In [7]:
class GatingLayer(nn.Module):
    def __init__(self, dim, n_experts):
        super(GatingLayer, self).__init__()
        self.gate = nn.Linear(dim, n_experts)

    def forward(self, x):
        return F.softmax(self.gate(x), dim=-1)

In [8]:
gate = GatingLayer(D, n_experts)
print(gate(inputs).shape)

torch.Size([16, 512, 8])


In [9]:
class MoEModule(nn.Module):
    def __init__(self, dim, n_experts, expansion):
        super(MoEModule, self).__init__()
        self.experts = nn.ModuleList([ExpertsLayer(dim, expansion) for _ in range(n_experts)])
        self.gate = GatingLayer(dim, n_experts)

    def forward(self, x):
        experts_output = [expert(x) for expert in self.experts]
        experts_prob = self.gate(x).unsqueeze(-2)
        output = torch.stack([experts_prob[..., i]*expert_output[i] for i, expert_output in enumerate(experts_output)], dim=-1)
        return output.sum(dim=-1)

In [10]:
moe = MoEModule(D, n_experts, expansion=expansion)
print(moe(inputs).shape)

torch.Size([16, 512, 64])


In [13]:
class MoEDecoderBlock(nn.Module):
    def __init__(self, dim, n_heads, n_experts, expansion):
        super(MoEDecoderBlock, self).__init__()
        self.multihead_attn = nn.MultiheadAttention(dim, n_heads)
        self.norm1 = nn.LayerNorm([dim]) # LayerNorm with dropout or RMSNorm without dropout
        self.moe = MoEModule(dim, n_experts, expansion)
        self.norm2 = nn.LayerNorm([dim])

    def forward(self, x):
        attn_output, attn_output_weights = self.multihead_attn(x, x, x)
        att_out = self.norm1(attn_output + x)
        moe_out = self.moe(att_out)
        out = self.norm2(moe_out + x)
        return out

In [14]:
decoder = MoEDecoderBlock(D, n_heads, n_experts, expansion)
print(decoder(inputs).shape)

torch.Size([16, 512, 64])


# Sparse MoE

In [15]:
class SparseMoEModule(nn.Module):
    def __init__(self, dim, n_experts, expansion):
        super(SparseMoEModule, self).__init__()
        self.experts = nn.ModuleList([ExpertsLayer(dim, expansion) for _ in range(n_experts)])
        self.gate = GatingLayer(dim, n_experts)

    def forward(self, x, top_k):
        experts_prob = self.gate(x)
        topk_experts_prob, topk_indices = experts_prob.topk(top_k, dim=-1, sorted=False)
        # Create a mask to zero out the contribution of non-topk experts
        mask = torch.zeros_like(experts_prob).scatter_(2, topk_indices, 1)
        # Use the mask to retain only the topk gating scores
        experts_prob = F.normalize(experts_prob, p=1, dim=-1)

        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=2)
        output = torch.einsum('bte, bteo -> bto', experts_prob, expert_outputs)
        return output

In [16]:
smoe = SparseMoEModule(D, n_experts, expansion)
print(smoe(inputs,3).shape)

torch.Size([16, 512, 64])


In [23]:
class TransformerWithSparseMoE(nn.Module): 
    # only with one MoE layer at the end but not MoE at all transformer blocks
    def __init__(self, n_layers, dim, n_heads, n_experts, expasion, vocab_size):
        super(TransformerWithSparseMoE, self).__init__()
        self.embedding = nn.Embedding(vocab_size, dim)
        self.layers = nn.ModuleList([nn.TransformerEncoderLayer(d_model=dim, nhead=n_heads) for _ in range(n_layers)])
        self.moe_layer = SparseMoEModule(dim, n_experts, expasion)
        self.proj = nn.Linear(dim, vocab_size)

    def forward(self, x, top_k):
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x)
        x = self.moe_layer(x, top_k)
        logits = self.proj(x)
        return logits

In [24]:
# Initialize the model with configurations matching Mixtral 8x7B
model = TransformerWithSparseMoE(
    n_layers=6,              # Number of transformer layers
    dim=D,                   # Dimension of the model
    n_heads=n_heads,         # Dimension of each head in the multi-head attention mechanisms
    n_experts=n_experts,     # Number of experts in the MoE layer
    expasion=expansion,      # Number of times of Expansion
    vocab_size=vocab_size,   # Vocabulary size for the embedding layer
)

In [25]:
tokenized_inputs = torch.randint(0, vocab_size, size=(B, N))

In [26]:
print(model(tokenized_inputs, 3).shape)

torch.Size([16, 512, 50012])
