# 08: Train Decoder-Only Model with Mixture of Experts (MoE)
In this notebook, we demonstrate how to train a GPT-style decoder model using Mixture of Experts (MoE) layers.

**Key features:**
- Multi-domain synthetic dataset (poetry, news, code, dialog)
- MoE used in place of standard FFN blocks
- Gating mechanism learns to route token representations to different experts

In [None]:
!pip install torch transformers

In [None]:
import sys
import os
sys.path.append(os.path.abspath(".."))

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
import random

## Create synthetic multi-domain dataset

In [None]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

domains = {
    "poetry": "Shall I compare thee to a summer's day?",
    "news": "The central bank announced an increase in interest rates.",
    "code": "def add(a, b): return a + b",
    "dialog": "Hey! How are you doing today?"
}

samples = []
for domain, prompt in domains.items():
    for _ in range(300):
        samples.append(f"<{domain}> {prompt}")
random.shuffle(samples)
text = "\n".join(samples)
tokens = tokenizer.encode(text, add_special_tokens=False)

## Dataset and DataLoader

In [None]:
class TextDataset(Dataset):
    def __init__(self, tokens, block_size):
        self.examples = [
            torch.tensor(tokens[i:i+block_size+1])
            for i in range(len(tokens) - block_size - 1)
        ]
    def __len__(self):
        return len(self.examples)
    def __getitem__(self, idx):
        data = self.examples[idx]
        return data[:-1], data[1:]

block_size = 64
dataset = TextDataset(tokens, block_size)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

## MoE Layer Implementation

In [None]:
class TopKGate(nn.Module):
    def __init__(self, input_dim, num_experts, k=2):
        super().__init__()
        self.k = k
        self.w_gating = nn.Linear(input_dim, num_experts)

    def forward(self, x):
        gate_logits = self.w_gating(x)  # [B, T, E]
        topk_vals, topk_idx = torch.topk(gate_logits, self.k, dim=-1)
        topk_weights = F.softmax(topk_vals, dim=-1)  # [B, T, K]
        return topk_weights, topk_idx

class MoELayer(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_experts, k=2):
        super().__init__()
        self.num_experts = num_experts
        self.k = k
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, input_dim)
            ) for _ in range(num_experts)
        ])
        self.gate = TopKGate(input_dim, num_experts, k)

    def forward(self, x):
        bsz, seq_len, dim = x.shape
        weights, indices = self.gate(x)  # [B, T, K], [B, T, K]
        x_flat = x.view(-1, dim)
        output = torch.zeros_like(x)

        for i in range(self.k):
            idx = indices[:, :, i]  # [B, T]
            mask = F.one_hot(idx, num_classes=self.num_experts).float()  # [B, T, E]
            mask = mask.permute(2, 0, 1)  # [E, B, T]
            for e in range(self.num_experts):
                selected = mask[e] > 0  # [B, T]
                if selected.any():
                    selected_flat = selected.view(-1)
                    input_sel = x_flat[selected_flat]
                    output_sel = self.experts[e](input_sel)
                    scaled_output = weights[:, :, i].reshape(-1)[selected_flat].unsqueeze(1) * output_sel
                    output.view(-1, dim)[selected_flat] += scaled_output
        return output

## Define the Decoder Model with MoE

In [None]:
class MoETransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, num_experts=4, k=2):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.ff = MoELayer(embed_dim, hidden_dim=embed_dim * 4, num_experts=num_experts, k=k)
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.ln1(x + self.dropout(attn_out))
        x = self.ln2(x + self.dropout(self.ff(x)))
        return x

class MoEGPTDecoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, depth, heads, max_len):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Parameter(torch.zeros(1, max_len, embed_dim))
        self.blocks = nn.ModuleList([
            MoETransformerBlock(embed_dim, heads) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.lm_head = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        x = self.token_embed(x) + self.pos_embed[:, :x.size(1)]
        for block in self.blocks:
            x = block(x)
        return self.lm_head(self.norm(x))

## Train the MoE Model

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = MoEGPTDecoder(
    vocab_size=len(tokenizer),
    embed_dim=512,
    depth=4,
    heads=8,
    max_len=block_size
).to(device)

optimizer = optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

for epoch in range(3):
    model.train()
    total_loss = 0
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1} Loss: {total_loss / len(dataloader):.4f}")

torch.save(model.state_dict(), "moe_decoder_trained.pt")
print("✅ MoE model trained and saved.")