In [1]:
import os

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

import functools

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from ml_collections import ConfigDict
from functools import partial
from jaxtyping import Array, PyTree

In [2]:
device_array = np.array(jax.devices())
mesh = Mesh(device_array, ("model",))

In [3]:
class Embeddings(nn.Module):
    model_dimension: int
    vocab_size: int
    model_dtype: jnp.dtype = jnp.float32

    def setup(self):
        ei = nn.initializers.variance_scaling(1.0, "fan_in", "normal", out_axis=0)
        self.embedding = nn.Embed(
            num_embeddings=self.vocab_size,
            features=self.model_dimension,
            dtype=self.model_dtype,
            embedding_init=nn.with_partitioning(ei, (None, "model")),
        )
        self.layer_norm = nn.LayerNorm()

    def __call__(self, x: Array, out: bool = False) -> Array:
        if not out:
            x = self.embedding(x)
        else:
            x = self.layer_norm(x)
            x = self.embedding.attend(x)
        return x


dense = Embeddings(1024, 64)

In [4]:
x = jnp.ones((8 * 1024, 1024), dtype=jnp.int32)

var_spec = jax.eval_shape(dense.init, jax.random.PRNGKey(0), x)
var_spec_out = nn.get_partition_spec(
    var_spec
)  # generates a PyTree of PartitionSpecs as out_spec needs to be a pytree
print(var_spec_out)

init_specs = (None, P(None, "model"))

init_fn_sharded = partial(
    shard_map, mesh=mesh, in_specs=init_specs, out_specs=var_spec_out
)(dense.init)

variables = init_fn_sharded(jax.random.PRNGKey(0), x)

apply_fn_sharded = partial(
    shard_map,
    mesh=mesh,
    in_specs=(var_spec_out, P(None, "model")),
    out_specs=P(None, "model"),
)(dense.apply)

output = apply_fn_sharded(variables, x)

{'params': {'embedding': {'embedding': PartitionSpec(None, 'model')}}}


In [5]:
jax.debug.visualize_array_sharding(output[0])