# Basic Dense Mixture of Experts Model


In [None]:
import jax
import jax.numpy as jnp
import flax.nnx as nnx
from typing import Any

class Router(nnx.Module):
    def __init__(self, dim: int, num_experts: int, *, rngs: nnx.Rngs):
        self.w1 = nnx.Linear(dim, num_experts, rngs=rngs)

    def __call__(self, x: jax.Array) -> jax.Array:
        return self.w1(x)

class Expert(nnx.Module):
    def __init__(self, dim: int, *, rngs: nnx.Rngs):
        self.linear = nnx.Linear(dim, dim, rngs=rngs)

    def __call__(self, x: jax.Array) -> jax.Array:
        return self.linear(x)

class SimpleMoE(nnx.Module):
    def __init__(self, dim: int, *, rngs: nnx.Rngs):
        num_experts = 2
        self.router = Router(dim, num_experts=num_experts, rngs=rngs)
        self.experts = [
            Expert(dim, rngs=rngs)
            for _ in range(num_experts)
        ]
        self.top_k = 2

    def __call__(self, x: jax.Array) -> jax.Array:
        gate_logits = self.router(x)       
        top_k_logits, expert_indices = jax.lax.top_k(gate_logits, self.top_k)
        zeros = jnp.full_like(gate_logits, float('-inf'))
        sparse_logits = jnp.put_along_axis(
            zeros, expert_indices, top_k_logits, axis=-1, inplace=False
        )
        expert_weights = jax.nn.softmax(sparse_logits, axis=-1)

        mean_gates = jnp.mean(gate_logits, axis=0)
        lb_loss = gate_logits.shape[1] * jnp.sum(mean_gates ** 2)

        outputs = [ e(x) for e in self.experts ]

        result = jnp.zeros_like(x)

        for i, o in enumerate(outputs):
            result += (o * expert_weights[:, :, i:i+1])
           
        return result, lb_loss, expert_weights

In [None]:
import optax 

D, B, T, C = 10000, 16, 4, 3

model = SimpleMoE(dim=C, rngs=nnx.Rngs(0))
tx = optax.adam(1e-3)
state = nnx.Optimizer(model, tx)

x = jax.random.normal(jax.random.key(1000), (D * B * T, C))

expert_ids = (x[:, 0] > 0).astype(jnp.int32)
t = [
    jax.random.normal(jax.random.key(2000), (C, C)),
    jax.random.normal(jax.random.key(3000), (C, C)),
]
def transform(xi, eid):
    return jnp.where(eid == 1, xi @ t[0], xi @ t[1])

y = jax.vmap(lambda xi, ei: transform(xi, ei))(x, expert_ids)

def loss_fn(model, x, y):
    y_pred, lb_loss, gates = model(x)
    loss = jnp.mean((y - y_pred)**2) # + lb_loss
    return loss, gates

@nnx.jit
def step(state, x, y):
    (loss, gates), grads = nnx.value_and_grad(loss_fn, has_aux=True)(state.model, x, y)
    state.update(grads)
    return loss, gates, grads

x = x.reshape(D, B, T, C)
y = y.reshape(D, B, T, C)

for e in range(10):
    for i in range(D):
        loss, gates, grads = step(state, x[i], y[i])
        if i % 1000 == 0:
            print(i, loss)

# Sparse Mixture of Experts Model

In [None]:
import os

from functools import partial
from dataclasses import dataclass
import random

import jax
import jax.numpy as jnp

import flax.nnx as nnx
import optax

from jaxpt.modules.config import Config


@dataclass(unsafe_hash=True)
class GLU_Config(Config):
    top_k = 2
    load_factor = 1.00
    n_experts = 2
    n_embed = 3
    n_mlp_hidden = 6
    mlp_bias = True
    dtype = jax.numpy.float32

config = GLU_Config()


class Experts(nnx.Module):
    def __init__(self, config, rngs):
        init = nnx.initializers.normal(stddev=0.02)
        self.w1 = nnx.Param(init(rngs.default(),
            (
                config.n_experts,
                config.n_embed,
                config.n_embed
            )
        ))

    def __call__(self, x, expert_idx):
        w1 = self.w1[expert_idx] 
        x = x @ w1
        return x


