In [5]:
import torch
import torch.nn as nn
import torch.optim as optim

from datasets import load_dataset
from torchtext.vocab import build_vocab_from_iterator
from torch.nn.utils.rnn import pad_sequence

import numpy as np
import random

# --- 1. Load AG News data: train and test splits ---
train_dataset = load_dataset('ag_news', split='train')
test_dataset  = load_dataset('ag_news', split='test')

# --- 2. Build vocabulary using ONLY training data ---
tokenizer = lambda s: s.lower().split()
vocab = build_vocab_from_iterator((tokenizer(x['text']) for x in train_dataset), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

# --- 3. Encode samples for train and test separately ---
def encode(text):
    tokens = tokenizer(text)
    return torch.tensor([vocab[token] for token in tokens][:8])  # seq_len=8

X_train = [encode(sample['text']) for sample in train_dataset]
X_train = pad_sequence(X_train, batch_first=True, padding_value=0)
y_train = torch.tensor([sample['label'] for sample in train_dataset])

X_test = [encode(sample['text']) for sample in test_dataset]
X_test = pad_sequence(X_test, batch_first=True, padding_value=0)
y_test = torch.tensor([sample['label'] for sample in test_dataset])

n_samples_train = len(X_train)
n_samples_test = len(X_test)

# --- 4. Model Definitions ---
class RoutingNetwork(nn.Module):
    def __init__(self, model_dim, n_experts):
        super().__init__()
        self.linear = nn.Linear(model_dim, n_experts)
    def forward(self, x):
        return self.linear(x)  # (batch, n_experts)

class MoEBackbone(nn.Module):
    def __init__(self, vocab_size, model_dim, n_heads=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, model_dim)
        self.pos_embedding = nn.Parameter(torch.randn(1, 10, model_dim))  # Max seq len 10
        encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=n_heads, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)
    def forward(self, x):
        x = self.embedding(x) + self.pos_embedding[:, :x.size(1)]
        out = self.encoder(x)
        return out[:, 0, :]  # Use the first token as representation

class MoEExpert(nn.Module):
    def __init__(self, model_dim, out_dim):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(model_dim, model_dim),
            nn.ReLU(),
            nn.Linear(model_dim, out_dim)
        )
    def forward(self, x):
        return self.ffn(x)

class MoE(nn.Module):
    def __init__(self, n_experts, vocab_size, model_dim, out_dim):
        super().__init__()
        self.backbone = MoEBackbone(vocab_size, model_dim, n_heads=2)
        self.experts = nn.ModuleList([MoEExpert(model_dim, out_dim) for _ in range(n_experts)])
        self.router = RoutingNetwork(model_dim, n_experts)

    def forward(self, x, return_probs=False):
        shared = self.backbone(x)  # (batch, model_dim)
        expert_logits = self.router(shared)  # (batch, n_experts)
        expert_probs = torch.softmax(expert_logits, dim=-1)
        expert_idx = torch.argmax(expert_probs, dim=-1)
        if shared.size(0) == 1:
            idx = expert_idx.item() if isinstance(expert_idx, torch.Tensor) else expert_idx
            out = self.experts[idx](shared)
        else:
            outs = []
            for i in range(shared.size(0)):
                idx = expert_idx[i].item() if isinstance(expert_idx, torch.Tensor) else expert_idx[i]
                outs.append(self.experts[idx](shared[i:i+1]))
            out = torch.cat(outs, dim=0)
        if return_probs:
            return out, expert_probs
        else:
            return out

# --- 5. Training/Evaluation Utilities (updated for train/test split) ---

