### 

## Soft Mixture of Experts

In [33]:
import jax
import jax.numpy as jnp
import flax.nnx as nnx
import optax

In [101]:
class SoftMoE(nnx.Module):
    def __init__(self, n_experts, capacity, n_dim, rngs):
        init = nnx.initializers.normal(stddev=0.02)
        self.gate = nnx.Param(
            init(rngs.default(), (n_experts, capacity, n_dim))
        )
        self.experts = nnx.Param(
            init(rngs.default(), (n_experts, n_dim, n_dim))
        )
        self.out = nnx.Linear(
            n_dim, 2, rngs=rngs
        )

    def _dispatch(self, x, dw):
        ei = jnp.einsum('td,tec->ecd', x, dw)
        eo = jnp.einsum('ecd,edd->ecd', ei, self.experts)
        return eo 

    def _combine(self, eo, cw):
        y = jnp.einsum('ecd,tec->td', eo, cw)
        return y

    def __call__(self, x):
        g = jnp.einsum('btd,ecd->btec', x, self.gate) # B, T, E, capacity
        dw = jax.nn.softmax(g, axis=1)
        cw = jax.nn.softmax(g, axis=(2, 3))
        eo = jax.vmap(lambda x, w: self._dispatch(x, w))(x, dw) # B, E, capacity, C
        y = jax.vmap(lambda o, w: self._combine(o, w))(eo, cw) # B, T, C
        return self.out(y) # B, T, 2


In [115]:
D, B, T, C = 1000, 16, 24, 3
n_experts = 2
capacity = B * T // n_experts
print(capacity)

# create model
rngs = nnx.Rngs(default=0)
m = SoftMoE(n_experts, capacity, C, rngs)

# create dataset
x = jax.random.normal(
   jax.random.key(100),
   (D*B*T, C)
)

r1 = jax.random.normal(jax.random.key(200), (C, C))
r2 = jax.random.normal(jax.random.key(400), (C, C))

y = jax.vmap(lambda x: x.sum() > 0)(x).astype(jnp.int16)
x = jax.vmap(lambda label, x: jax.lax.cond(label == 0, lambda x: x @ r1, lambda x: x @ r2, operand=x))(y, x)
x = x.reshape(D, B, T, C)
y = y.reshape(D, B, T)

# define optimizer
tx = optax.adam(0.01)
optimizer = nnx.Optimizer(m, tx, wrt=nnx.Param)

# define loss function
def loss_fn(m, x, y):
    logits = m(x)
    preds = jnp.argmax(logits, axis=-1) 
    acc = jnp.sum(preds == y) / (B * T)
    loss = optax.losses.softmax_cross_entropy_with_integer_labels(logits, y)
    return loss.mean(), acc

# define step function
@nnx.jit
def step_fn(m, opt, x, y):
    (loss, acc), grads = nnx.value_and_grad(loss_fn, has_aux=True)(m, x, y)
    opt.update(m, grads)
    return loss, acc


for e in range(10):
    for i in range(D):
        inputs = x[i]
        labels = y[i]
        loss, acc = step_fn(m, optimizer, inputs, labels)
    print(loss, acc)



192
0.09010549 0.9661459
0.08453002 0.9713542
0.08273953 0.96875
0.082466975 0.96875
0.08248196 0.96875
0.082847305 0.9661459
0.083282135 0.9661459
0.08367075 0.9661459
0.08404602 0.9661459
0.08454438 0.9661459


In [103]:
m.experts[1]

Array([[ 1.0953181 , -0.00379145,  0.01928024],
       [-0.026022  ,  0.68088347, -0.00745997],
       [ 0.00885581, -0.02380599,  1.2311255 ]], dtype=float32)

In [104]:
r2

Array([[ 0.6075079 , -0.5394046 , -0.52474177],
       [-0.07122789,  0.01275766,  1.0289301 ],
       [-1.1398981 ,  0.5853884 ,  1.153948  ]], dtype=float32)