In [2]:
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

import functools

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from ml_collections import ConfigDict
from functools import partial

In [3]:
device_array = np.array(jax.devices())
mesh = Mesh(device_array, ("model",))

In [21]:
class TPDense(nn.Module):
  features: int
  dtype: jnp.dtype = jnp.bfloat16
  @nn.compact
  def __call__(self, x):
    ki = nn.linear.default_kernel_init
    h = nn.Dense(
        self.features,
        dtype=self.dtype,
        kernel_init=nn.with_partitioning(ki, (None, 'model')))(x)
    return h

dense = TPDense(4096)

In [22]:
x = jnp.ones((8 * 1024, 1024))

# use eval_shape to get the Partitioned instances for the variables.
# this way we can determinte the PartitionSpecs for the init variables
# before we call the init fn.
var_spec = jax.eval_shape(dense.init, jax.random.PRNGKey(0), x)
var_spec_out = nn.get_partition_spec(var_spec) #generates a PyTree of PartitionSpecs as out_spec needs to be a pytree

init_specs = (None, P(None, "model"))

init_fn_sharded = partial(shard_map, mesh=mesh, in_specs=init_specs, out_specs=var_spec_out)(mlp.init)

variables = init_fn_sharded(jax.random.PRNGKey(0), x)


apply_fn_sharded = partial(shard_map, mesh=mesh, in_specs=(var_spec_out, P(None, "model")), out_specs=P(None, "model"))(mlp.apply)

output = apply_fn_sharded(variables, x)


In [24]:
jax.tree.map( lambda x: x.sharding, output)

NamedSharding(mesh=Mesh('model': 8, axis_types=(Auto,)), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)