Given a vector of input scores (logits) $ \mathbf{z} = (z_1, z_2, ..., z_d) \in \mathbb{R}^d $, the α-entmax transformation computes a probability distribution $ \mathbf{p} = \alpha\text{-entmax}(\mathbf{z}) $ where $ \mathbf{p} \in \Delta^d $ (the probability simplex).

For $ \alpha \ge 1 $, the formula for the $ i $-th component of the output probability vector $ \mathbf{p} $ is:

$ p_i = \alpha\text{-entmax}_i(\mathbf{z}) = \left[ (\alpha - 1) z_i - \tau(\mathbf{z}) \right]_+^{1/(\alpha-1)} $

Where:

$ \alpha $: The hyperparameter controlling the sparsity. $ \alpha=1 $ recovers Softmax, $ \alpha=2 $ recovers Sparsemax. For $ \alpha > 1 $, the transformation can produce sparse outputs (exact zeros). Your code tested $ \alpha = 1.0, 1.5, 2.0 $.
$ z_i $: The $ i $-th input score (logit).
$ [x]_+ $: This denotes the positive part function (ReLU), i.e., $ [x]_+ = \max{x, 0} $. This is what allows the output probabilities to be exactly zero for low scores when $ \alpha > 1 $.
$ \tau(\mathbf{z}) $: This is a threshold value (a Lagrange multiplier) that is determined uniquely for each input vector $ \mathbf{z} $ such that the resulting vector $ \mathbf{p} $ sums to 1, i.e., $ \sum_{j=1}^d p_j = 1 $. Finding this $ \tau $ efficiently is often the main challenge in computing α-entmax, and methods like bisection search (as used in entmax_bisect) are employed.
In simple terms:

The input logits $ z_i $ are scaled by $ (\alpha - 1) $.
A threshold $ \tau $ (specific to the input vector $ \mathbf{z} $) is subtracted.
The result is passed through a ReLU function $ [\cdot]_+ $ to zero out negative values.
The non-zero results are raised to the power of $ 1/(\alpha-1) $.
The threshold $ \tau $ is chosen precisely so that these final values sum to 1, forming a valid (potentially sparse) probability distribution.


Metrics

## Result and Analysis (Run result below)

| Alpha ($ \alpha $) | Temperature (T) | Val Loss | Utilization | Imbalance | Concentration | TPE CV | Avg Prob CV |
|--------------------|-----------------|----------|-------------|-----------|---------------|--------|-------------|
| 1                  | 0.5             | 2.3981   | 1           | 1         | 0.125         | 0      | 0           |
| 1                  | 1               | 1.618    | 1           | 1         | 0.125         | 0      | 0           |
| 1                  | 10              | 1.5866   | 1           | 1         | 0.125         | 0      | 0           |
| 1.5                | 0.5             | 1.5641   | 1           | 1.6958    | 0.8997        | 0.3955 | 0.5014      |
| 1.5                | 1               | 1.5604   | 1           | 1.2568    | 0.7286        | 0.1769 | 0.2747      |
| 1.5                | 10              | 1.5752   | 1           | 1.1047    | 0.4309        | 0.0793 | 0.2631      |
| 2                  | 0.5             | 1.5702   | 1           | 2.0718    | 0.9728        | 0.5883 | 0.6088      |
| 2                  | 1               | 1.5624   | 1           | 1.7956    | 0.8461        | 0.4227 | 0.4705      |
| 2                  | 10              | 1.5622   | 1           | 1.188     | 0.437         | 0.133  | 0.1614      |

### Best Performance (Validation Loss):

- The lowest validation loss (best performance) was achieved with alpha=1.5, temperature=1.0 (Val Loss: 1.5604).
- Two other configurations were very close:
    - alpha=2.0, temperature=10.0 (Val Loss: 1.5622)
    - alpha=2.0, temperature=1.0 (Val Loss: 1.5624)
- The models with alpha=1.0 performed significantly worse, especially at low temperature (temp=0.5) -> the standard Softmax (alpha=1.0) routing wasn't optimal for this setup.


### Effect of Alpha ($ \alpha $): (Controls sparsity function shape; higher alpha promotes sparser outputs)


