In [14]:
import jax
import jax.numpy as jnp
import flax.nnx as nnx
from typing import Any

class Router(nnx.Module):
    def __init__(self, dim: int, num_experts: int, *, rngs: nnx.Rngs):
        self.w1 = nnx.Linear(dim, num_experts, rngs=rngs)
        #self.w2 = nnx.Linear(dim * 3, num_experts, rngs=rngs)
        #self.noise = nnx.Linear(dim * 3, num_experts, rngs=rngs)

    def __call__(self, x: jax.Array) -> jax.Array:
        return self.w1(x)

class Expert(nnx.Module):
    def __init__(self, dim: int, *, rngs: nnx.Rngs):
        self.linear = nnx.Linear(dim, dim, rngs=rngs)

    def __call__(self, x: jax.Array) -> jax.Array:
        return self.linear(x)

class SimpleMoE(nnx.Module):
    def __init__(self, dim: int, *, rngs: nnx.Rngs):
        num_experts = 8
        self.router = Router(dim, num_experts=num_experts, rngs=rngs)
        self.experts = [
            Expert(dim, rngs=rngs)
            for _ in range(num_experts)
        ]
        self.top_k = 2

    def __call__(self, x: jax.Array) -> jax.Array:
        gate_logits = self.router(x)       
        top_k_logits, expert_indices = jax.lax.top_k(gate_logits, self.top_k)
        zeros = jnp.full_like(gate_logits, float('-inf'))
        sparse_logits = jnp.put_along_axis(
            zeros, expert_indices, top_k_logits, axis=-1, inplace=False
        )
        expert_weights = jax.nn.softmax(sparse_logits, axis=-1)

        mean_gates = jnp.mean(gate_logits, axis=0)
        lb_loss = gate_logits.shape[1] * jnp.sum(mean_gates ** 2)

        outputs = [ e(x) for e in self.experts ]

        result = jnp.zeros_like(x)

        for i, o in enumerate(outputs):
            result += (o * expert_weights[:, i:i+1])
           
        return result, lb_loss, gate_logits

In [None]:
import optax 

D, B, C = 10000, 16, 8

model = SimpleMoE(dim=8, rngs=nnx.Rngs(0))
tx = optax.adam(1e-3)
state = nnx.Optimizer(model, tx)

x = jax.random.normal(jax.random.key(1000), (D, B, C))
t = [
    jax.random.normal(jax.random.key(2000), (C, C)),
    jax.random.normal(jax.random.key(3000), (C, C)),
    jax.random.normal(jax.random.key(4000), (C, C))
]
y = jnp.concatenate([
    x[:D//2] @ t[0],
    x[D//2:] @ t[1]
])
z = y @ t[2]


def loss_fn(model, x, y):
    y_pred, lb_loss, gates = model(x)
    loss = jnp.mean((y - y_pred)**2) # + lb_loss
    return loss, gates

@nnx.jit
def step(state, x, y):
    (loss, gates), grads = nnx.value_and_grad(loss_fn, has_aux=True)(state.model, x, y)
    state.update(grads)
    return loss, gates

import random
indices = list(range(D))
random.shuffle(indices)

for e in range(10):
    for it, i in enumerate(indices):
        loss, gates = step(state, y[i], z[i])
        print(i, loss )