class MOE(nnx.Module):
    def __init__(self, config: Config, rngs: nnx.Rngs):
        self.router_gate = nnx.Linear(
            config.n_embed,
            config.n_experts,
            kernel_init=nnx.initializers.normal(stddev=0.02),
            bias_init=nnx.initializers.zeros, 
            use_bias=config.mlp_bias,
            dtype=config.dtype,
            rngs=rngs,
        )
        self.experts = Experts(config, rngs)        
        self.top_k = config.top_k
        self.n_experts = config.n_experts
        self.load_factor = config.load_factor
        self.add_noise = False
        self.rngs = rngs

    def __call__(self, x):
        B, T, C = x.shape
        x = x.reshape(-1, C)
        logits = self.router_gate(x) # B, n_experts
        #if self.add_noise:
        #    logits += 1 * jax.random.normal(key=self.rngs.gate_noise(), shape=logits.shape)
        top_k_logits, expert_indices = jax.lax.top_k(logits, self.top_k) # B, top_k

        zeros = jnp.full_like(logits, float('-inf')) # B, n_experts
        sparse_logits = jnp.put_along_axis(
                zeros, expert_indices, top_k_logits, axis=-1, inplace=False) # b, n_experts
        expert_weights = jax.nn.softmax(sparse_logits, axis=-1) # B, n_experts

        expert_inputs = jnp.zeros((self.n_experts, self.top_k * B * T, C))
        input_counters = jnp.zeros((self.n_experts,), dtype=jnp.uint8)

        def update_expert_inputs(i, carry):
            expert_inputs, counters = carry
            for j in range(self.top_k):
                expert_idx = expert_indices[i, j]
                token_pos = counters[expert_idx]
                expert_inputs = expert_inputs.at[expert_idx, token_pos].set(x[i])
                counters = counters.at[expert_idx].add(1)

            return expert_inputs, counters
        
        expert_inputs, input_counters = jax.lax.fori_loop(
            0, B * T, update_expert_inputs, (
                expert_inputs,
                input_counters
            )
        )

        expert_outputs = jnp.zeros_like(expert_inputs)
        for i in range(self.n_experts):
            expert_outputs = expert_outputs.at[i].set(
                self.experts(expert_inputs[i], i)
                )

        output_counters = jnp.zeros((self.n_experts,), dtype=jnp.uint8)
        #y = jnp.zeros((B,))
        y_pred = jnp.zeros_like(x)
        def update_expert_outputs(i, carry):
            y_pred, output_counters = carry
            for j in range(self.top_k):
                expert_idx = expert_indices[i, j]
                token_pos = output_counters[expert_idx]
                y_pred = y_pred.at[i].add(
                    expert_outputs[expert_idx, token_pos] * expert_weights[i, expert_idx])
                output_counters = output_counters.at[expert_idx].add(1)

            return y_pred, output_counters

        y_pred, output_counters = jax.lax.fori_loop(
            0, B * T, update_expert_outputs, (
                y_pred,
                output_counters
            )
        )

        y_pred = y_pred.reshape(B, T, C)
        return y_pred

def loss_fn(model, x, y):
    y_pred  = model(x)
    loss = jnp.mean((y - y_pred)**2)
    return loss, y_pred

@nnx.jit
def step(state, x, y):
    (loss, y_pred), grads = nnx.value_and_grad(
        loss_fn, has_aux=True)(state.model, x, y)
    state.update(grads)
    return loss, grads, y_pred

D, B, T, C =  1000, 2 * config.n_experts, 16, config.n_embed 
   
default = jax.random.key(69)
gate_noise = jax.random.key(42)
rngs = nnx.Rngs(default=default, gate_noise=gate_noise)

model = MOE(config, rngs)
model.train(add_noise=False)
tx = optax.adam(1e-2)
state = nnx.Optimizer(model, tx)

x = jax.random.normal(jax.random.key(1000), (D, B, T, C))

expert_ids = (x[:, :, :, 0] > 0).astype(jnp.int32)[..., None]
t = [
    jax.random.normal(jax.random.key(2000), (C, C)),
    jax.random.normal(jax.random.key(3000), (C, C)),
]

def transform(xi, eid):
    return jnp.where(eid == 1, xi @ t[0], xi @ t[1])

y = jax.vmap(lambda xi, ei: transform(xi, ei))(x, expert_ids)
#x = x.reshape(D, B, T, C)
#y = y.reshape(D, B, T, C)

indices = list(range(D))
for e in range(100):
    for i in indices:
        loss, grads, y_pred = step(state, x[i], y[i])
        if i % 1000 == 0:
            print(e, i, loss)
