In [1]:
from typing import Literal

import os

os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'

import jax

platform : Literal["darwin", "colab", "cuda", "tpu"] = "darwin"

try:
    import google.colab
    platform = "colab"
except ImportError:
    devices = jax.devices()
    if any(d.platform == "gpu" for d in devices):
        platform = "cuda"
    if any(d.platform == "tpu" for d in devices):
        platform = "tpu"

print(f"Running on {platform}")

if platform == "colab":
    !git clone https://github.com/novastar53/jaxpt
    !cd jaxpt && git checkout main && git pull
    !pip install tiktoken datasets --quiet
    #!pip uninstall -y tensorflow
    !pip install tensorboard
    !pip install -U tensorboard-plugin-profile

from pathlib import Path
import sys

if platform == "colab":
    jaxpt_dir = str(Path().absolute() / "jaxpt" / "src" )
else:
    jaxpt_dir = str(Path().absolute().parent / "src" )


sys.path.append(jaxpt_dir)
print(jaxpt_dir)

Running on darwin
/Users/vikram/dev/jaxpt/src


In [2]:

from functools import partial
from dataclasses import dataclass
import random

import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec, NamedSharding, Mesh
from jax.debug import visualize_array_sharding as viz

import flax.nnx as nnx
import optax

from jaxpt.modules.moe import MOE
from jaxpt.modules.config import Config
from jaxpt.utils import create_sharded_model


devices = jax.devices()
print(devices)

mesh = Mesh(devices, ("devices"))
spec = PartitionSpec("devices",)
sharding = NamedSharding(mesh, spec)


@dataclass(unsafe_hash=True)
class MOE_Config(Config):
    n_layer: int = 1
    top_k: int = 2
    load_factor: float = 1.00
    n_experts: int = len(devices)
    n_embed: int = 3 
    n_mlp_hidden: int = 6
    mlp_bias: bool = True
    dtype: jnp.dtype = jax.numpy.bfloat16
    param_dtype: jnp.dtype = jax.numpy.float32
    mesh: jax.sharding.Mesh = mesh

config = MOE_Config()


def loss_fn(model, x, y):
    y_pred, aux_loss = model(x)
    loss = jnp.mean((y - y_pred)**2) + 0.01 * aux_loss
    return loss

@nnx.jit
def step(state, x, y):
    loss, grads = nnx.value_and_grad(loss_fn, has_aux=False)(state.model, x, y)
    state.update(grads)
    return loss, grads


from time import time

with mesh:
    D, B, T, C = 1000, 4 * len(devices), 5, config.n_embed

    default = jax.random.key(69)
    gate_noise = jax.random.key(42)
    rngs = nnx.Rngs(default=default, gate_noise=gate_noise)
    #model = create_sharded_model(MOE, config, rngs)
    model = MOE(config, rngs)
    nnx.display(model)
    model.train(add_noise=True, aux_loss=True)
    tx = optax.adam(1e-2)
    state = nnx.Optimizer(model, tx)

    x = jax.random.normal(jax.random.key(1000), (D * B * T, C))

    expert_ids = (x[:, 0] > 0).astype(jnp.int32)
    t = [
        jax.random.normal(jax.random.key(2000), (C, C)),
        jax.random.normal(jax.random.key(3000), (C, C)),
    ]
    def transform(xi, eid):
        return jnp.where(eid == 1, xi @ t[0], xi @ t[1])

    y = jax.vmap(lambda xi, ei: transform(xi, ei))(x, expert_ids)

    x = x.reshape(D, B, T, C)
    y = y.reshape(D, B, T, C)

    indices = list(range(D))

    #with jax.profiler.trace("./tensorboard"):
    for e in range(20):
        for i in indices:
            start = time()
            x_i = jax.device_put(x[i], sharding)
            y_i = jax.device_put(y[i], sharding)
            loss, grads = step(state, x_i, y_i)
            if i % 1000 == 0:
                end = time()
                iter_time = 1024 * (end - start) / 1000
                print(f"{e=}, {i=}, {loss.item()=}, {iter_time=:0.4f}")

[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]


e=0, i=0, loss.item()=3.8815596103668213, iter_time=1.0489
e=1, i=0, loss.item()=0.26602426171302795, iter_time=0.0008
e=2, i=0, loss.item()=0.24425984919071198, iter_time=0.0009
e=3, i=0, loss.item()=0.12493135035037994, iter_time=0.0008
e=4, i=0, loss.item()=0.06272979825735092, iter_time=0.0008
e=5, i=0, loss.item()=0.053908221423625946, iter_time=0.0009


KeyboardInterrupt: 

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./tensorboard

Reusing TensorBoard on port 6006 (pid 2359), started 3 days, 7:38:22 ago. (Use '!kill 2359' to kill it.)