# Basic Dense Mixture of Experts Model

## Overview

This implementation demonstrates a **dense mixture of experts** approach where:
- **All experts process all tokens** (computationally expensive)
- **Top-k routing** selects the k most relevant experts per token
- **Soft weighting** combines expert outputs using routing scores

## Architecture

### Components:
- **Router**: Linear layer that outputs gating logits for each expert
- **Experts**: Simple linear transformations (dim → dim)
- **Load Balancing Loss**: Encourages uniform expert utilization

### Algorithm Flow:
1. **Routing**: Compute gating scores for each expert
2. **Selection**: Top-k selection of experts per token
3. **Weighting**: Apply softmax to create sparse expert weights
4. **Computation**: All experts process all tokens (dense)
5. **Combination**: Weighted sum of expert outputs

## Key Characteristics

### Pros:
- Simple implementation and debugging
- Smooth gradients from all experts
- Implicit load balancing through auxiliary loss

### Cons:
- **Inefficient**: O(n_experts) computation per token
- Not scalable for large numbers of experts
- High memory usage due to dense computation

### Load Balancing Loss:
```python
mean_gates = jnp.mean(gate_logits, axis=0)
lb_loss = gate_logits.shape[1] * jnp.sum(mean_gates ** 2)
```
This penalty encourages uniform distribution of tokens across experts by minimizing the variance of average gating scores.

## Use Cases

- Educational purposes and prototyping
- Small-scale experiments with few experts
- Baseline for comparing with sparse implementations
- Situations where all expert contributions are needed


In [5]:
import jax
import jax.numpy as jnp
import flax.nnx as nnx
from typing import Any

class Router(nnx.Module):
    def __init__(self, dim: int, num_experts: int, *, rngs: nnx.Rngs):
        self.w1 = nnx.Linear(dim, num_experts, rngs=rngs)

    def __call__(self, x: jax.Array) -> jax.Array:
        return self.w1(x)

class Expert(nnx.Module):
    def __init__(self, dim: int, *, rngs: nnx.Rngs):
        self.linear = nnx.Linear(dim, dim, rngs=rngs)

    def __call__(self, x: jax.Array) -> jax.Array:
        return self.linear(x)

class SimpleMoE(nnx.Module):
    def __init__(self, dim: int, *, rngs: nnx.Rngs):
        num_experts = 2
        self.router = Router(dim, num_experts=num_experts, rngs=rngs)
        self.experts = nnx.List([
            Expert(dim, rngs=rngs)
            for _ in range(num_experts)
        ])
        self.top_k = 2

    def __call__(self, x: jax.Array) -> jax.Array:
        gate_logits = self.router(x)       
        top_k_logits, expert_indices = jax.lax.top_k(gate_logits, self.top_k)
        zeros = jnp.full_like(gate_logits, float('-inf'))
        sparse_logits = jnp.put_along_axis(
            zeros, expert_indices, top_k_logits, axis=-1, inplace=False
        )
        expert_weights = jax.nn.softmax(sparse_logits, axis=-1)

        mean_gates = jnp.mean(gate_logits, axis=0)
        lb_loss = gate_logits.shape[1] * jnp.sum(mean_gates ** 2)

        outputs = [ e(x) for e in self.experts ]

        result = jnp.zeros_like(x)

        for i, o in enumerate(outputs):
            result += (o * expert_weights[:, :, i:i+1])
           
        return result, lb_loss, expert_weights

In [12]:
import optax 

D, B, T, C = 10000, 2, 5, 3

model = SimpleMoE(dim=C, rngs=nnx.Rngs(0))
tx = optax.adam(1e-3)
state = nnx.Optimizer(model, tx, wrt=nnx.Param)

x = jax.random.normal(jax.random.key(1000), (D * B * T, C))

expert_ids = (x[:, 0] > 0).astype(jnp.int32)
t = [
    jax.random.normal(jax.random.key(2000), (C, C)),
    jax.random.normal(jax.random.key(3000), (C, C)),
]
def transform(xi, eid):
    return jnp.where(eid == 1, xi @ t[0], xi @ t[1])

y = jax.vmap(lambda xi, ei: transform(xi, ei))(x, expert_ids)

def loss_fn(model, x, y):
    y_pred, lb_loss, gates = model(x)
    loss = jnp.mean((y - y_pred)**2) # + lb_loss
    return loss, gates

@nnx.jit
def step(model, state, x, y):
    (loss, gates), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model, x, y)
    state.update(model, grads)
    return loss, gates, grads

x = x.reshape(D, B, T, C)
y = y.reshape(D, B, T, C)