def train_joint_moe_supervised_router(
    model, X, y, loss_fn, optimizer, epochs=5, n_experts=4,
    lb_lambda=3, router_lambda=1.0, entropy_lambda=0.05, batch_size=32
):
    n_samples = len(X)
    indices = torch.arange(n_samples)
    for epoch in range(epochs):
        total_loss = 0
        total_lb_loss = 0
        total_router_loss = 0
        total_entropy = 0
        model.train()
        routing_counts = [0 for _ in range(n_experts)]
        expert_probs_sum = torch.zeros(n_experts)

        indices = indices[torch.randperm(n_samples)]
        for batch_start in range(0, n_samples, batch_size):
            batch_idx = indices[batch_start:batch_start+batch_size]
            X_batch = X[batch_idx]
            y_batch = y[batch_idx]
            
            shared = model.backbone(X_batch)
            router_logits = model.router(shared)  # (batch, n_experts)
            expert_probs = torch.softmax(router_logits, dim=-1)
            hard_assign = expert_probs.argmax(dim=-1)
            for k in range(n_experts):
                routing_counts[k] += (hard_assign == k).sum().item()
            expert_probs_sum += expert_probs.sum(dim=0).detach()

            # Supervised router loss
            router_loss = nn.CrossEntropyLoss()(router_logits, y_batch)
            
            # Main MoE output (with dynamic routing)
            expert_idx = hard_assign
            outs = []
            for i in range(X_batch.size(0)):
                idx = expert_idx[i].item()
                outs.append(model.experts[idx](shared[i:i+1]))
            out = torch.cat(outs, dim=0)

            ce_loss = loss_fn(out, y_batch)
            probs_mean = expert_probs.mean(dim=0)
            lb_loss = ((probs_mean - 1.0/n_experts) ** 2).sum()
            
            # Entropy regularization (maximize entropy for diversity)
            entropy = -torch.sum(expert_probs * torch.log(expert_probs + 1e-8), dim=1).mean()

            # Total loss
            loss = ce_loss + lb_lambda * lb_loss + router_lambda * router_loss + entropy_lambda * entropy

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += ce_loss.item() * X_batch.size(0)
            total_lb_loss += lb_loss.item() * X_batch.size(0)
            total_router_loss += router_loss.item() * X_batch.size(0)
            total_entropy += entropy.item() * X_batch.size(0)

        print(f"[MoE][Training] Epoch {epoch+1} | CE: {total_loss/n_samples_train:.3f} | LB: {total_lb_loss/n_samples_train:.3f} | Router: {total_router_loss/n_samples_train:.3f} | Entropy: {total_entropy/n_samples_train:.3f}")
        print("[MoE][Training] Hard assignment counts per expert:", routing_counts)
        print("[MoE][Training] Mean softmax probability per expert:", (expert_probs_sum / n_samples_train).tolist())

def evaluate_per_expert_moe(model, X, y, n_experts, batch_size=32):
    results = []
    model.eval()
    expert_assignments = []
    n_samples = len(X)
    with torch.no_grad():
        for expert_id in range(n_experts):
            idxs = (y == expert_id).nonzero(as_tuple=True)[0]
            if len(idxs) == 0:
                results.append(float('nan'))
                continue
            correct = 0
            total = 0
            for batch_start in range(0, len(idxs), batch_size):
                batch_idx = idxs[batch_start:batch_start+batch_size]
                X_batch = X[batch_idx]
                y_batch = y[batch_idx]
                out, expert_probs = model(X_batch, True)
                pred = out.argmax(dim=1)
                routed_expert = expert_probs.argmax(dim=-1)
                expert_assignments += routed_expert.cpu().tolist()
                correct += (pred == y_batch).sum().item()
                total += y_batch.size(0)
            acc = correct / total if total > 0 else 0
            results.append(acc)
    # Print final routing histogram for analysis
    unique, counts = torch.tensor(expert_assignments).unique(return_counts=True)
    dist = {int(u): int(c) for u, c in zip(unique, counts)}
    print(f"[MoE] Final routing distribution: {dist}")
    return results

