In [5]:
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

import functools

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from ml_collections import ConfigDict


In [None]:
device_array = np.array(jax.devices())
mesh = Mesh(device_array, ("dp", "tp"))

In [8]:
device_array

array([CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3),
       CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)],
      dtype=object)

In [16]:
class ShardedDense(nn.Module):
    features: int
    kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
    bias_init: nn.initializers.Initializer = nn.initializers.zeros
    use_bias: bool = True
    tp_axis_name: str = 'tp'

    @nn.compact
    def __call__(self, x):
        weights_sharding = P(None, self.tp_axis_name)
        weights = self.param(
                'weights', 
                self.kernel_init, 
                (input_dim, self.features), 
                jax.default_prng_key(),
                _param_axes=nn.split_axis('hidden', 1, self.tp_axis_name)
            )
        if self.use_bias:
            bias_sharding = P(None, self.tp_axis_name)
            bias = self.param(
                    'bias',
                    self.bias_init,
                    (self.features,),
                    jax.default_prng_key(),
                    _param_axes=nn.split_axis('hidden', 1, self.tp_axis_name)
            )
        else:
            bias = None
        
        y_local = x @ weights
        y_local = y_local + bias if bias is not None else y_local
        y_global = jax.lax.all_gather(y_local, self.tp_axis_name, axis=1, tiled=True)

        return y_global

In [9]:
jax.devices()

[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7)]