In [25]:
import jax
import jax.numpy as jnp
from jax import random
import math
from typing import Callable
from einops import rearrange

In [27]:
try:
    import flax
except ModuleNotFoundError: # Install flax if missing
    !pip install --quiet flax
    import flax

from flax import linen as nn

In [28]:
rng = jax.random.PRNGKey(1)
rng, inp_rng, init_rng = jax.random.split(rng, 3)

B = 4
T = 8
C = 32
# inp = jax.random.randint(inp_rng, (B, T), 0, 256)  # Batch size 8, input size 2
inp = jax.random.normal(inp_rng, (B, T, C)) 
inp.shape

(4, 8, 32)

In [29]:
class FeedForward(nn.Module):
    model_dimension : int
    ff_dim : int
    dropout : float

    def setup(self):
        self.linear1 = nn.Dense(features=self.ff_dim)
        self.linear2 = nn.Dense(features=self.model_dimension)
        self.dropout_layer = nn.Dropout(rate=self.dropout) 
    
    def __call__(self, x, train: bool = True):
        x = self.linear1(x)
        x = nn.relu(x)
        x = self.dropout_layer(x, deterministic=not train)
        x = self.linear2(x)
        return x

In [30]:
class NoisyKGate(nn.Module):
    model_dimension : int
    n_experts : int
    k : int

    def setup(self):
        self.rng = jax.random.PRNGKey(42)
        self.Wg = nn.Dense(features=self.n_experts)
        self.Wnoise = nn.Dense(features=self.n_experts)
    
    def top(self, x):
        k = self.k
        y,i = jax.lax.top_k(x, k)
        y = nn.softmax(y)
        return y, i
    
    def __call__(self, x):
        b = x.shape[0]
        t = x.shape[1]
        Hx = self.Wg(x) + ((jax.random.normal(self.rng, shape=(b, t, self.n_experts))) * nn.softplus(self.Wnoise(x)))
        g_scores, indices = jnp.apply_along_axis(func1d=self.top, axis=-1, arr=Hx)
        return g_scores, indices

In [110]:
class MoE(nn.Module):
    model_dimension : int
    n_experts : int
    k : int
    dropout : float
    
    def setup(self):
        self.experts = [FeedForward(model_dimension=self.model_dimension, ff_dim=4*self.model_dimension, dropout=self.dropout) for i in range(self.n_experts)]
        self.gate = NoisyKGate(model_dimension=self.model_dimension, n_experts=self.n_experts, k=self.k)

    def get_gScores(self, scores, indices, x, train=True):
        

        expert_lambda = [lambda mdl, x: mdl.experts[i](x) for i in range(self.n_experts)] 

        if self.is_mutable_collection('params'):
            for expert_ffn in expert_lambda: 
                _ = expert_ffn(self, x) 
                
        expert_fn = lambda j, experts, x : nn.switch(j, expert_lambda, self, x)
        expert_parallel = jax.vmap(fun=expert_fn, in_axes=(0, None, None), out_axes=(0)) 

        expert_scores = expert_parallel(indices, self.experts, x) # (K) -> (K, C)
        gScore = scores[:, None] * expert_scores #(K, 1), (K, C) -> (K, C)
        gScore = jnp.sum(gScore, axis=0) #(K, C) -> C

        return gScore
    
    def __call__(self, x):
        s, i = self.gate(x)
        gscore_parallel = jax.vmap(fun=jax.vmap(fun=self.get_gScores, in_axes=(0,0,0), out_axes=(0)), in_axes=(0,0,0), out_axes=(0))
        res = gscore_parallel(s, i, x)
        return res

In [107]:
model = MoE(model_dimension=C, n_experts=8, k=2, dropout=0.2)
print(model)

MoE(
    # attributes
    model_dimension = 32
    n_experts = 8
    k = 2
    dropout = 0.2
)


In [109]:
# Initialize the model
params = model.init(init_rng, inp)

print(jax.tree.reduce(lambda acc, p: acc + p.size, jax.tree.leaves(params), 0))
print(inp.shape)


7
8880
(4, 8, 32)


In [102]:
print(inp)

[[[-0.24392003  0.12287012 -0.7633101  ... -0.11720864 -1.5185605
    1.5217804 ]
  [-0.84688747  0.30684185 -0.07057513 ... -0.9426277  -0.38939518
   -0.7442158 ]
  [-0.81580335  0.13533224  0.6767916  ...  0.50519985 -0.9484449
   -0.61324644]
  ...
  [-0.9497825   0.5360208   1.7091829  ... -1.7208983   1.242111
   -0.47621638]
  [ 0.46530464 -0.7567534  -1.0421433  ...  1.1826577  -0.56831634
   -0.29065755]
  [ 1.9868064   0.9201935  -0.22226302 ...  2.0525253  -0.37513527
    1.0895865 ]]

 [[-1.2303954  -2.248714    1.058285   ... -0.1294223  -0.77748317
    2.0040932 ]
  [ 0.15732075  1.1365863   0.08460022 ... -0.66197455 -1.0347629
   -1.2902288 ]
  [-1.5851781  -0.7893366  -0.96769536 ...  0.7012614  -1.1963965
   -1.5259213 ]
  ...
  [ 0.9598219  -0.97783214 -1.4395024  ...  0.39721426  2.2276711
   -0.61668503]
  [ 0.6563103   0.52220154 -0.5418465  ... -0.67837745  2.0536146
   -0.17622429]
  [-0.9547713   0.03570899 -0.00437756 ... -2.0016026   0.79181325
    0.22056514

In [36]:
model.apply(params, inp).shape


AttributeError: 'tuple' object has no attribute 'shape'

In [24]:
n_experts = 8
model_dim = 32
dropout = 0.2

experts = [FeedForward(model_dimension=model_dim, ff_dim=4*model_dim, dropout=dropout) for i in range(n_experts)]

def expert_fn(j):
    jax.lax.switch(j, experts, inp)
    

# expert_fn = lambda j : experts[j](inp, train=True) # (C) -> (C)
expert_parallel = jax.vmap(fun=expert_fn, in_axes=(0), out_axes=(0)) 
expert_dims = expert_parallel(i[0][0]) # (K) -> (K, C)
expert_dims
# gscore = scores[:, None] * experts #(K, 1), (K, C) -> (K, C)
# gscore = jnp.sum(gscore, axis=0) #(K, C) -> C

AttributeError: "FeedForward" object has no attribute "linear1". If "linear1" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'.