# --- 6. Confusion Matrix Utility ---
def compute_routing_confusion(model, X, y, n_experts, n_classes=4, batch_size=32):
    import numpy as np
    try:
        from sklearn.metrics import confusion_matrix
    except ImportError:
        confusion_matrix = None
    confusion = torch.zeros(n_classes, n_experts, dtype=torch.long)
    n_samples = len(X)
    model.eval()
    with torch.no_grad():
        for batch_start in range(0, n_samples, batch_size):
            X_batch = X[batch_start:batch_start+batch_size]
            y_batch = y[batch_start:batch_start+batch_size]
            _, expert_probs = model(X_batch, True)
            routed_expert = expert_probs.argmax(dim=-1)
            for i in range(X_batch.size(0)):
                true_label = y_batch[i].item()
                expert = routed_expert[i].item()
                confusion[true_label, expert] += 1
    print("\n[MoE] Confusion Matrix: Rows = True Class, Columns = Routed Expert")
    print(confusion)
    if confusion_matrix is not None:
        y_true = []
        y_pred = []
        for batch_start in range(0, n_samples, batch_size):
            X_batch = X[batch_start:batch_start+batch_size]
            y_batch = y[batch_start:batch_start+batch_size]
            _, expert_probs = model(X_batch, True)
            routed_expert = expert_probs.argmax(dim=-1).cpu().tolist()
            y_true.extend(y_batch.cpu().tolist())
            y_pred.extend(routed_expert)
        print("\n[MoE] Sklearn confusion_matrix (same axes):")
        print(confusion_matrix(y_true, y_pred))

# --- 7. Usage Example: Only train on train set, evaluate on test set ---
n_experts = 4
model_dim = 32
out_dim = 4
vocab_size = len(vocab)

moe_model = MoE(n_experts, vocab_size, model_dim, out_dim)
moe_loss_fn = nn.CrossEntropyLoss()
moe_optimizer = optim.Adam(moe_model.parameters(), lr=1e-3)
epochs = 5

# TRAINING (on training set only)
train_joint_moe_supervised_router(
    moe_model, X_train, y_train, moe_loss_fn, moe_optimizer,
    epochs=epochs, n_experts=n_experts, lb_lambda=3, router_lambda=1.0, batch_size=32
)

# TESTING (on test set only)
moe_acc = evaluate_per_expert_moe(moe_model, X_test, y_test, n_experts)
print("[MoE][testing] Per-expert accuracy on test set:", moe_acc)

# Confusion Matrix (test set)
compute_routing_confusion(moe_model, X_test, y_test, n_experts, n_classes=4, batch_size=32)

# Overall test accuracy
def compute_overall_accuracy(model, X, y, batch_size=32):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for batch_start in range(0, len(X), batch_size):
            X_batch = X[batch_start:batch_start+batch_size]
            y_batch = y[batch_start:batch_start+batch_size]
            out = model(X_batch)
            pred = out.argmax(dim=1)
            correct += (pred == y_batch).sum().item()
            total += y_batch.size(0)
    return correct / total

test_acc = compute_overall_accuracy(moe_model, X_test, y_test)
print("[MoE][testing] Overall test accuracy:", test_acc)


[MoE][Training] Epoch 1 | CE: 0.989 | LB: 0.008 | Router: 0.987 | Entropy: 0.992
[MoE][Training] Hard assignment counts per expert: [30400, 32616, 26932, 30052]
[MoE][Training] Mean softmax probability per expert: [0.24954630434513092, 0.2502956688404083, 0.2502772808074951, 0.24988067150115967]
[MoE][Training] Epoch 2 | CE: 0.620 | LB: 0.014 | Router: 0.618 | Entropy: 0.619
[MoE][Training] Hard assignment counts per expert: [29096, 32075, 28704, 30125]
[MoE][Training] Mean softmax probability per expert: [0.24949970841407776, 0.25036853551864624, 0.25016531348228455, 0.2499663382768631]
[MoE][Training] Epoch 3 | CE: 0.484 | LB: 0.016 | Router: 0.483 | Entropy: 0.481
[MoE][Training] Hard assignment counts per expert: [29009, 31442, 29247, 30302]
[MoE][Training] Mean softmax probability per expert: [0.2496137171983719, 0.25029751658439636, 0.250068724155426, 0.25002002716064453]
[MoE][Training] Epoch 4 | CE: 0.402 | LB: 0.017 | Router: 0.400 | Entropy: 0.398
[MoE][Training] Hard assignm