In [1]:
import os

os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=2'


In [None]:
from functools import partial
from dataclasses import dataclass
import random

import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec, NamedSharding, Mesh
from jax.debug import visualize_array_sharding as viz

import flax.nnx as nnx

from jaxpt.modules.config import Config

devices = jax.devices()
print(devices)

mesh = Mesh(devices, ("devices"))
spec = PartitionSpec(None, "devices")
sharding = NamedSharding(mesh, spec)

@dataclass(unsafe_hash=True)
class GLU_Config(Config):
    top_k = 1
    load_factor = 2.00
    n_experts = len(devices)
    n_embed = 16
    n_mlp_hidden = 6
    mlp_bias = True
    dtype = jax.numpy.float32
    mesh = mesh

config = GLU_Config()


class FFN(nnx.Module):
    def __init__(self, config, rngs):
        init = nnx.with_partitioning(
            nnx.initializers.normal(stddev=0.02),
            sharding=(None,))

        self.w1 = nnx.Param(init(rngs.normal.key.value,
            (
                config.n_embed,
                config.n_embed * 3
            )
        ))
        self.w2 = nnx.Param(init(rngs.normal.key.value,
            (
                config.n_embed * 3,
                config.n_embed
            )
        ))
    def __call__(self, x):
        y = nnx.relu(x @ self.w1) @ self.w2
        #return nnx.sigmoid(logits), 0
        return y, 0

class Expert(nnx.Module):
    def __init__(self, config, rngs):
        init = nnx.with_partitioning(
            nnx.initializers.normal(stddev=0.02),
            sharding=("devices",))

        self.w1 = nnx.Param(init(rngs.normal.key.value,
            (
                config.n_experts,
                config.n_embed,
                config.n_embed
            )
        ))
        #self.w2 = nnx.Param(init(rngs.normal.key.value,
        #    (
        #        config.n_experts,
        #        config.n_embed * 3,
        #        config.n_embed
        #    )
        #))

    def __call__(self, x, expert_idx):
        # Use only one expert's weights
        w1 = self.w1[expert_idx] 
        x = x @ w1
        #w2 = self.w2[expert_idx]
        #return nnx.sigmoid(x @ w2)
        #x = x @ w2
        return x

@nnx.jit(static_argnums=(0, 1)) #, out_shardings=sharding)
def create_sharded_model(Model, config, rngs):
    model = Model(config=config, rngs=rngs)
    graphdef, state = nnx.split(model) 
    pspecs = nnx.get_partition_spec(state)
    sharded_state = nnx.with_sharding_constraint(
        state, pspecs, mesh=config.mesh
        )
    nnx.update(model, sharded_state)
    return model


