In [None]:
from typing import Literal

import os

os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

import jax

platform : Literal["darwin", "colab", "cuda", "tpu"] = "darwin"

try:
    import google.colab
    platform = "colab"
except ImportError:
    devices = jax.devices()
    if any(d.platform == "gpu" for d in devices):
        platform = "cuda"
    if any(d.platform == "tpu" for d in devices):
        platform = "tpu"

print(f"Running on {platform}")

if platform == "colab":
    !git clone https://github.com/novastar53/jaxpt
    !cd jaxpt && git checkout main && git pull
    !pip install tiktoken datasets --quiet
    #!pip uninstall -y tensorflow
    !pip install tensorboard
    !pip install -U tensorboard-plugin-profile

from pathlib import Path
import sys

if platform == "colab":
    jaxpt_dir = str(Path().absolute() / "jaxpt" / "src" )
else:
    jaxpt_dir = str(Path().absolute().parent / "src" )


sys.path.append(jaxpt_dir)
print(jaxpt_dir)

In [None]:

from functools import partial
from dataclasses import dataclass
import random

import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec, NamedSharding, Mesh
from jax.debug import visualize_array_sharding as viz

import flax.nnx as nnx
import optax

from jaxpt.modules.config import Config
from jaxpt.utils import create_sharded_model

'''
@nnx.jit(static_argnums=(0, 1)) #, out_shardings=sharding)
def create_sharded_model(Model, config, rngs):
    model = Model(config=config, rngs=rngs)
    graphdef, state = nnx.split(model)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = nnx.with_sharding_constraint(
        state, pspecs, mesh=config.mesh
        )
    nnx.update(model, sharded_state)
    return model
'''

devices = jax.devices()
print(devices)

mesh = Mesh(devices, ("devices"))
spec = PartitionSpec("devices",)
sharding = NamedSharding(mesh, spec)

@dataclass(unsafe_hash=True)
class MOE_Config(Config):
    top_k = 2
    load_factor = 1.00
    n_experts = len(devices)
    n_embed = 3
    n_mlp_hidden = 6
    mlp_bias = True
    dtype = jax.numpy.float32
    mesh = mesh

config = MOE_Config()


class Experts(nnx.Module):
    def __init__(self, config, rngs):
        init = nnx.with_partitioning(
            nnx.initializers.normal(stddev=0.02),
            sharding=("devices",))

        self.w = nnx.Param(init(rngs.default(),

            (
                config.n_experts,
                config.n_embed,
                config.n_embed
            )
        ))

    def __call__(self, x):
        x = jax.lax.with_sharding_constraint(x, spec)
        y = jnp.einsum('eti,eio->eto', x, self.w)
        y = jax.lax.with_sharding_constraint(y, spec)
        return y