- As expected, increasing alpha generally increased Concentration (average max probability per token) and Imbalance (ratio of max expert load to mean expert load). This indicates sparser, more focused expert selection, but potentially worse load balancing.

At alpha=1.0 (Softmax), Concentration was minimal (0.125, i.e., 1/num_experts), and Imbalance was perfect (1.0). This means Softmax distributed tokens evenly but didn't achieve any routing sparsity.
alpha=2.0 produced the highest concentration and imbalance, especially at lower temperatures. This indicates very sparse routing (most probability mass on one expert), leading to higher load imbalance.

### Effect of Temperature (T): (Scales logits before sparsity function; lower T sharpens, higher T softens)

- Lower temperatures (T=0.5) significantly increased Concentration and Imbalance for alpha > 1.0. This sharpening effect combined with sparser functions (alpha=1.5 or 2.0) led to very high concentration and poor load balancing.
- Higher temperatures (T=10.0) softened the distributions, resulting in lower Concentration and Imbalance compared to T=1.0 or T=0.5 for the same alpha.
- Interestingly, for alpha=1.0 (Softmax), temperature had no effect on the final metrics. This is because Softmax inherently produces dense outputs regardless of temperature scaling (unless T is extremely low).
- The very poor performance of alpha=1.0, temp=0.5 might be due to the temperature being too low, causing instability or difficulty during training when combined with standard Softmax.


### MoE Metrics (Utilization, CVs):

* Utilization was 1.0 (100%) for almost all runs, meaning all experts received at least some tokens during evaluation. This is good.
* TPE CV (Coefficient of Variation for Tokens Per Expert) and Avg Prob CV (CV for Average Expert Probability) generally followed the trends of Imbalance and Concentration. Higher values indicate more uneven distribution of tokens or routing probabilities across experts.
* The lowest CV values (most balanced) were for alpha=1.0 (perfectly balanced) and generally increased with alpha and decreased with temperature.


#### Some Trade-offs:

* Best Performance: The sweet spot for validation loss seems to be around alpha=1.5, temp=1.0 or alpha=2.0 with temp=1.0 or temp=10.0.
* **Sparsity vs. Performance**: Using alpha=1.5 or alpha=2.0 clearly improved performance over standard Softmax (alpha=1.0). This indicates that sparse routing was beneficial.
* **Temperature Tuning**: Temperature is crucial.
    * Very low temperatures (T=0.5) combined with alpha > 1.0 led to extremely high concentration and imbalance, which slightly hurt performance compared to T=1.0.
    * High temperature (T=10.0) seemed effective at mitigating the imbalance caused by high alpha (alpha=2.0, temp=10.0 performed well).
* **Load Balancing**: There's a trade-off. Achieving the best validation loss involved accepting some degree of imbalance (values around 1.2-1.8 for the top models, compared to 1.0 for Softmax).
* The configuration with the highest imbalance (alpha=2.0, temp=0.5) did not have the best loss.
* The combination alpha=1.5, temperature=1.0 stands out as the best performer with reasonably balanced MoE metrics compared to the other high-performing, high-alpha configurations.


In [None]:
input_path = '/kaggle/input/input-txt/input.txt'

with open(input_path, 'r', encoding='utf-8') as f:
    text = f.read()

## Experts and Routers

In [None]:
def setup_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import time
import math
import torch
import torch.nn as nn
from entmax import entmax_bisect
from sparsemax import Sparsemax
from torch.nn import functional as F
from torch.nn import init

sns.set_palette('pastel')
sns.set_style('whitegrid')
import numpy as np
batch_size = 128
block_size = 32
max_iters = 3000
eval_interval = 100
learning_rate = 1e-3
eval_iters = 400
head_size = 16
n_embed = 128
n_head = 8
n_layer = 8
dropout = 0.1
num_experts = 8
setup_seed
device = 'cuda' if torch.cuda.is_available() else (
    'mps' if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
import random
setup_seed()

In [None]:


chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]
print(f"Vocab size: {vocab_size}\tData length: {len(text)}")
print(f"Train data length: {len(train_data)}\tValidation data length: {len(val_data)}")


