In [1]:
from typing import Literal

import os

os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

import jax

platform : Literal["darwin", "colab", "cuda", "tpu"] = "darwin"

try:
    import google.colab
    platform = "colab"
except ImportError:
    devices = jax.devices()
    if any(d.platform == "gpu" for d in devices):
        platform = "cuda"
    if any(d.platform == "tpu" for d in devices):
        platform = "tpu"

print(f"Running on {platform}")

if platform == "colab":
    !git clone https://github.com/novastar53/jaxpt
    !cd jaxpt && git checkout main && git pull
    !pip install tiktoken datasets --quiet
    #!pip uninstall -y tensorflow
    !pip install tensorboard
    !pip install -U tensorboard-plugin-profile

from pathlib import Path
import sys

if platform == "colab":
    jaxpt_dir = str(Path().absolute() / "jaxpt" / "src" )
else:
    jaxpt_dir = str(Path().absolute().parent / "src" )


sys.path.append(jaxpt_dir)
print(jaxpt_dir)

Running on darwin
/Users/vikram/dev/jaxpt/src


In [2]:

from functools import partial
from dataclasses import dataclass
import random

import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec, NamedSharding, Mesh
from jax.debug import visualize_array_sharding as viz

import flax.nnx as nnx
import optax

from jaxpt.modules.config import Config
#from jaxpt.utils import create_sharded_model


devices = jax.devices()
print(devices)

mesh = Mesh(devices, ("devices"))
spec = PartitionSpec("devices",)
sharding = NamedSharding(mesh, spec)

@nnx.jit(static_argnums=(0, 1)) #, out_shardings=sharding)
def create_sharded_model(Model, config, rngs):
    model = Model(config=config, rngs=rngs)
    graphdef, state = nnx.split(model)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = nnx.with_sharding_constraint(
        state, pspecs, mesh=config.mesh
        )
    nnx.update(model, sharded_state)
    return model



@dataclass(unsafe_hash=True)
class MOE_Config(Config):
    n_layer = 1
    top_k = 2
    load_factor = 1.00
    n_experts = len(devices)
    n_embed = 3 
    n_mlp_hidden = 6
    mlp_bias = True
    dtype = jax.numpy.float32
    mesh = mesh

config = MOE_Config()


class Experts(nnx.Module):
    def __init__(self, config, rngs):
        w_c_fc_init = nnx.with_partitioning(
            nnx.initializers.normal(stddev=0.02),
            sharding=("devices",))
        
        b_init = nnx.with_partitioning(
            nnx.initializers.zeros,
            sharding=("devices",))
        
        w_c_proj_init = nnx.with_partitioning(
            nnx.initializers.normal(stddev=0.02 * (2 * config.n_layer) ** -0.5),
            sharding=("devices",)
        )

        self.w_c_fc = nnx.Param(w_c_fc_init(rngs.default(),
            (
                config.n_experts,
                config.n_embed,
                config.n_mlp_hidden
            )
        ))
        self.b_c_fc = nnx.Param(b_init(rngs.default(),
        (
            config.n_experts,
            1,
            config.n_mlp_hidden
        )))

        self.w_gate = nnx.Param(w_c_fc_init(rngs.default(),
        (
            config.n_experts,
            config.n_embed,
            config.n_mlp_hidden
        )))
        self.b_gate = nnx.Param(b_init(rngs.default(),
        (
            config.n_experts,
            1,
            config.n_mlp_hidden
        )))

        self.w_c_proj = nnx.Param(
            w_c_proj_init(
                rngs.default(),
                (
                    config.n_experts,
                    config.n_mlp_hidden,
                    config.n_embed
                ))
        )
        self.b_c_proj = nnx.Param(
            b_init(
                rngs.default(),
                (
                    config.n_experts,
                    1,
                    config.n_embed
                )
            )
        )

    def __call__(self, x):
        x = jax.lax.with_sharding_constraint(x, spec)
        h = jnp.einsum('eti,eih->eth', x, self.w_c_fc) + self.b_c_fc
        g = jnp.einsum('eti,eih->eth', x, self.w_gate) + self.b_gate
        g = nnx.silu(g)
        og = jnp.einsum('eth,eth->eth', h, g)
        o = jnp.einsum('eth,eho->eto', og, self.w_c_proj) + self.b_c_proj
        o = jax.lax.with_sharding_constraint(o, spec)
        return o


