In [256]:
import os

os.environ["XLA_FLAGS"] = (
    "--xla_force_host_platform_device_count=8"  # Use 8 CPU devices
)
import jax

jax.config.update("jax_debug_nans", False)

import jax.numpy as jnp

mesh = jax.make_mesh((8,), ("tensor",))
mesh

Mesh(device_ids=array([0, 1, 2, 3, 4, 5, 6, 7]), axis_names=('tensor',), axis_types=(Auto,))

In [None]:
import flax.linen as nn
from jaxtyping import Array


class Dense(nn.Module):
    features: int = 8

    @nn.compact
    def __call__(self, x: Array):
        x = nn.Dense(self.features)(x)
        print(f"x shape before scatter: {x.shape}")
        x = jax.lax.psum_scatter(x, "tensor", scatter_dimension=x.ndim - 1, tiled=True)
        print(f"x shape after scatter: {x.shape}")
        return x


class Embeddings(nn.Module):
    model_dimension: int
    vocab_size: int
    model_dtype: jnp.dtype

    def setup(self):
        self.embedding = nn.Embed(
            num_embeddings=self.vocab_size,
            features=self.model_dimension,
            dtype=self.model_dtype,
        )
        self.norm = RMSNorm(
            model_dimension=self.model_dimension, model_dtype=self.model_dtype
        )

    def __call__(self, x: Array, out: bool = False) -> Array:
        if not out:
            x = self.embedding(x)
            x = jax.lax.all_to_all(x, "tensor", split_axis=2, concat_axis=1, tiled=True)
            if self.is_mutable_collection("params"):
                x = jax.lax.all_gather(x, "tensor", axis=-1, tiled=True)
                _ = self.norm(x)
        else:
            x = jax.lax.all_to_all(x, "tensor", split_axis=1, concat_axis=2, tiled=True)
            x = self.norm(x)
            x = self.embedding.attend(x)

        return x


class RMSNorm(nn.Module):
    model_dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, x: Array) -> Array:
        rms = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
        rms = jax.lax.pmean(rms, "tensor")
        x = x / jnp.sqrt(rms + 1e-6)

        gamma = self.param(
            "gamma", nn.initializers.ones, (x.shape[-1],), self.model_dtype
        )
        beta = self.param(
            "beta", nn.initializers.zeros, (x.shape[-1],), self.model_dtype
        )

        print(
            f"RMS shape: {rms.shape}, x shape: {x.shape}, gamma shape: {gamma.shape}, beta shape: {beta.shape}"
        )
        x = x * gamma + beta

        return x


