In [1]:
import os

os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=2'

from functools import partial
from dataclasses import dataclass
import random

import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec, NamedSharding, Mesh
from jax.debug import visualize_array_sharding as viz

import flax.nnx as nnx
import optax

from jaxpt.modules.config import Config

devices = jax.devices()
print(devices)

mesh = Mesh(devices, ("devices"))
spec = PartitionSpec(None, "devices")
sharding = NamedSharding(mesh, spec)

@dataclass(unsafe_hash=True)
class GLU_Config(Config):
    top_k = 1
    load_factor = 2.00
    n_experts = len(devices)
    n_embed = 64
    n_mlp_hidden = 6
    mlp_bias = True
    dtype = jax.numpy.float32
    mesh = mesh

config = GLU_Config()


class FFN(nnx.Module):
    def __init__(self, config, rngs):
        init = nnx.with_partitioning(
            nnx.initializers.normal(stddev=0.02),
            sharding=(None,))

        self.w1 = nnx.Param(init(rngs.normal.key.value,
            (
                config.n_embed,
                config.n_embed
            )
        ))
    def __call__(self, x):
        y = x @ self.w1
        return y, 0

class Expert(nnx.Module):
    def __init__(self, config, rngs):
        init = nnx.with_partitioning(
            nnx.initializers.normal(stddev=0.02),
            sharding=("devices",))

        self.w1 = nnx.Param(init(rngs.normal.key.value,
            (
                config.n_experts,
                config.n_embed,
                config.n_embed
            )
        ))

    def __call__(self, x, expert_idx):
        # Use only one expert's weights
        w1 = self.w1[expert_idx] 
        x = x @ w1
        return x

@nnx.jit(static_argnums=(0, 1)) #, out_shardings=sharding)
def create_sharded_model(Model, config, rngs):
    model = Model(config=config, rngs=rngs)
    graphdef, state = nnx.split(model) 
    pspecs = nnx.get_partition_spec(state)
    sharded_state = nnx.with_sharding_constraint(
        state, pspecs, mesh=config.mesh
        )
    nnx.update(model, sharded_state)
    return model