for e in range(10):
    for i in range(D):
        loss, gates, grads = step(model, state, x[i], y[i])
        if i % 1000 == 0:
            print(i, loss)

0 2.8386207
1000 1.4691024
2000 0.6264829
3000 0.22666237
4000 0.32224867
5000 0.27109587
6000 0.036007397
7000 0.042003997
8000 0.359906
9000 0.019899525
0 0.07967947
1000 0.048338573
2000 0.07231833
3000 0.0011955692
4000 0.061102558
5000 0.027696146
6000 0.0011745306
7000 0.0014423532
8000 0.25469956
9000 0.0031980982
0 0.037122283
1000 0.021085124
2000 0.020036932
3000 0.0002973358
4000 0.028141744
5000 0.0031959899
6000 0.0005252546
7000 0.0006971828
8000 0.23215021
9000 0.0010879862
0 0.019682594
1000 0.009787401
2000 0.00533594
3000 0.00018117022
4000 0.013458978
5000 0.00086001644
6000 0.00029388236
7000 0.0004293031
8000 0.21831232
9000 0.00060997024
0 0.012715654
1000 0.004591771
2000 0.0015240598
3000 0.00012839105
4000 0.0068611316
5000 0.0005576007
6000 0.00019687101
7000 0.00030134188
8000 0.20550406
9000 0.00043756963
0 0.009739019
1000 0.002218916
2000 0.0005265196
3000 0.00010175501
4000 0.0037555827
5000 0.000450459
6000 0.00014797169
7000 0.00023094774
8000 0.1936279

# Sparse Mixture of Experts Model

## Overview

This implementation demonstrates a **sparse mixture of experts** approach that overcomes the computational inefficiency of dense MoE:
- **Only selected experts process tokens** (computationally efficient)
- **Expert capacity management** prevents overflow using load factors
- **Efficient buffer management** for token-to-expert routing

## Architecture

### Components:
- **Router**: Linear layer with configurable bias for expert selection
- **Experts**: Parameterized linear experts with efficient indexing
- **Expert Buffers**: Dynamically allocated buffers for expert inputs
- **Capacity Management**: Load factor controls expert buffer sizes

### Algorithm Flow:
1. **Routing**: Compute gating scores for each expert
2. **Selection**: Top-k selection of experts per token
3. **Buffer Allocation**: Create expert input buffers using load factor
4. **Token Distribution**: Efficiently assign tokens to selected experts
5. **Expert Processing**: Each expert processes only its assigned tokens
6. **Output Assembly**: Combine expert outputs using routing weights

## Key Innovations

### Sparse Computation:
Unlike dense MoE, only the selected experts process each token:
```python
# Dense: All experts process all tokens
outputs = [expert(x) for expert in experts]  # O(n_experts) cost

# Sparse: Only selected experts process tokens
expert_inputs[expert_idx, token_pos] = x[i]  # O(top_k) cost
```

### Capacity Management:
The `load_factor` parameter determines expert buffer sizes:
- `load_factor = 1.0`: Expected token capacity (can cause overflow)
- `load_factor > 1.0`: Extra capacity for load imbalance
- Higher values = more memory but better coverage

### Efficient Indexing:
Uses counter-based token placement to avoid expensive scatter operations:
```python
input_counters = jnp.zeros((n_experts,), dtype=jnp.uint8)
expert_inputs = expert_inputs.at[expert_idx, token_pos].set(x[i])
input_counters = input_counters.at[expert_idx].add(1)
```

## Performance Characteristics

### Computational Complexity:
- **Dense MoE**: O(seq_len × n_experts × dim²)
- **Sparse MoE**: O(seq_len × top_k × dim²)
- **Speedup**: ~n_experts/top_k times faster

### Memory Usage:
- Expert buffers: `n_experts × top_k × batch_size × seq_len × dim`
- Scales with `top_k` rather than `n_experts`

### Load Balancing:
- No explicit auxiliary loss (unlike dense version)
- Relies on learned routing to distribute load
- Can be enhanced with load balancing penalties if needed

## Comparison with Dense Implementation

| Aspect | Dense MoE | Sparse MoE |
|--------|-----------|------------|
| Computation | All experts × all tokens | Top-k experts × tokens |
| Memory | Lower (no buffers) | Higher (expert buffers) |
| Scalability | Poor (linear in n_experts) | Excellent (constant in n_experts) |
| Load Balancing | Explicit loss term | Implicit via routing |
| Implementation | Simple | Complex (buffer management) |

## Connection to Production MoE Models