class Attention(nn.Module):
    model_dimension: int
    num_heads: int
    model_dtype: jnp.dtype

    def setup(self):
        self.qkv = Dense(features=self.model_dimension * 3)
        self.out = Dense(features=self.model_dimension)

    def __call__(self, x: Array) -> Array:
        B, S, D = x.shape
        print(
            f"Input shape: {x.shape}, Batch: {B}, Sequence Length: {S}, Model Dimension: {D}"
        )
        x = self.qkv(x)  # (batch, seq_len, model_dimension // Tensor * 3)
        print(f"x shape: {x.shape}")
        x = x.reshape(
            B, S, self.num_heads, D // self.num_heads * 3
        )  # (batch, seq_len, num_heads, model_dimension // Tensor * 3)
        x = jax.lax.all_to_all(x, "tensor", split_axis=2, concat_axis=3, tiled=True)
        print(f"x shape after all_to_all: {x.shape}")
        q, k, v = jnp.split(x, 3, axis=-1)
        print(f"q shape: {q.shape}, k shape: {k.shape}, v shape: {v.shape}")
        score = jnp.einsum("bshd,bShd->bhsS", q, k) / jnp.sqrt(D // self.num_heads)
        score = jax.nn.softmax(score, axis=-1)
        print(f"score shape: {score.shape}")
        x = jnp.einsum("bhsS,bShd->bshd", score, v)
        print(f"x shape {x.shape}")

        x = jax.lax.all_to_all(x, "tensor", split_axis=3, concat_axis=2, tiled=True)

        print(f"x shape after all_to_all: {x.shape}")
        x = x.reshape(B, S, D)
        print(f"x shape after reshape: {x.shape}")
        x = self.out(x)

        return x

In [476]:
model1 = RMSNorm(model_dimension=32, model_dtype=jnp.float32)
x = jnp.ones((2, 8, 32), dtype=jnp.float32)
init_key = jax.random.PRNGKey(0)
params1 = model1.init(init_key, x, eval=True)["params"]
test_out = model1.apply({"params": params1}, x, eval=True)
print(test_out.shape)

TypeError: RMSNorm.__call__() got an unexpected keyword argument 'eval'

In [477]:
model = Embeddings(model_dimension=32, vocab_size=16, model_dtype=jnp.float32)
x = jnp.ones((2, 8), dtype=jnp.int32)
init_key = jax.random.PRNGKey(0)
params = model.init(init_key, x, eval=True)["params"]
print(model.apply({"params": params}, x, eval=True).shape)

TypeError: Embeddings.__call__() got an unexpected keyword argument 'eval'

In [478]:
params

{'embedding': {'embedding': Array([[ 1.07713148e-01, -1.26537129e-01,  6.29703328e-02,
          -1.95410237e-01,  5.84000796e-02, -2.08789915e-01,
          -3.08006644e-01,  1.42043397e-01, -3.14659290e-02,
           2.00732164e-02, -3.74735296e-01, -9.57105495e-03,
          -1.77905113e-01,  1.31225690e-01,  7.06995949e-02,
           8.56466219e-02,  3.86315644e-01,  1.02088936e-01,
           6.53506815e-02, -7.32118904e-04, -3.34644876e-02,
           9.92537811e-02, -5.35630845e-02,  2.26300865e-01,
           6.99746236e-02, -3.59276868e-02, -1.37261957e-01,
           2.63334811e-01, -4.09673393e-01,  3.60953480e-01,
           1.35335410e-02,  2.75004357e-01],
         [-1.76977396e-01, -6.37361258e-02, -8.02243203e-02,
           1.58917308e-01, -2.11864665e-01, -9.58523303e-02,
          -1.98677201e-02,  6.85228333e-02,  1.78853989e-01,
           1.26904353e-01, -3.99270505e-02, -9.51974839e-02,
          -1.49512723e-01,  2.98589945e-01,  3.47766168e-02,
          -8.8

In [479]:
def display_param_sharding(params):
    jax.tree.map(lambda p: jax.debug.visualize_array_sharding(p), params)

In [None]:
from functools import partial
from jax.sharding import PartitionSpec as P

model_tensor = Embeddings()

tensor_keys = jax.random.split(init_key, 8)
tensor_keys = jnp.array(tensor_keys)


@partial(
    jax.shard_map,
    mesh=mesh,
    in_specs=(P(None, "tensor"), P("tensor")),
    out_specs=P("tensor"),
)
def get_var_spec(x, key):
    out = model_tensor.init(key[0], x)["params"]
    return out


var_spec = jax.eval_shape(lambda x, key: get_var_spec(x, key), x, tensor_keys)
print(var_spec)


def get_sharding(current_shape):
    if current_shape.ndim < 2:
        return P()
    return P("tensor")


out_spec = jax.tree.map(lambda x: get_sharding(x), var_spec)


@partial(
    jax.shard_map,
    mesh=mesh,
    in_specs=(P(None, "tensor"), P("tensor")),
    out_specs=out_spec,
)
def init_tensor(x, key):
    out = model_tensor.init(key[0], x)["params"]
    print(out.keys())
    return out


@partial(
    jax.shard_map,
    mesh=mesh,
    in_specs=(out_spec, P(None, "tensor")),
    out_specs=(P(None, None, "tensor")),
)
def predict_tensor(params, x):
    return model_tensor.apply({"params": params}, x)


@partial(
    jax.shard_map,
    mesh=mesh,
    in_specs=(out_spec, P(None, None, "tensor")),
    out_specs=(P(None, "tensor")),
)
def out_predict_tensor(params, x):
    return model_tensor.apply({"params": params}, x, out=True)


print(out_spec)

RMS shape: (2, 8, 1), x shape: (2, 8, 32), gamma shape: (32,), beta shape: (32,)
{'embedding': {'embedding': ShapeDtypeStruct(shape=(128, 32), dtype=float32)}, 'norm': {'beta': ShapeDtypeStruct(shape=(256,), dtype=float32), 'gamma': ShapeDtypeStruct(shape=(256,), dtype=float32)}}
{'embedding': {'embedding': PartitionSpec('tensor',)}, 'norm': {'beta': PartitionSpec(), 'gamma': PartitionSpec()}}


In [481]:
params_tensor = init_tensor(x, tensor_keys)

RMS shape: (2, 8, 1), x shape: (2, 8, 32), gamma shape: (32,), beta shape: (32,)
dict_keys(['embedding', 'norm'])


In [482]:
display_param_sharding(params_tensor)

In [483]:
x_tensor = predict_tensor(params_tensor, x)
print(x_tensor.shape)

(2, 8, 32)


In [484]:
jax.debug.visualize_array_sharding(x_tensor[0])

In [485]:
x_tensor.shape

(2, 8, 32)

In [486]:
x.shape

(2, 8)

In [487]:
x_tensor_out = out_predict_tensor(params_tensor, x_tensor)

RMS shape: (2, 1, 1), x shape: (2, 1, 32), gamma shape: (32,), beta shape: (32,)


In [488]:
print(x_tensor_out.shape)

(2, 8, 16)


In [489]:
display_param_sharding(x_tensor_out[0])

In [491]:
from functools import partial
from jax.sharding import PartitionSpec as P

model_tensor = Attention(model_dimension=64, num_heads=8, model_dtype=jnp.float32)

tensor_keys = jax.random.split(init_key, 8)
tensor_keys = jnp.array(tensor_keys)
x = jnp.ones((2, 8, 64), dtype=jnp.float32)


@partial(
    jax.shard_map,
    mesh=mesh,
    in_specs=(P(None, None, "tensor"), P("tensor")),
    out_specs=P("tensor"),
)
def init_var_spec(x, key):
    return model_tensor.init(key[0], x)["params"]


var_spec = jax.eval_shape(lambda key, x: init_var_spec(x, key), tensor_keys, x)


def get_sharding(current_shape):
    if current_shape.ndim < 2:
        return P()
    else:
        return P("tensor")


out_spec = jax.tree.map(lambda x: get_sharding(x), var_spec)


@partial(
    jax.shard_map,
    mesh=mesh,
    in_specs=(P(None, None, "tensor"), P("tensor")),
    out_specs=out_spec,
)
def init_tensor(x, key):
    return model_tensor.init(key[0], x)["params"]


@partial(
    jax.shard_map,
    mesh=mesh,
    in_specs=(out_spec, P(None, None, "tensor")),
    out_specs=(P(None, None, "tensor")),
)
def predict_tensor(params, x):
    return model_tensor.apply({"params": params}, x)


@partial(
    jax.shard_map,
    mesh=mesh,
    in_specs=(out_spec, P(None, None, "tensor")),
    out_specs=(P(None, None, "tensor")),
)
def out_predict_tensor(params, x):
    return model_tensor.apply({"params": params}, x, out=True)

Input shape: (2, 8, 8), Batch: 2, Sequence Length: 8, Model Dimension: 8
x shape before scatter: (2, 8, 192)
x shape after scatter: (2, 8, 24)
x shape: (2, 8, 24)
x shape after all_to_all: (2, 8, 1, 24)
q shape: (2, 8, 1, 8), k shape: (2, 8, 1, 8), v shape: (2, 8, 1, 8)
score shape: (2, 1, 8, 8)
x shape (2, 8, 1, 8)
x shape after all_to_all: (2, 8, 8, 1)
x shape after reshape: (2, 8, 8)
x shape before scatter: (2, 8, 64)
x shape after scatter: (2, 8, 8)


In [492]:
attention_params = init_tensor(x, tensor_keys)

Input shape: (2, 8, 8), Batch: 2, Sequence Length: 8, Model Dimension: 8
x shape before scatter: (2, 8, 192)
x shape after scatter: (2, 8, 24)
x shape: (2, 8, 24)
x shape after all_to_all: (2, 8, 1, 24)
q shape: (2, 8, 1, 8), k shape: (2, 8, 1, 8), v shape: (2, 8, 1, 8)
score shape: (2, 1, 8, 8)
x shape (2, 8, 1, 8)
x shape after all_to_all: (2, 8, 8, 1)
x shape after reshape: (2, 8, 8)
x shape before scatter: (2, 8, 64)
x shape after scatter: (2, 8, 8)


In [493]:
display_param_sharding(attention_params)

In [494]:
predict_tensor(attention_params, x).shape

Input shape: (2, 8, 8), Batch: 2, Sequence Length: 8, Model Dimension: 8
x shape before scatter: (2, 8, 192)
x shape after scatter: (2, 8, 24)
x shape: (2, 8, 24)
x shape after all_to_all: (2, 8, 1, 24)
q shape: (2, 8, 1, 8), k shape: (2, 8, 1, 8), v shape: (2, 8, 1, 8)
score shape: (2, 1, 8, 8)
x shape (2, 8, 1, 8)
x shape after all_to_all: (2, 8, 8, 1)
x shape after reshape: (2, 8, 8)
x shape before scatter: (2, 8, 64)
x shape after scatter: (2, 8, 8)


(2, 8, 64)