In [1]:
import os

os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'


In [None]:
from functools import partial
from dataclasses import dataclass

import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec, NamedSharding, Mesh
from jax.debug import visualize_array_sharding as viz

import flax.nnx as nnx

from jaxpt.modules.config import Config

devices = jax.devices()
print(devices)

mesh = Mesh(devices, ("devices"))
spec = PartitionSpec(None, "devices")
sharding = NamedSharding(mesh, spec)

@dataclass(unsafe_hash=True)
class GLU_Config(Config):
    top_k = 2
    load_factor = 1.25
    n_experts = 8
    n_embed = 4
    n_mlp_hidden = 6
    mlp_bias = True
    dtype = jax.numpy.float32
    mesh = mesh

config = GLU_Config()


class Expert(nnx.Module):
    def __init__(self, config, rngs):
        init = nnx.with_partitioning(
            nnx.initializers.normal(stddev=0.02),
            sharding=("devices",))

        self.weight = nnx.Param(init(rngs.normal.key.value,
            (
                config.n_experts,
                config.n_embed,
                config.n_embed
            )
        ))
    def __call__(self, x, expert_idx):
        # Use only one expert's weights
        w = self.weight[expert_idx]  # Slice along the expert axis
        return x @ w  # x: [batch, in_dim], w: [in_dim, out_dim]

@nnx.jit(static_argnums=(0, 1)) #, out_shardings=sharding)
def create_sharded_model(Model, config, rngs):
    model = Model(config=config, rngs=rngs)
    graphdef, state = nnx.split(model) 
    pspecs = nnx.get_partition_spec(state)
    sharded_state = nnx.with_sharding_constraint(
        state, pspecs, mesh=config.mesh
        )
    nnx.update(model, sharded_state)
    return model


class MOE(nnx.Module):
    def __init__(self, config: Config, rngs: nnx.Rngs):
        self.router_gate = nnx.Linear(
            config.n_embed,
            config.n_mlp_hidden,
            kernel_init=nnx.with_partitioning(
                nnx.initializers.normal(stddev=0.02),
                sharding=(None, None)),
            bias_init=nnx.with_partitioning(nnx.initializers.zeros, 
            sharding=(None,)),
            use_bias=config.mlp_bias,
            dtype=config.dtype,
            rngs=rngs,
        )
        self.expert = Expert(config, rngs)        
        #devices = jax.devices()
        #for i in range(config.n_experts):
        #    self.experts[i].to(devices[i])
        self.top_k = config.top_k
        self.n_experts = config.n_experts
        self.load_factor = config.load_factor

    def __call__(self, x):
        B, T, C = x.shape
        logits = self.router_gate(x)
        zeros = jnp.full_like(logits, float('-inf'))
        top_k_logits, expert_indices = jax.lax.top_k(logits, self.top_k)
        sparse_logits = jnp.put_along_axis(
            zeros, expert_indices, top_k_logits, axis=-1, inplace=False)
        expert_weights = jax.nn.softmax(sparse_logits, axis=-1)

        tokens_per_expert = int((self.load_factor * B * T) // self.n_experts)
        expert_inputs = jnp.zeros((self.n_experts, tokens_per_expert, C))
        # TODO: load the tokens into the expert inputs and track order

        # Gather the current expert's inputs 
        expert_inputs = jax.lax.all_to_all(expert_inputs, "i", 0, 0)
        device_index = jax.lax.axis_index("i")
        expert_outputs = self.expert(expert_inputs, device_index)
        expert_outputs = jax.lax.all_to_all(expert_outputs, "i", 0, 0)

        # TODO: rearrange the tokens in their original shape 

        return expert_outputs


@nnx.pmap(axis_name="i")
def step(x):
    y = model(x)
    return y

B, T, C = 16, 4, 4
   
key = jax.random.key(0)
rngs = nnx.Rngs(key)

#model = MOE(config, rngs)
model = create_sharded_model(MOE, config, rngs)
x = jax.random.normal(key=key, shape=(len(devices), B, T, C))
print(x.shape)
y = step(x)
print(y.shape)
 

[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]
(8, 16, 4, 4)
(8, 8, 10, 4)


In [124]:
import jax
import jax.numpy as jnp
import flax.nnx as nnx

class SlicedLinear(nnx.Module):
    def __init__(self, in_dim, out_dim, num_experts):
        self.weight = nnx.Param(jax.random.normal(jax.random.key(0), (num_experts, in_dim, out_dim)))

    def __call__(self, x, expert_idx):
        # Use only one expert's weights
        w = self.weight[expert_idx]  # Slice along the expert axis
        return x @ w  # x: [batch, in_dim], w: [in_dim, out_dim]


@nnx.jit
def step(m, x, i):
    y = m(x, i)
    return y


l = SlicedLinear(4, 8, 8)
x = jax.random.normal(jax.random.key(1), (8, 4))
y = step(l, x, 0)

In [56]:
import jax
import jax.numpy as jnp
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from jax.experimental import mesh_utils

# Setup mesh (1 expert per device)
devices = mesh_utils.create_device_mesh((8,))
mesh = Mesh(devices, axis_names=('expert',))

# Dimensions
batch_size, hidden_dim = 16, 32
num_experts, top_k = 8, 2
tokens_per_expert, expert_output_dim = 8, 32

key = random.PRNGKey(42)
x = random.normal(key, (batch_size, hidden_dim))

# Gate routing
logits = random.normal(random.fold_in(key, 1), (batch_size, num_experts))
topk_vals, topk_idx = jax.lax.top_k(logits, top_k)
gate_scores = jax.nn.softmax(topk_vals)

# Allocate token buffers
input_buf = jnp.zeros((num_experts, tokens_per_expert, hidden_dim))
mask_buf = jnp.zeros((num_experts, tokens_per_expert), dtype=bool)
counter = jnp.zeros((num_experts,), dtype=int)

def dispatch(i, carry):
    buf, mask, ctr = carry
    for j in range(top_k):
        e = topk_idx[i, j]
        s = ctr[e]
        buf = buf.at[e, s].set(x[i] * gate_scores[i, j])
        mask = mask.at[e, s].set(True)
        ctr = ctr.at[e].set(s + 1)
    return buf, mask, ctr

input_buf, mask_buf, _ = jax.lax.fori_loop(0, batch_size, dispatch, (input_buf, mask_buf, counter))

# Shard tokens and weights across devices
sharding = NamedSharding(mesh, PartitionSpec("expert", None, None))
input_sharded = jax.device_put(input_buf, sharding)
mask_sharded = jax.device_put(mask_buf, NamedSharding(mesh, PartitionSpec("expert", None)))

# Create per-expert weights
W1 = jnp.stack([random.normal(random.fold_in(key, i), (hidden_dim, expert_output_dim)) for i in range(num_experts)])
W2 = jnp.stack([random.normal(random.fold_in(key, i+100), (expert_output_dim, hidden_dim)) for i in range(num_experts)])

W1_sharded = jax.device_put(W1, NamedSharding(mesh, PartitionSpec("expert", None, None)))
W2_sharded = jax.device_put(W2, NamedSharding(mesh, PartitionSpec("expert", None, None)))

# MoE forward
@jax.jit
def moe_forward(x, mask, W1, W2):
    def apply_expert(inputs, mask, W1, W2):
        return jnp.where(mask[:, None], jax.nn.relu(inputs @ W1) @ W2, 0.0)
    return jax.vmap(apply_expert, in_axes=(0, 0, 0, 0))(x, mask, W1, W2)

with mesh:
    output = moe_forward(input_sharded, mask_sharded, W1_sharded, W2_sharded)