This sparse implementation shares core concepts with production models like `Tiny_MoE`, `Tiny_MoE_2`, and `Tiny_MoE_3`:

### Similarities:
- Top-k expert selection
- Sparse token-to-expert routing
- Efficient expert buffer management
- Configurable load factors

### Production Enhancements:
- **Advanced Load Balancing**: Auxiliary losses (load balance + z-loss)
- **Mixed Precision**: bfloat16 computation, float32 parameters
- **Device Sharding**: Distributed training across multiple devices
- **Expert Weight Priority**: Weight-based expert selection ordering
- **Soft MoE**: Differentiable routing (Tiny_MoE_3)
- **Gated Linear Units**: More powerful expert networks

## Use Cases

- **Large-scale models** where computational efficiency is critical
- **Training on limited resources** with many experts
- **Inference scenarios** requiring fast processing
- **Foundation for advanced MoE variants** (Switch Transformer, GLaM, etc.)

## Limitations and Extensions

### Current Limitations:
- No explicit load balancing loss
- Fixed capacity buffers (can overflow or underflow)
- Single-device implementation

### Potential Extensions:
- Add load balancing and auxiliary losses
- Implement capacity dropping for overflow handling
- Add noise injection for better exploration
- Extend to multi-device training with expert sharding
- Implement different routing strategies (e.g., learned routing)

This sparse implementation provides the foundation for understanding modern, efficient MoE systems used in large language models.

In [None]:
import os

from functools import partial
from dataclasses import dataclass
import random

import jax
import jax.numpy as jnp

import flax.nnx as nnx
import optax

from jaxpt.modules.config import Config


@dataclass(unsafe_hash=True)
class GLU_Config(Config):
    top_k = 2
    load_factor = 1.00
    n_experts = 2
    n_embed = 3
    n_mlp_hidden = 6
    mlp_bias = True
    dtype = jax.numpy.float32

config = GLU_Config()


class Experts(nnx.Module):
    def __init__(self, config, rngs):
        init = nnx.initializers.normal(stddev=0.02)
        self.w1 = nnx.Param(init(rngs.default(),
            (
                config.n_experts,
                config.n_embed,
                config.n_embed
            )
        ))

    def __call__(self, x, expert_idx):
        w1 = self.w1[expert_idx] 
        x = x @ w1
        return x


class MOE(nnx.Module):
    def __init__(self, config: Config, rngs: nnx.Rngs):
        self.router_gate = nnx.Linear(
            config.n_embed,
            config.n_experts,
            kernel_init=nnx.initializers.normal(stddev=0.02),
            bias_init=nnx.initializers.zeros, 
            use_bias=config.mlp_bias,
            dtype=config.dtype,
            rngs=rngs,
        )
        self.experts = Experts(config, rngs)        
        self.top_k = config.top_k
        self.n_experts = config.n_experts
        self.load_factor = config.load_factor
        self.add_noise = False
        self.rngs = rngs

    def __call__(self, x):
        B, T, C = x.shape
        x = x.reshape(-1, C)
        # Gives you the expert logits for each token
        logits = self.router_gate(x) # B, n_experts
        #if self.add_noise:
        #    logits += 1 * jax.random.normal(key=self.rngs.gate_noise(), shape=logits.shape)
        # 
        # Obtains the logits for the top_k experts as well as the indices for those experts
        top_k_logits, expert_indices = jax.lax.top_k(logits, self.top_k) # B, top_k

        # Tensor that will hold the expert weights for each token
        zeros = jnp.full_like(logits, float('-inf')) # B, n_experts
        # Fill the logits for each token 
        sparse_logits = jnp.put_along_axis(
                zeros, expert_indices, top_k_logits, axis=-1, inplace=False) # B, n_experts
        # Take a softmax across each row to obtain the expert weights
        expert_weights = jax.nn.softmax(sparse_logits, axis=-1) # B, n_experts

        # Construct an array to hold the inputs for each expert
        expert_inputs = jnp.zeros((self.n_experts, self.top_k * B * T, C))
        input_counters = jnp.zeros((self.n_experts,), dtype=jnp.uint8)

        # Loop through the experts and update each input
        def update_expert_inputs(i, carry):
            expert_inputs, counters = carry
            for j in range(self.top_k):
                expert_idx = expert_indices[i, j]
                token_pos = counters[expert_idx]
                expert_inputs = expert_inputs.at[expert_idx, token_pos].set(x[i])
                counters = counters.at[expert_idx].add(1)

            return expert_inputs, counters

        # Loop through all the tokens and assign them to the top k experts        
        expert_inputs, input_counters = jax.lax.fori_loop(
            0, B * T, update_expert_inputs, (
                expert_inputs,
                input_counters
            )
        )

        # Create a tensor for the expert outputs
        expert_outputs = jnp.zeros_like(expert_inputs)

        # Loop through each expert and transform the inputs
        for i in range(self.n_experts):
            expert_outputs = expert_outputs.at[i].set(
                self.experts(expert_inputs[i], i)
                )


        # Now we need to scatter the tokens back to their original positions 
        output_counters = jnp.zeros((self.n_experts,), dtype=jnp.uint8)
        #y = jnp.zeros((B,))
        y_pred = jnp.zeros_like(x)
        def update_expert_outputs(i, carry):
            y_pred, output_counters = carry
            for j in range(self.top_k):
                expert_idx = expert_indices[i, j]
                token_pos = output_counters[expert_idx]
                y_pred = y_pred.at[i].add(
                    expert_outputs[expert_idx, token_pos] * expert_weights[i, expert_idx])
                output_counters = output_counters.at[expert_idx].add(1)

            return y_pred, output_counters

        # Loop through the transformed tokens and assign them to their original positions
        y_pred, output_counters = jax.lax.fori_loop(
            0, B * T, update_expert_outputs, (
                y_pred,
                output_counters
            )
        )

        # Reshape the output to its original shape
        y_pred = y_pred.reshape(B, T, C)
        return y_pred