class MOE(nnx.Module):
    def __init__(self, config: Config, rngs: nnx.Rngs):
        self.router_gate = nnx.Linear(
            config.n_embed,
            config.n_experts,
            kernel_init=nnx.with_partitioning(
                nnx.initializers.normal(stddev=0.02),
                sharding=(None,)),
            bias_init=nnx.with_partitioning(nnx.initializers.zeros,
            sharding=(None,)),
            use_bias=config.mlp_bias,
            dtype=config.dtype,
            rngs=rngs,
        )
        self.experts = Experts(config, rngs)
        self.top_k = config.top_k
        self.n_experts = config.n_experts
        self.load_factor = config.load_factor
        self.add_noise = False
        self.rngs = rngs

    def __call__(self, x):
        B, C = x.shape
        logits = self.router_gate(x) # B, n_experts
        logits = jax.lax.with_sharding_constraint(logits, spec) # B, n_experts
        #if self.add_noise:
        #    logits += 0.01 * jax.random.normal(key=self.rngs.gate_noise(), shape=logits.shape)

        top_k_logits, expert_indices = jax.lax.top_k(logits, self.top_k) # B, top_k
        zeros = jnp.full_like(logits, float('-inf')) # B, n_experts
        zeros = jax.lax.with_sharding_constraint(zeros, spec)
        sparse_logits = jnp.put_along_axis(
                zeros, expert_indices, top_k_logits, axis=-1, inplace=False) # B, n_experts
        sparse_logits = jax.lax.with_sharding_constraint(sparse_logits, spec)
        expert_probs = jax.nn.softmax(sparse_logits, axis=-1) # B, n_experts  
        expert_probs = jax.lax.with_sharding_constraint(expert_probs, spec)
        
        expert_indices_mask = jax.nn.one_hot(expert_indices, num_classes=self.n_experts, axis=-1) # B, n_experts, 2
        expert_indices_mask = jax.lax.with_sharding_constraint(expert_indices_mask, spec)
        expert_indices_mask = jnp.sum(expert_indices_mask, axis=1)  # B, n_experts
        expert_token_positions = jnp.cumsum(expert_indices_mask, axis=0) * expert_indices_mask # B, n_experts

        expert_token_positions = jax.lax.with_sharding_constraint(expert_token_positions, spec)
        expert_input_experts, expert_input_token_idxs = jnp.nonzero(expert_token_positions.T, size=B * self.top_k) # B * top_k, B * top_k
        expert_input_positions = jnp.int32(expert_token_positions.T[expert_input_experts, expert_input_token_idxs]) - 1 # B * top_k
        expert_input_positions = jax.lax.with_sharding_constraint(expert_input_positions, spec)
         
        expert_inputs = jnp.zeros((self.n_experts, self.top_k * B, C))
        expert_inputs = jax.lax.with_sharding_constraint(expert_inputs, spec)
        expert_inputs = expert_inputs.at[expert_input_experts, expert_input_positions].set(x[expert_input_token_idxs])
        input_counters = jnp.max(expert_input_positions, axis=0)

        f = input_counters / B
        P = jnp.mean(expert_probs, axis=0)
        aux_loss = jnp.sum(f * P) / (self.n_experts ** 2)

        expert_outputs = self.experts(expert_inputs) # n_experts, expert_capacity
        expert_outputs = jax.lax.with_sharding_constraint(expert_outputs, spec)

        y = jnp.zeros_like(x)
        y = jax.lax.with_sharding_constraint(y, spec)
        y = y.at[expert_input_token_idxs].add(expert_outputs[expert_input_experts, expert_input_positions]*expert_probs[expert_input_token_idxs, expert_input_experts][..., None])
        y = jax.lax.with_sharding_constraint(y, spec)

        return y, 0

def loss_fn(model, x, y):
    y_pred, aux_loss = model(x)
    loss = jnp.mean((y - y_pred)**2) + 0.01 * aux_loss
    return loss

@nnx.jit
def step(state, x, y):
    x = jax.lax.with_sharding_constraint(x, spec)
    y = jax.lax.with_sharding_constraint(y, spec)
    loss, grads = nnx.value_and_grad(loss_fn)(state.model, x, y)
    state.update(grads)
    return loss, grads



[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]


In [3]:
from time import time


with mesh:
    D, B, C = 1000, 2 * len(devices), config.n_embed

    default = jax.random.key(69)
    gate_noise = jax.random.key(42)
    rngs = nnx.Rngs(default=default, gate_noise=gate_noise)
    model = create_sharded_model(MOE, config, rngs)
    model.train(add_noise=True)
    tx = optax.adam(1e-3)
    state = nnx.Optimizer(model, tx)

    x = jax.random.normal(jax.random.key(1000), (D * B, C))

    expert_ids = (x[:, 0] > 0).astype(jnp.int32)
    t = [
        jax.random.normal(jax.random.key(2000), (C, C)),
        jax.random.normal(jax.random.key(3000), (C, C)),
    ]
    def transform(xi, eid):
        return jnp.where(eid == 1, xi @ t[0], xi @ t[1])

    y = jax.vmap(lambda xi, ei: transform(xi, ei))(x, expert_ids)

    x = x.reshape(D, B, C)
    y = y.reshape(D, B, C)

    indices = list(range(D))
    #with jax.profiler.trace("./tensorboard"):
    for e in range(20):
        for i in indices:
            start = time()
            loss, grads = step(state, x[i], y[i])
            if i % 1000 == 0:
                end = time()
                iter_time = 1024 * (end - start) / 1000
                print(f"{e=}, {i=}, {loss.item()=}, {iter_time=:0.4f}")


e=0, i=0, loss.item()=5.034303188323975, iter_time=0.4012
e=1, i=0, loss.item()=0.8223907351493835, iter_time=0.0019
e=2, i=0, loss.item()=0.3427053391933441, iter_time=0.0013
e=3, i=0, loss.item()=0.18538902699947357, iter_time=0.0016
e=4, i=0, loss.item()=0.12252093851566315, iter_time=0.0013
e=5, i=0, loss.item()=0.07758034765720367, iter_time=0.0015
e=6, i=0, loss.item()=0.04556037485599518, iter_time=0.0016
e=7, i=0, loss.item()=0.08833136409521103, iter_time=0.0012
e=8, i=0, loss.item()=0.0814942717552185, iter_time=0.0013
e=9, i=0, loss.item()=0.02062944881618023, iter_time=0.0015
e=10, i=0, loss.item()=0.015844417735934258, iter_time=0.0009
e=11, i=0, loss.item()=0.016143733635544777, iter_time=0.0014
e=12, i=0, loss.item()=0.08357572555541992, iter_time=0.0013
e=13, i=0, loss.item()=0.07738518714904785, iter_time=0.0012
e=14, i=0, loss.item()=0.07250624895095825, iter_time=0.0014
e=15, i=0, loss.item()=0.06801272928714752, iter_time=0.0017
e=16, i=0, loss.item()=0.062490642070

In [4]:
%load_ext tensorboard
%tensorboard --logdir ./tensorboard