In [1]:
import os

os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'


In [None]:
from functools import partial
from dataclasses import dataclass

import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec, NamedSharding, Mesh
from jax.debug import visualize_array_sharding as viz

import flax.nnx as nnx

from jaxpt.modules.config import Config

devices = jax.devices()
print(devices)

mesh = Mesh(devices, ("devices"))
spec = PartitionSpec(None, "devices")
sharding = NamedSharding(mesh, spec)

@dataclass(unsafe_hash=True)
class GLU_Config(Config):
    top_k = 2
    load_factor = 2.00
    n_experts = 8
    n_embed = 5
    n_mlp_hidden = 6
    mlp_bias = True
    dtype = jax.numpy.float32
    mesh = mesh

config = GLU_Config()


class Linear(nnx.Module):
    def __init__(self, config, rngs):
        init = nnx.with_partitioning(
            nnx.initializers.normal(stddev=0.02),
            sharding=(None,))

        self.weight = nnx.Param(init(rngs.normal.key.value,
            (
                config.n_embed,
                config.n_embed
            )
        ))
    def __call__(self, x):
        # Use only one expert's weights
        w = self.weight  # Slice along the expert axis
        return x @ w  # x: [batch, in_dim], w: [in_dim, out_dim]


class Expert(nnx.Module):
    def __init__(self, config, rngs):
        init = nnx.with_partitioning(
            nnx.initializers.normal(stddev=0.02),
            sharding=("devices",))

        self.weight = nnx.Param(init(rngs.normal.key.value,
            (
                config.n_experts,
                config.n_embed,
                config.n_embed
            )
        ))
    def __call__(self, x, expert_idx):
        # Use only one expert's weights
        w = self.weight[expert_idx]  # Slice along the expert axis
        return x @ w  # x: [batch, in_dim], w: [in_dim, out_dim]

@nnx.jit(static_argnums=(0, 1)) #, out_shardings=sharding)
def create_sharded_model(Model, config, rngs):
    model = Model(config=config, rngs=rngs)
    graphdef, state = nnx.split(model) 
    pspecs = nnx.get_partition_spec(state)
    sharded_state = nnx.with_sharding_constraint(
        state, pspecs, mesh=config.mesh
        )
    nnx.update(model, sharded_state)
    return model


class MOE(nnx.Module):
    def __init__(self, config: Config, rngs: nnx.Rngs):
        self.router_gate = nnx.Linear(
            config.n_embed,
            config.n_experts,
            kernel_init=nnx.with_partitioning(
                nnx.initializers.normal(stddev=0.02),
                sharding=(None,)),
            bias_init=nnx.with_partitioning(nnx.initializers.zeros, 
            sharding=(None,)),
            use_bias=config.mlp_bias,
            dtype=config.dtype,
            rngs=rngs,
        )
        self.expert = Expert(config, rngs)        
        self.top_k = config.top_k
        self.n_experts = config.n_experts
        self.load_factor = config.load_factor

    def __call__(self, x):
        B, T, C = x.shape
        x_flat = x.reshape(-1, C)
        logits = self.router_gate(x_flat) # B * T, n_experts
        top_k_logits, expert_indices = jax.lax.top_k(logits, self.top_k) # B * T, top_k

        zeros = jnp.full_like(logits, float('-inf')) # B * T, n_experts
        sparse_logits = jnp.put_along_axis(
                zeros, expert_indices, top_k_logits, axis=-1, inplace=False)
        expert_weights = jax.nn.softmax(sparse_logits, axis=-1)

        max_tokens_per_expert = int((self.load_factor * B * T) // self.n_experts)
        expert_inputs = jnp.zeros((self.n_experts, max_tokens_per_expert, C))
        expert_mask = jnp.zeros((self.n_experts, max_tokens_per_expert))
        counters = jnp.zeros((self.n_experts,), dtype=jnp.uint8)

        def update_expert_inputs(i, carry):
            expert_inputs, expert_mask, counters = carry
            for j in range(self.top_k):
                expert_idx = expert_indices[i, j]
                token_pos = counters[expert_idx]
                expert_inputs = expert_inputs.at[expert_idx, token_pos].set(x_flat[i] * expert_weights[i, expert_idx])
                expert_mask = expert_mask.at[expert_idx, token_pos].set(1)
                counters = counters.at[expert_idx].add(1)

            return expert_inputs, expert_mask, counters
        
        expert_inputs, expert_mask, _ = jax.lax.fori_loop(
            0, B * T, update_expert_inputs, (
                expert_inputs,
                expert_mask,
                counters
            )
        )
            
        # Gather the current expert's inputs 
        expert_inputs = jax.lax.all_to_all(expert_inputs, "i", 0, 0)
        # Run the current expert on the gathered inputs
        device_index = jax.lax.axis_index("i")
        #expert_outputs = expert_inputs @ jnp.eye(C, C)
        expert_outputs = self.expert(expert_inputs, device_index)
        # Redistribute the outputs back to the devices of origin
        expert_outputs = jax.lax.all_to_all(expert_outputs, "i", 0, 0)

        y = jnp.zeros_like(x_flat)
        def update_expert_outputs(i, carry):
            y, counters = carry
            for j in range(self.top_k):
                expert_idx = expert_indices[i, j]
                token_pos = counters[expert_idx]
                y = y.at[i].set(expert_outputs[expert_idx, token_pos])
                counters = counters.at[expert_idx].add(1)

            return y, counters

        y, _ = jax.lax.fori_loop(
            0, B * T, update_expert_outputs, (
                y,
                counters
            )
        )

        return y.reshape(B, T, C)

def loss_fn(model, x, y):
    y_pred = model(x)
    loss = jnp.mean((y_pred - y)**2)
    return loss

@nnx.pmap(axis_name="i", in_axes=(None, 0, 0), out_axes=(0))
def step(model, x, y):
    loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
    state = nnx.state(model)
    state = jax.tree_map(
        lambda param, g: param - 0.001 * g,
        state, grads,
        is_leaf=lambda x: isinstance(x, nnx.Param)
    )
    nnx.update(model, state)
    loss = jax.lax.pmean(loss, axis_name="i")
    return loss, grads

D, B, T, C = 40000, 16, 4, 5 
   
key = jax.random.key(0)
rngs = nnx.Rngs(key)

#model = MOE(config, rngs)
model = create_sharded_model(MOE, config, rngs)
x = jax.random.normal(key=key, shape=(D, len(devices), B, T, C))
t = jax.random.normal(key=key, shape=(C, C))
y = x @ t

for e in range(100):
    for i in range(D):
        loss, grads = step(model, x[i], y[i])
        if i % 1000 == 0:
            print(loss[0])


In [None]:
nnx.display(grads)