def loss_fn(model, x, y):
    y_pred  = model(x)
    loss = jnp.mean((y - y_pred)**2)
    return loss, y_pred

@nnx.jit
def step(state, x, y):
    (loss, y_pred), grads = nnx.value_and_grad(
        loss_fn, has_aux=True)(state.model, x, y)
    state.update(grads)
    return loss, grads, y_pred

D, B, T, C =  1000, config.n_experts, 5, config.n_embed 
   
default = jax.random.key(69)
gate_noise = jax.random.key(42)
rngs = nnx.Rngs(default=default, gate_noise=gate_noise)

model = MOE(config, rngs)
model.train(add_noise=False)
tx = optax.adam(1e-2)
state = nnx.Optimizer(model, tx)

x = jax.random.normal(jax.random.key(1000), (D, B, T, C))

expert_ids = (x[:, :, :, 0] > 0).astype(jnp.int32)[..., None]
t = [
    jax.random.normal(jax.random.key(2000), (C, C)),
    jax.random.normal(jax.random.key(3000), (C, C)),
]

def transform(xi, eid):
    return jnp.where(eid == 1, xi @ t[0], xi @ t[1])

y = jax.vmap(lambda xi, ei: transform(xi, ei))(x, expert_ids)
#x = x.reshape(D, B, T, C)
#y = y.reshape(D, B, T, C)

indices = list(range(D))
for e in range(100):
    for i in indices:
        loss, grads, y_pred = step(state, x[i], y[i])
        if i % 1000 == 0:
            print(e, i, loss)


0 0 2.6899827
1 0 0.07806056
2 0 0.05633394
3 0 0.04081124
4 0 0.029184442
5 0 0.020910963
6 0 0.015457934
7 0 0.012084823
8 0 0.010068251
9 0 0.008866145
10 0 0.008118208
11 0 0.007610542
12 0 0.0072278054
13 0 0.0069115274
14 0 0.006632954
15 0 0.0063779377
16 0 0.006139335
17 0 0.005913367
18 0 0.0056978953
19 0 0.0054915864
20 0 0.005293556
21 0 0.005103143
22 0 0.004919832
23 0 0.004743188
24 0 0.00457285
25 0 0.0044085
26 0 0.0042498484
27 0 0.0040966603
28 0 0.0039487346
29 0 0.003805874
30 0 0.0036679066
31 0 0.0035347173
32 0 0.0034061493
33 0 0.003282117
34 0 0.0031625007
35 0 0.0030472109
36 0 0.0029361544
37 0 0.002829238
38 0 0.0027263872
39 0 0.0026275036
40 0 0.0025324963
41 0 0.0024412859
42 0 0.0023537574
43 0 0.0022698208
44 0 0.0021893624
45 0 0.0021122745
46 0 0.0020384467
47 0 0.00196777
48 0 0.0019001174
49 0 0.0018353804
50 0 0.0017734382
51 0 0.0017141727
52 0 0.0016574716
53 0 0.0016032109
54 0 0.0015512869
55 0 0.0015015854
56 0 0.0014539972
57 0 0.0014084165