class MOE(nnx.Module):
    def __init__(self, config: Config, rngs: nnx.Rngs):
        self.router_gate = nnx.Linear(
            config.n_embed,
            config.n_experts,
            kernel_init=nnx.with_partitioning(
                nnx.initializers.normal(stddev=0.02),
                sharding=(None,)),
            bias_init=nnx.with_partitioning(nnx.initializers.zeros, 
            sharding=(None,)),
            use_bias=config.mlp_bias,
            dtype=config.dtype,
            rngs=rngs,
        )
        self.expert = Expert(config, rngs)        
        self.top_k = config.top_k
        self.n_experts = config.n_experts
        self.load_factor = config.load_factor

    def __call__(self, x):
        B, C = x.shape
        logits = self.router_gate(x) # B, n_experts
        top_k_logits, expert_indices = jax.lax.top_k(logits, self.top_k) # B, top_k

        zeros = jnp.full_like(logits, float('-inf')) # B, n_experts
        sparse_logits = jnp.put_along_axis(
                zeros, expert_indices, top_k_logits, axis=-1, inplace=False)
        expert_weights = jax.nn.softmax(sparse_logits, axis=-1)

        max_tokens_per_expert = int((self.load_factor * B) // self.n_experts)
        expert_inputs = jnp.zeros((self.n_experts, max_tokens_per_expert, C))
        expert_mask = jnp.zeros((self.n_experts, max_tokens_per_expert))
        counters = jnp.zeros((self.n_experts,), dtype=jnp.uint8)

        def update_expert_inputs(i, carry):
            expert_inputs, expert_mask, counters = carry
            for j in range(self.top_k):
                expert_idx = expert_indices[i, j]
                token_pos = counters[expert_idx]
                expert_inputs = expert_inputs.at[expert_idx, token_pos].set(x[i])
                expert_mask = expert_mask.at[expert_idx, token_pos].set(1)
                counters = counters.at[expert_idx].add(1)

            return expert_inputs, expert_mask, counters
        
        expert_inputs, expert_mask, _ = jax.lax.fori_loop(
            0, B, update_expert_inputs, (
                expert_inputs,
                expert_mask,
                counters
            )
        )
            
        # Gather the current expert's inputs 
        expert_inputs = jax.lax.all_to_all(expert_inputs, "i", 0, 0)
        # Run the current expert on the gathered inputs
        device_index = jax.lax.axis_index("i")
        #expert_outputs = expert_inputs @ jnp.eye(C, C)
        expert_outputs = self.expert(expert_inputs, device_index)
        # Redistribute the outputs back to the devices of origin
        expert_outputs = jax.lax.all_to_all(expert_outputs, "i", 0, 0)

        #y = jnp.zeros((B,))
        y = jnp.zeros_like(x)
        def update_expert_outputs(i, carry):
            y, counters = carry
            for j in range(self.top_k):
                expert_idx = expert_indices[i, j]
                token_pos = counters[expert_idx]
                y = y.at[i].add(
                    expert_outputs[expert_idx, token_pos] * expert_weights[i, expert_idx])
                counters = counters.at[expert_idx].add(1)

            return y, counters

        y, _ = jax.lax.fori_loop(
            0, B, update_expert_outputs, (
                y,
                counters
            )
        )

        return y, jnp.bincount(expert_indices.flatten(), length=self.n_experts)

def loss_fn(model, x, y):
    y_pred, expert_indices = model(x)
    #loss = -jnp.mean(y * jnp.log(y_pred + 1e-7) + (1 - y) * jnp.log(1 - y_pred + 1e-7))
    loss = jnp.mean((y - y_pred)**2)
    return loss, expert_indices

@nnx.pmap(axis_name="i", in_axes=(None, 0, 0), out_axes=(0))
def step(model, x, y):
    (loss, expert_indices), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model, x, y)
    state = nnx.state(model)
    state = jax.tree_map(
        lambda param, g: param - 0.01 * g,
        state, grads,
        is_leaf=lambda x: isinstance(x, nnx.Param)
    )
    nnx.update(model, state)
    loss = jax.lax.pmean(loss, axis_name="i")
    return loss, grads, expert_indices

D, B, C = 1000, 16, config.n_embed 
   
key = jax.random.key(0)
rngs = nnx.Rngs(key)

#model = MOE(config, rngs)
model = create_sharded_model(MOE, config, rngs)

x = jax.random.normal(key=key, shape=(D, len(devices), B, C))
print(x.shape)
t = jnp.stack([
    jax.random.normal(key=jax.random.key(1000), shape=(C, C)),
    jax.random.normal(key=jax.random.key(2000), shape=(C, C)),
])
print(t.shape)
y = jnp.concatenate([
    x[:D//2] @ t[0],
    x[D//2:] @ t[0]
], axis=0)
#l = jnp.concatenate([
#    jnp.zeros((D//2, len(devices), B)),
#    jnp.ones((D//2, len(devices), B))
#], axis=0)
#print(y.shape, l.shape)

expert_bincounts = jnp.zeros((config.n_experts))

#model = nnx.pmap(model, axis_name="i")
#y_pred, _ = model(x[0])
#loss = -jnp.mean(l[0] * jnp.log(y_pred + 1e-7) + (1 - l[0]) * jnp.log(1 - y_pred + 1e-7))
#loss = jnp.mean((y_pred - x[0])**2)

for e in range(100):
    indices = list(range(D))
    random.shuffle(indices)
    for i in indices:
        loss, grads, expert_counts = step(model, x[i], y[i])
        expert_counts = jnp.sum(expert_counts, axis=0)
        expert_bincounts += expert_counts
        if i % 1000 == 0:
            print(loss[0])