def get_batch(split):
    data_source = train_data if split == 'train' else val_data
    ix = torch.randint(len(data_source) - block_size, (batch_size,))
    x = torch.stack([data_source[i:i + block_size] for i in ix])
    y = torch.stack([data_source[i + 1:i + block_size + 1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y


class Head(nn.Module):
    def __init__(self, n_embed, head_size, block_size, dropout):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)
        self.head_size = head_size

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2, -1) * self.head_size ** -0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        v = self.value(x)
        out = wei @ v
        return out


class MultiHeadAttention(nn.Module):
    def __init__(self, n_embed, num_heads, block_size, dropout):
        super().__init__()
        assert n_embed % num_heads == 0
        head_size = n_embed // num_heads
        self.heads = nn.ModuleList([Head(n_embed, head_size, block_size, dropout) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out


class Expert(nn.Module):
    def __init__(self, n_embed, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed), nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed), nn.Dropout(dropout),
        )

    def forward(self, x): return self.net(x)


class AlphaEntmaxRouter(nn.Module):
    def __init__(self, n_embed, num_experts, alpha=1.5, temperature=1.0, n_iter=50):
        super(AlphaEntmaxRouter, self).__init__()
        assert temperature > 1e-9
        self.num_experts = num_experts
        self.alpha = alpha
        self.temperature = temperature
        self.n_iter = n_iter
        self.route_linear = nn.Linear(n_embed, num_experts)
        print(f"Initialized AlphaEntmaxRouter (functional) with alpha = {self.alpha}, temperature = {self.temperature}")

    def forward(self, mh_output):
        logits = self.route_linear(mh_output)
        scaled_logits = logits / self.temperature
        try:
            router_output = entmax_bisect(scaled_logits, alpha=self.alpha, dim=-1, n_iter=self.n_iter)
        except Exception as e:
            print(f"WARNING: entmax_bisect forward failed. Error: {e}. Falling back to Softmax.")
            router_output = F.softmax(scaled_logits, dim=-1)
        return router_output


class SparseMoE(nn.Module):
    def __init__(self, n_embed, num_experts, dropout, router_alpha=1.5, router_temperature=1.0):
        super(SparseMoE, self).__init__()
        self.router = AlphaEntmaxRouter(n_embed, num_experts, alpha=router_alpha, temperature=router_temperature)
        self.experts = nn.ModuleList([Expert(n_embed, dropout) for _ in range(num_experts)])
        self.num_experts = num_experts
        self.register_buffer('latest_tpe', torch.zeros(num_experts, dtype=torch.float32))
        self.latest_imbalance = 0.0
        self.latest_concentration = 0.0
        self.latest_utilization = 0.0
        self.latest_tpe_cv = 0.0
        self.latest_avg_prob_cv = 0.0

    def forward(self, x):
        batch_size, seq_len, n_embed = x.shape
        num_tokens = batch_size * seq_len
        gating_output = self.router(x)
        with torch.no_grad():
            flat_gating_output_no_grad = gating_output.reshape(-1, self.num_experts).detach()
            is_active = flat_gating_output_no_grad > 1e-9
            tpe = is_active.sum(dim=0).float()
            self.latest_tpe = tpe
            if self.num_experts > 1:
                mean_tpe = tpe.mean()
                std_tpe = tpe.std(unbiased=False)
                self.latest_tpe_cv = (std_tpe / (mean_tpe + 1e-9)).item() if mean_tpe > 1e-9 else 0.0
            else:
                self.latest_tpe_cv = 0.0
            mean_tpe_val = tpe.mean().item()
            if self.num_experts > 0 and mean_tpe_val > 0:
                self.latest_imbalance = tpe.max().item() / (mean_tpe_val + 1e-9)
            else:
                self.latest_imbalance = 1.0
            max_p_per_token, _ = flat_gating_output_no_grad.max(dim=-1)
            self.latest_concentration = max_p_per_token.mean().item() if num_tokens > 0 else 0.0
            num_active_experts = (tpe > 0).sum().item()
            self.latest_utilization = num_active_experts / self.num_experts if self.num_experts > 0 else 0.0
            if self.num_experts > 1 and num_tokens > 0:
                avg_prob = flat_gating_output_no_grad.mean(dim=0)
                mean_avg_prob = avg_prob.mean()
                std_avg_prob = avg_prob.std(unbiased=False)
                self.latest_avg_prob_cv = (
                            std_avg_prob / (mean_avg_prob + 1e-9)).item() if mean_avg_prob > 1e-9 else 0.0
            else:
                self.latest_avg_prob_cv = 0.0
        flat_x = x.reshape(-1, n_embed)
        final_output = torch.zeros_like(flat_x)
        flat_gating_output_for_weighting = gating_output.reshape(-1, self.num_experts)
        for i, expert in enumerate(self.experts):
            expert_scores = flat_gating_output_for_weighting[:, i]
            token_indices = torch.nonzero(flat_gating_output_no_grad[:, i] > 1e-9).squeeze(-1)
            if token_indices.numel() == 0: continue
            expert_input = flat_x[token_indices]
            active_gating_scores = expert_scores[token_indices].unsqueeze(1)
            expert_output = expert(expert_input)
            weighted_output = expert_output * active_gating_scores
            final_output.index_add_(0, token_indices, weighted_output)
        final_output = final_output.view(batch_size, seq_len, n_embed)
        return final_output


class Block(nn.Module):
    def __init__(self, n_embed, n_head, num_experts, block_size, dropout, router_alpha=1.5, router_temperature=1.0):
        super().__init__()
        assert n_embed % n_head == 0
        self.sa = MultiHeadAttention(n_embed, n_head, block_size, dropout)
        self.smoe = SparseMoE(n_embed, num_experts, dropout, router_alpha=router_alpha,
                              router_temperature=router_temperature)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.smoe(self.ln2(x))
        return x


class SparseMoELanguageModel(nn.Module):
    def __init__(self, vocab_size, n_embed, n_head, n_layer, num_experts, block_size, dropout, router_alpha=1.5,
                 router_temperature=1.0):
        super().__init__()
        self.n_embed = n_embed
        self.block_size = block_size
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(*[
            Block(n_embed=n_embed, n_head=n_head, num_experts=num_experts, block_size=block_size, dropout=dropout,
                  router_alpha=router_alpha, router_temperature=router_temperature) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embed)
        self.lm_head = nn.Linear(n_embed, vocab_size)
        print(f"Initialized SparseMoELanguageModel with {n_layer} layers.")
        print(f"Router settings per layer: alpha={router_alpha}, temperature={router_temperature}")

    def forward(self, idx, targets=None):
        B, T = idx.shape
        device = idx.device
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            B, T, C = logits.shape
            logits_for_loss = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits_for_loss, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx


@torch.no_grad()
def estimate_loss(model):
    out = {'loss': {}}
    num_model_layers = n_layer
    accumulated_metrics = {
        f'L{i}': {'imbalance': 0.0, 'concentration': 0.0, 'utilization': 0.0, 'tpe_cv': 0.0, 'avg_prob_cv': 0.0,
                  'count': 0}
        for i in range(num_model_layers)
    }
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for i in range(num_model_layers):
            for key in accumulated_metrics[f'L{i}']:
                accumulated_metrics[f'L{i}'][key] = 0.0 if key != 'count' else 0
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
            if hasattr(model, 'blocks') and isinstance(model.blocks, nn.Sequential):
                for i, block in enumerate(model.blocks):
                    if isinstance(block, Block) and hasattr(block, 'smoe') and isinstance(block.smoe, SparseMoE):
                        smoe_layer = block.smoe
                        layer_key = f'L{i}'
                        if num_experts > 0:
                            accumulated_metrics[layer_key]['imbalance'] += smoe_layer.latest_imbalance
                            accumulated_metrics[layer_key]['concentration'] += smoe_layer.latest_concentration
                            accumulated_metrics[layer_key]['utilization'] += smoe_layer.latest_utilization
                            accumulated_metrics[layer_key]['tpe_cv'] += smoe_layer.latest_tpe_cv
                            accumulated_metrics[layer_key]['avg_prob_cv'] += smoe_layer.latest_avg_prob_cv
                            accumulated_metrics[layer_key]['count'] += 1
        out['loss'][split] = losses.mean()
    averaged_metrics = {f'L{i}': {} for i in range(num_model_layers)}
    for i in range(num_model_layers):
        layer_key = f'L{i}'
        count = accumulated_metrics[layer_key]['count']
        if count > 0:
            for key in accumulated_metrics[layer_key]:
                if key != 'count':
                    averaged_metrics[layer_key][key] = accumulated_metrics[layer_key][key] / count
        else:
            for key in accumulated_metrics[layer_key]:
                if key != 'count': averaged_metrics[layer_key][key] = 0.0
    model.train()
    return out['loss'], averaged_metrics

In [None]:
setup_seed()

In [None]:

def get_batch(split):
    data_source = train_data if split == 'train' else val_data
    ix = torch.randint(len(data_source) - block_size, (batch_size,))
    x = torch.stack([data_source[i:i + block_size] for i in ix])
    y = torch.stack([data_source[i + 1:i + block_size + 1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y


class Head(nn.Module):
    def __init__(self, n_embed, head_size, block_size, dropout):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)
        self.head_size = head_size

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2, -1) * self.head_size ** -0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        v = self.value(x)
        out = wei @ v
        return out


class MultiHeadAttention(nn.Module):
    def __init__(self, n_embed, num_heads, block_size, dropout):
        super().__init__()
        assert n_embed % num_heads == 0
        head_size = n_embed // num_heads
        self.heads = nn.ModuleList([Head(n_embed, head_size, block_size, dropout) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out


class Expert(nn.Module):
    def __init__(self, n_embed, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed), nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed), nn.Dropout(dropout),
        )

    def forward(self, x): return self.net(x)


class AlphaEntmaxRouter(nn.Module):
    def __init__(self, n_embed, num_experts, alpha=1.5, temperature=1.0, n_iter=50):
        super(AlphaEntmaxRouter, self).__init__()
        assert temperature > 1e-9
        self.num_experts = num_experts
        self.alpha = alpha
        self.temperature = temperature
        self.n_iter = n_iter
        self.route_linear = nn.Linear(n_embed, num_experts)
        print(f"Initialized AlphaEntmaxRouter (functional) with alpha = {self.alpha}, temperature = {self.temperature}")

    def forward(self, mh_output):
        logits = self.route_linear(mh_output)
        scaled_logits = logits / self.temperature
        try:
            router_output = entmax_bisect(scaled_logits, alpha=self.alpha, dim=-1, n_iter=self.n_iter)
        except Exception as e:
            print(f"WARNING: entmax_bisect forward failed. Error: {e}. Falling back to Softmax.")
            router_output = F.softmax(scaled_logits, dim=-1)
        return router_output


class SparseMoE(nn.Module):
    def __init__(self, n_embed, num_experts, dropout, router_alpha=1.5, router_temperature=1.0):
        super(SparseMoE, self).__init__()
        self.router = AlphaEntmaxRouter(n_embed, num_experts, alpha=router_alpha, temperature=router_temperature)
        self.experts = nn.ModuleList([Expert(n_embed, dropout) for _ in range(num_experts)])
        self.num_experts = num_experts
        self.register_buffer('latest_tpe', torch.zeros(num_experts, dtype=torch.float32))
        self.latest_imbalance = 0.0
        self.latest_concentration = 0.0
        self.latest_utilization = 0.0
        self.latest_tpe_cv = 0.0
        self.latest_avg_prob_cv = 0.0

    def forward(self, x):
        batch_size, seq_len, n_embed = x.shape
        num_tokens = batch_size * seq_len
        gating_output = self.router(x)
        with torch.no_grad():
            flat_gating_output_no_grad = gating_output.reshape(-1, self.num_experts).detach()
            is_active = flat_gating_output_no_grad > 1e-9
            tpe = is_active.sum(dim=0).float()
            self.latest_tpe = tpe
            if self.num_experts > 1:
                mean_tpe = tpe.mean()
                std_tpe = tpe.std(unbiased=False)
                self.latest_tpe_cv = (std_tpe / (mean_tpe + 1e-9)).item() if mean_tpe > 1e-9 else 0.0
            else:
                self.latest_tpe_cv = 0.0
            mean_tpe_val = tpe.mean().item()
            if self.num_experts > 0 and mean_tpe_val > 0:
                self.latest_imbalance = tpe.max().item() / (mean_tpe_val + 1e-9)
            else:
                self.latest_imbalance = 1.0
            max_p_per_token, _ = flat_gating_output_no_grad.max(dim=-1)
            self.latest_concentration = max_p_per_token.mean().item() if num_tokens > 0 else 0.0
            num_active_experts = (tpe > 0).sum().item()
            self.latest_utilization = num_active_experts / self.num_experts if self.num_experts > 0 else 0.0
            if self.num_experts > 1 and num_tokens > 0:
                avg_prob = flat_gating_output_no_grad.mean(dim=0)
                mean_avg_prob = avg_prob.mean()
                std_avg_prob = avg_prob.std(unbiased=False)
                self.latest_avg_prob_cv = (
                        std_avg_prob / (mean_avg_prob + 1e-9)).item() if mean_avg_prob > 1e-9 else 0.0
            else:
                self.latest_avg_prob_cv = 0.0
        flat_x = x.reshape(-1, n_embed)
        final_output = torch.zeros_like(flat_x)
        flat_gating_output_for_weighting = gating_output.reshape(-1, self.num_experts)
        for i, expert in enumerate(self.experts):
            expert_scores = flat_gating_output_for_weighting[:, i]
            token_indices = torch.nonzero(flat_gating_output_no_grad[:, i] > 1e-9).squeeze(-1)
            if token_indices.numel() == 0: continue
            expert_input = flat_x[token_indices]
            active_gating_scores = expert_scores[token_indices].unsqueeze(1)
            expert_output = expert(expert_input)
            weighted_output = expert_output * active_gating_scores
            final_output.index_add_(0, token_indices, weighted_output)
        final_output = final_output.view(batch_size, seq_len, n_embed)
        return final_output


class Block(nn.Module):
    def __init__(self, n_embed, n_head, num_experts, block_size, dropout, router_alpha=1.5, router_temperature=1.0):
        super().__init__()
        assert n_embed % n_head == 0
        self.sa = MultiHeadAttention(n_embed, n_head, block_size, dropout)
        self.smoe = SparseMoE(n_embed, num_experts, dropout, router_alpha=router_alpha,
                              router_temperature=router_temperature)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.smoe(self.ln2(x))
        return x


class SparseMoELanguageModel(nn.Module):
    def __init__(self, vocab_size, n_embed, n_head, n_layer, num_experts, block_size, dropout, router_alpha=1.5,
                 router_temperature=1.0):
        super().__init__()
        self.n_embed = n_embed
        self.block_size = block_size
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(*[
            Block(n_embed=n_embed, n_head=n_head, num_experts=num_experts, block_size=block_size, dropout=dropout,
                  router_alpha=router_alpha, router_temperature=router_temperature) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embed)
        self.lm_head = nn.Linear(n_embed, vocab_size)
        print(f"Initialized SparseMoELanguageModel with {n_layer} layers.")
        print(f"Router settings per layer: alpha={router_alpha}, temperature={router_temperature}")

    def forward(self, idx, targets=None):
        B, T = idx.shape
        device = idx.device
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            B, T, C = logits.shape
            logits_for_loss = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits_for_loss, targets)
        return logits, loss


@torch.no_grad()
def estimate_loss(model):
    out = {'loss': {}}
    num_model_layers = n_layer
    accumulated_metrics = {
        f'L{i}': {'imbalance': 0.0, 'concentration': 0.0, 'utilization': 0.0, 'tpe_cv': 0.0, 'avg_prob_cv': 0.0,
                  'count': 0}
        for i in range(num_model_layers)
    }
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for i in range(num_model_layers):
            for key in accumulated_metrics[f'L{i}']:
                accumulated_metrics[f'L{i}'][key] = 0.0 if key != 'count' else 0
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
            if hasattr(model, 'blocks') and isinstance(model.blocks, nn.Sequential):
                for i, block in enumerate(model.blocks):
                    if isinstance(block, Block) and hasattr(block, 'smoe') and isinstance(block.smoe, SparseMoE):
                        smoe_layer = block.smoe
                        layer_key = f'L{i}'
                        if num_experts > 0:
                            accumulated_metrics[layer_key]['imbalance'] += smoe_layer.latest_imbalance
                            accumulated_metrics[layer_key]['concentration'] += smoe_layer.latest_concentration
                            accumulated_metrics[layer_key]['utilization'] += smoe_layer.latest_utilization
                            accumulated_metrics[layer_key]['tpe_cv'] += smoe_layer.latest_tpe_cv
                            accumulated_metrics[layer_key]['avg_prob_cv'] += smoe_layer.latest_avg_prob_cv
                            accumulated_metrics[layer_key]['count'] += 1
        out['loss'][split] = losses.mean()
    averaged_metrics = {f'L{i}': {} for i in range(num_model_layers)}
    for i in range(num_model_layers):
        layer_key = f'L{i}'
        count = accumulated_metrics[layer_key]['count']
        if count > 0:
            for key in accumulated_metrics[layer_key]:
                if key != 'count':
                    averaged_metrics[layer_key][key] = accumulated_metrics[layer_key][key] / count
        else:
            for key in accumulated_metrics[layer_key]:
                if key != 'count': averaged_metrics[layer_key][key] = 0.0
    model.train()
    return out['loss'], averaged_metrics

In [None]:
alphas = [1.0, 1.5, 2.0]
temperatures = [0.5, 1, 10.0]

print(f"Using device: {device}")
print(f"Vocab size: {vocab_size}")
print(f"Data length: {len(text)}")

df_results = pd.DataFrame([], columns=['model', 'alpha', 'temperature', 'iter', 'loss_train', 'loss_val', 'utilization',
                                       'imbalance', 'concentration', 'tpe_cv', 'avg_prob_cv'])

start_time = time.time()
for alpha in alphas:
    print(df_results.head(2))
    for temperature in temperatures:
        key = f"alpha_{alpha}_temp_{temperature}"
        model = SparseMoELanguageModel(
            vocab_size=vocab_size, n_embed=n_embed, n_head=n_head, n_layer=n_layer,
            num_experts=num_experts, block_size=block_size, dropout=dropout,
            router_alpha=alpha, router_temperature=temperature
        )
        model = model.to(device)
        print(f'{sum(p.numel() for p in model.parameters()) / 1e6:.2f} M parameters')
        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

        print(f"\n----- {key} -----")
        for iter in range(max_iters):
            if iter % eval_interval == 0 or iter == max_iters - 1:
                losses, metrics = estimate_loss(model)
                loss_train = losses['train'].detach().cpu().numpy()
                loss_val = losses['val'].detach().cpu().numpy()
                avg_utilization = sum(metrics[f'L{i}']['utilization'] for i in range(n_layer)) / n_layer
                avg_imbalance = sum(metrics[f'L{i}']['imbalance'] for i in range(n_layer)) / n_layer
                avg_concentration = sum(metrics[f'L{i}']['concentration'] for i in range(n_layer)) / n_layer
                avg_tpe_cv = sum(metrics[f'L{i}']['tpe_cv'] for i in range(n_layer)) / n_layer
                avg_avg_prob_cv = sum(metrics[f'L{i}']['avg_prob_cv'] for i in range(n_layer)) / n_layer

                df_results = pd.concat([df_results if not df_results.empty else None, pd.DataFrame([{
                    'model': key,
                    'alpha': alpha,
                    'temperature': temperature,
                    'iter': iter,
                    'loss_train': loss_train, 'loss_val': loss_val,
                    'utilization': avg_utilization, 'imbalance': avg_imbalance,
                    'concentration': avg_concentration, 'tpe_cv': avg_tpe_cv,
                    'avg_prob_cv': avg_avg_prob_cv
                }])])

                print(f"\nSTEP: {iter:<1}/{max_iters:<5} || Train Loss: {loss_train:.4f} | Val Loss: {loss_val:.4f}")
                print(
                    f"Input Layer L0 || Utilization: {avg_utilization:<6.4f} | Imbalance: {avg_imbalance:<6.4f} | Concentration: {avg_concentration:<6.4f} | TPE CV: {avg_tpe_cv:<6.4f} | Avg Prob CV: {avg_avg_prob_cv:<6.4f}")

            xb, yb = get_batch('train')
            logits, loss = model(xb, yb)
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
        end_time = time.time()
        print(f"\nTraining completed for {key}, elapsed time: {end_time - start_time:.2f} seconds")

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns


print("DataFrame Info:")
df_results.info()

print("\nFirst 5 rows:")
print(df_results.head())

print("\nLast 5 rows:")
print(df_results.tail())

print("\nUnique Alpha values:", df_results['alpha'].unique())
print("Unique Temperature values:", df_results['temperature'].unique())