class MOE(nnx.Module):
    def __init__(self, config: Config, rngs: nnx.Rngs):
        self.router_gate = nnx.Linear(
            config.n_embed,
            config.n_experts,
            kernel_init=nnx.with_partitioning(
                nnx.initializers.normal(stddev=0.02),
                sharding=(None,)),
            bias_init=nnx.with_partitioning(nnx.initializers.zeros,
            sharding=(None,)),
            use_bias=config.mlp_bias,
            dtype=config.dtype,
            rngs=rngs,
        )
        self.experts = Experts(config, rngs)
        self.top_k = config.top_k
        self.n_experts = config.n_experts
        self.load_factor = config.load_factor
        self.add_noise = False
        self.rngs = rngs

    def __call__(self, x):
        B, C = x.shape
        logits = self.router_gate(x) # B, n_experts
        logits = jax.lax.with_sharding_constraint(logits, spec) # B, n_experts
        #if self.add_noise:
        #    logits += 0.01 * jax.random.normal(key=self.rngs.gate_noise(), shape=logits.shape)

        top_k_logits, expert_indices = jax.lax.top_k(logits, self.top_k) # B, top_k
        zeros = jnp.full_like(logits, float('-inf')) # B, n_experts
        sparse_logits = jnp.put_along_axis(
                zeros, expert_indices, top_k_logits, axis=-1, inplace=False) # b, n_experts
        expert_probs = jax.nn.softmax(sparse_logits, axis=-1) # B, n_experts  
        
        expert_indices_mask = jax.nn.one_hot(expert_indices, num_classes=self.n_experts, axis=-1) # B, n_experts, 2
        expert_indices_mask = jnp.sum(expert_indices_mask, axis=1)  # B, n_experts
        expert_token_positions = jnp.cumsum(expert_indices_mask, axis=0) * expert_indices_mask # B, n_experts
        expert_input_experts, expert_input_token_idxs = jnp.nonzero(expert_token_positions.T, size=B * self.top_k) # B * top_k, B * top_k
        expert_input_positions = jnp.int32(expert_token_positions.T[expert_input_experts, expert_input_token_idxs]) - 1 # B * top_k
        expert_probs_flattened = expert_probs.T[expert_input_experts, expert_input_token_idxs] # B * top_k
         
        expert_inputs = jnp.zeros((self.n_experts, self.top_k * B, C))
        expert_inputs = expert_inputs.at[expert_input_experts, expert_input_positions].set(x[expert_input_token_idxs])
        expert_inputs = jax.lax.with_sharding_constraint(expert_inputs, spec)
        input_counters = jnp.max(expert_input_positions, axis=0)

        f = input_counters / B
        P = jnp.mean(expert_probs, axis=0)
        aux_loss = jnp.sum(f * P) / (self.n_experts ** 2)

        expert_outputs = self.experts(expert_inputs) # n_experts, expert_capacity
        expert_outputs = jax.lax.with_sharding_constraint(expert_outputs, spec)

        y = jnp.zeros_like(x)
        y = y.at[expert_input_token_idxs].add(expert_outputs[expert_input_experts, expert_input_positions]*expert_probs[expert_input_token_idxs, expert_input_experts][..., None])
        y = jax.lax.with_sharding_constraint(y, spec)

        return y, 0

def loss_fn(model, x, y):
    y_pred, aux_loss = model(x)
    loss = jnp.mean((y - y_pred)**2) + 0.01 * aux_loss
    return loss

@nnx.jit
def step(state, x, y):
    x = jax.lax.with_sharding_constraint(x, spec)
    y = jax.lax.with_sharding_constraint(y, spec)
    loss, grads = nnx.value_and_grad(loss_fn)(state.model, x, y)
    state.update(grads)
    return loss, grads



In [None]:
with mesh:
    D, B, C = 1000, 2 * len(devices), config.n_embed

    default = jax.random.key(69)
    gate_noise = jax.random.key(42)
    rngs = nnx.Rngs(default=default, gate_noise=gate_noise)
    model = create_sharded_model(MOE, config, rngs)
    model.train(add_noise=True)
    tx = optax.adam(1e-3)
    state = nnx.Optimizer(model, tx)

    x = jax.random.normal(jax.random.key(1000), (D * B, C))

    expert_ids = (x[:, 0] > 0).astype(jnp.int32)
    t = [
        jax.random.normal(jax.random.key(2000), (C, C)),
        jax.random.normal(jax.random.key(3000), (C, C)),
    ]
    def transform(xi, eid):
        return jnp.where(eid == 1, xi @ t[0], xi @ t[1])

    y = jax.vmap(lambda xi, ei: transform(xi, ei))(x, expert_ids)

    x = x.reshape(D, B, C)
    y = y.reshape(D, B, C)

    indices = list(range(D))
    #with jax.profiler.trace("./tensorboard"):
    for e in range(100):
        for i in indices:
            loss, grads = step(state, x[i], y[i])
            if i % 1000 == 0:
                print(e, i, loss)