class MOE(nnx.Module):
    def __init__(self, config: Config, rngs: nnx.Rngs):
        self.router_gate = nnx.Linear(
            config.n_embed,
            config.n_experts,
            kernel_init=nnx.with_partitioning(
                nnx.initializers.normal(stddev=0.02),
                sharding=(None,)),
            bias_init=nnx.with_partitioning(nnx.initializers.zeros, 
            sharding=(None,)),
            use_bias=config.mlp_bias,
            dtype=config.dtype,
            rngs=rngs,
        )
        self.expert = Expert(config, rngs)        
        self.top_k = config.top_k
        self.n_experts = config.n_experts
        self.load_factor = config.load_factor
        self.add_noise = False
        self.rngs = rngs

    def __call__(self, x):
        B, C = x.shape
        logits = self.router_gate(x) # B, n_experts
        if self.add_noise:
            logits += jax.random.normal(key=self.rngs.gate_noise(), shape=logits.shape)
        top_k_logits, expert_indices = jax.lax.top_k(logits, self.top_k) # B, top_k

        zeros = jnp.full_like(logits, float('-inf')) # B, n_experts
        sparse_logits = jnp.put_along_axis(
                zeros, expert_indices, top_k_logits, axis=-1, inplace=False) # b, n_experts
        expert_weights = jax.nn.softmax(sparse_logits, axis=-1) # B, n_experts

        max_tokens_per_expert = int((self.load_factor * B) // self.n_experts)
        expert_inputs = jnp.zeros((self.n_experts, max_tokens_per_expert, C))
        input_counters = jnp.zeros((self.n_experts,), dtype=jnp.uint8)

        def update_expert_inputs(i, carry):
            expert_inputs, counters = carry
            for j in range(self.top_k):
                expert_idx = expert_indices[i, j]
                token_pos = counters[expert_idx]
                expert_inputs = expert_inputs.at[expert_idx, token_pos].set(x[i])
                counters = counters.at[expert_idx].add(1)

            return expert_inputs, counters
        
        expert_inputs, input_counters = jax.lax.fori_loop(
            0, B, update_expert_inputs, (
                expert_inputs,
                input_counters
            )
        )

        # Gather the current expert's inputs 
        expert_inputs_gathered = jax.lax.all_to_all(expert_inputs, "i", 0, 0)
        # Run the current expert on the gathered inputs
        device_index = jax.lax.axis_index("i")
        # expert_outputs = expert_inputs_gathered @ (jnp.eye(C, C)*device_index)
        expert_outputs = self.expert(expert_inputs, device_index)
        # Redistribute the outputs back to the devices of origin
        expert_outputs_gathered = jax.lax.all_to_all(expert_outputs, "i", 0, 0)

        output_counters = jnp.zeros((self.n_experts,), dtype=jnp.uint8)
        #y = jnp.zeros((B,))
        y = jnp.zeros_like(x)
        def update_expert_outputs(i, carry):
            y, counters = carry
            for j in range(self.top_k):
                expert_idx = expert_indices[i, j]
                token_pos = counters[expert_idx]
                y = y.at[i].add(
                    expert_outputs_gathered[expert_idx, token_pos] * expert_weights[i, expert_idx])
                counters = counters.at[expert_idx].add(1)

            return y, counters

        y, output_counters = jax.lax.fori_loop(
            0, B, update_expert_outputs, (
                y,
                output_counters
            )
        )

        return y, expert_inputs, expert_outputs_gathered, input_counters, expert_indices, expert_weights

def loss_fn(model, x, y):
    y_pred, expert_inputs, expert_outputs, counters, expert_indices, expert_weights = model(x)
    #loss = -jnp.mean(y * jnp.log(y_pred + 1e-7) + (1 - y) * jnp.log(1 - y_pred + 1e-7))
    loss = jnp.mean((y - y_pred)**2)
    return loss, (expert_outputs, expert_inputs, counters, expert_indices, expert_weights, y_pred)

@nnx.pmap(axis_name="i", in_axes=(None, 0, 0), out_axes=(0))
def step(state, x, y):
    (loss, (expert_outputs, expert_inputs, counters, expert_indices, expert_weights, y_pred)), grads = nnx.value_and_grad(
        loss_fn, has_aux=True)(state.model, x, y)
    grads = jax.lax.pmean(grads, axis_name="i")
    state.update(grads)
    loss = jax.lax.pmean(loss, axis_name="i")
    return loss, grads, expert_outputs, expert_inputs, counters, expert_indices, expert_weights, y_pred

D, B, C = len(devices) * 1000, 16, config.n_embed 
   
default = jax.random.key(0)
gate_noise = jax.random.key(42)
rngs = nnx.Rngs(default=default, gate_noise=gate_noise)

#model = MOE(config, rngs)
model = create_sharded_model(MOE, config, rngs)
model.train(add_noise=True)
tx = optax.adam(1e-3)
state = nnx.Optimizer(model, tx)

x = jax.random.normal(jax.random.key(1000), (D * B, C))

expert_ids = (x[:, 0] > 0).astype(jnp.int32)
t = [
    jax.random.normal(jax.random.key(2000), (C, C)),
    jax.random.normal(jax.random.key(3000), (C, C)),
]
def transform(xi, eid):
    return jnp.where(eid == 1, xi @ t[0], xi @ t[1])

y = jax.vmap(lambda xi, ei: transform(xi, ei))(x, expert_ids)

x = x.reshape(D//len(devices), len(devices), B, C)
y = y.reshape(D//len(devices), len(devices), B, C)

indices = list(range(D//len(devices)))
for e in range(100):
    for i in indices:
        loss, grads, expert_outputs, expert_inputs, counters, expert_indices, expert_weights, y_pred = step(state, x[i], y[i])
        if i % 1000 == 0:
            print(e, i, loss[0])


[CpuDevice(id=0), CpuDevice(id=1)]
0 0 68.96336
1 0 63.26626
2 0 60.674118
3 0 58.08045
4 0 57.524876
5 0 57.85038
6 0 56.621796
7 0 58.736324
8 0 56.391182
9 0 56.643585
10 0 57.761555
11 0 57.837578
12 0 56.986187
13 0 58.267403
14 0 58.81224
15 0 59.377
16 0 61.09084
17 0 61.505043
18 0 58.084328
19 0 59.36465
20 0 58.707935
21 0 56.86256
22 0 57.731113
23 0 61.200844
24 0 61.464096
25 0 57.238068
26 0 55.824425
27 0 56.6616
28 0 58.494892
29 0 58.357327
30 0 55.158035
31 0 56.34519
32 0 58.275215
33 0 58.631767
34 0 63.283398
35 0 58.351143


KeyboardInterrupt: 