In [None]:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices

In [None]:
from functools import partial

import jax
import jax.numpy as jnp

from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map

In [None]:
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('x', 'y'))

a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 *  4.).reshape(16, 4)

@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),
         out_specs=P('x', None))
def matmul_basic(a_block, b_block):
  # a_block: f32[2, 8]
  # b_block: f32[8, 4]
  c_partialsum = jnp.dot(a_block, b_block)
  c_block = jax.lax.psum(c_partialsum, 'y')
  # c_block: f32[2, 4]
  return c_block

c = matmul_basic(a, b)   # c: f32[8, 4]

In [None]:
from jax.tree_util import tree_map, tree_all

def allclose(a, b):
  return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))

allclose(c, jnp.dot(a, b))

In [None]:
jax.debug.visualize_array_sharding(c)

In [None]:
from jax.sharding import NamedSharding

a = jax.device_put(a, NamedSharding(mesh, P('x', 'y')))
b = jax.device_put(b, NamedSharding(mesh, P('y', None)))

@jax.jit
def matmul_reference(a, b):
  c = jnp.dot(a, b)
  return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None)))

c_ref = matmul_reference(a, b)
allclose(c_ref, jnp.dot(a, b))

In [None]:
print('a blocks:'); jax.debug.visualize_array_sharding(a)
print('b blocks:'); jax.debug.visualize_array_sharding(b)
print('c blocks:'); jax.debug.visualize_array_sharding(c)

In [None]:
import numpy as np
devices = np.array(jax.devices()[:4])
mesh = Mesh(devices, ('i',))  # mesh.shape['i'] = 4

def check_shmap(f, y):
  ans = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))(y)
  expected = jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, mesh.shape['i'])])
  print(allclose(ans, expected))


check_shmap(lambda x: x.T @ x, jnp.arange(32).reshape(8, 4))

In [None]:
B,S,H,D = (4,16, 12,32)

devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('row', 'col'))

X = jnp.arange( B*S*H*D).reshape(B*S, H*D)
W = jnp.arange(H*D*4*H*D).reshape(H*D, 4*H*D)

Xt = jax.device_put(X, NamedSharding(mesh, P('row', 'col')))
Wt = jax.device_put(W, NamedSharding(mesh, P('row', 'col')))

@partial(shard_map, mesh=mesh, in_specs=(P('row', 'col'), P('row', 'col')),
         out_specs=P('row', 'col'))
def GSPMD_OS(Xij, Wij):
    Xi = jax.lax.all_gather(Xij, 'col', tiled=True, axis=1)
    print(Xi.shape)
    Wj = jax.lax.all_gather(Wij, 'row', tiled=True, axis=0)
    print(Wj.shape)
    return Xi @ Wj

y_ref = GSPMD_OS(X, W)
#y_ref.shape
#jnp.dot(X, W).shape
allclose(y_ref, jnp.dot(X, W))
y_ref

In [None]:
X1 = jnp.arange( B*S*4*H*D).reshape(B*S, 4*H*D)
W1 = jnp.arange(H*D*4*H*D).reshape(4*H*D, H*D)

Xt = jax.device_put(X1, NamedSharding(mesh, P('row', 'col')))
Wt = jax.device_put(W1, NamedSharding(mesh, P('row', 'col')))

@partial(shard_map, mesh=mesh, in_specs=(P('row', 'col'), P('col', 'row')),
         out_specs=P('row', 'col'))
def GSPMD_IS(Xij, Wij):
    Wj = jax.lax.all_gather(Wij, 'row', tiled=True, axis=1)
    print(Wj.shape)
    Yp = Xij @ Wj
    return jax.lax.psum_scatter(Yp, 'col', scatter_dimension=1, tiled=True)

y_ref = GSPMD_IS(Xt, Wt)
y_ref.shape
#jnp.dot(X, W).shape
allclose(y_ref, jnp.dot(X1, W1))