### 

## Soft Mixture of Experts

In [2]:
import jax
import jax.numpy as jnp
import flax.nnx as nnx

In [7]:
class SoftMoE(nnx.Module):
    def __init__(self, n_experts, capacity, n_dim, rngs):
        init = nnx.initializers.normal(stddev=0.02)
        self.gate = nnx.Param(
            init(rngs.default(), (n_dim, n_experts * capacity))
        )
        self.experts = nnx.Param(
            init(rngs.default(), (n_experts, n_dim, n_dim))
        )

    def _dispatch(self, x, dw):
        T, E = x.shape
        y = jnp.einsum('td,ts->ds', x, dw)
        return y

    def _combine(self, eo, cw):
        E, S = eo.shape
        y = jnp.einsum('ds,ts->td', eo, cw)
        return y

    def __call__(self, x):
        B, T, C = x.shape
        g = jnp.einsum('btd,ds->bts', x, self.gate) # B, T, E * capacity
        dw = jax.nn.softmax(g, axis=1)
        cw = jax.nn.softmax(g, axis=2)

        eo = jax.vmap(lambda x, w: self._dispatch(x, w))(x, dw) # B, E, S
        y = jax.vmap(lambda o, w: self._combine(o, w))(eo, cw) # B, T, C

        return y

In [9]:
B, T, C = 16, 24, 3
n_experts = 8
capacity = B * T // n_experts
print(capacity)
rngs = nnx.Rngs(default=0)
m = SoftMoE(n_experts, capacity, C, rngs)
x = jax.random.normal(rngs.default(), (B, T, C))
y = m(x)
y.shape
assert(x.shape==y.shape)

48
