Skip to content

Commit

Permalink
[shmap] Support multiple axes for standard collectives in shard_map
Browse files Browse the repository at this point in the history
  • Loading branch information
ppham27 committed Feb 6, 2024
1 parent 9a098e9 commit c850b10
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 19 deletions.
9 changes: 5 additions & 4 deletions jax/experimental/shard_map.py
Expand Up @@ -923,13 +923,14 @@ def _standard_collective_check(prim, mesh, x_rep, *, axis_name, **params):

def _standard_collective_rewrite(prim, mesh, in_rep, x, axis_name, **params):
# The standard collective rewrite may insert a pbroadcast on the input.
if type(axis_name) is tuple: raise NotImplementedError # TODO
if params.get('axis_index_groups') is not None: raise NotImplementedError
axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
x_rep, = in_rep
if axis_name in in_rep:
x = pbroadcast(x, (axis_name,))
axis_name_set = set(axis_name)
if pbroadcast_axis_name := axis_name_set & x_rep:
x = pbroadcast(x, tuple(pbroadcast_axis_name))
out_val = prim.bind(x, axis_name=axis_name, **params)
return [out_val], [x_rep - {axis_name}]
return [out_val], [x_rep - axis_name_set]


for o in it.chain(lax.__dict__.values(), slicing.__dict__.values(),
Expand Down
103 changes: 88 additions & 15 deletions tests/shard_map_test.py
Expand Up @@ -125,10 +125,17 @@ def test_all_gather(self):
@partial(shard_map, mesh=mesh,
in_specs=(P('z', ('x', 'y')),), out_specs=P('z', ('x', 'y')))
def fwd(a):
return lax.all_gather(a, 'z', axis=0, tiled=True)

c = fwd(a)
return (
lax.all_gather(a, 'z', axis=0, tiled=True),
lax.all_gather(a, ('x', 'y'), axis=-1, tiled=True),
)
c, d = fwd(a)
self.assertEqual(c.addressable_data(0).shape, (8, 2))
for i, a_shard in enumerate(np.split(a, 4, axis=1)):
self.assertAllClose(c.addressable_data(2 * i), a_shard)
self.assertEqual(d.addressable_data(0).shape, (4, 8))
for i, a_shard in enumerate(np.split(a, 2, axis=0)):
self.assertAllClose(d.addressable_data(i), a_shard)

def test_matmul_partial(self):
raise unittest.SkipTest("invalid replication asserted by out_spec?")
Expand Down Expand Up @@ -156,10 +163,17 @@ def test_matmul_reduce_scatter(self):
out_specs=P(('z', 'y'), None))
def fwd(a, b):
c = jnp.matmul(a, b) # [B.z, F] {y.unreduced}
return lax.psum_scatter(c, 'y', scatter_dimension=0, tiled=True)
return (
lax.psum_scatter(c, 'y', scatter_dimension=0, tiled=True),
lax.psum_scatter(c, ('z', 'y'), scatter_dimension=0, tiled=True),
)

c = fwd(a, b)
expected = jnp.matmul(a, b)
c, d = fwd(a, b)
self.assertEqual(c.addressable_data(0).shape, (2, 8))
self.assertAllClose(expected, c)
self.assertEqual(d.addressable_data(0).shape, (1, 8))
self.assertAllClose(expected[:4] + expected[4:], d)

def test_collective_permute(self):
devices = np.array(jax.devices()[:8]) # Take up to 8 devices
Expand All @@ -169,8 +183,9 @@ def test_collective_permute(self):
jax.sharding.NamedSharding(mesh, P('x', None)))

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(P('x', None),),
out_specs=P('x', None))
@partial(
shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None)
)
def fwd(a):
axis_size = lax.psum(1, 'x')
perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
Expand All @@ -179,18 +194,74 @@ def fwd(a):
c = fwd(a)
self.assertAllClose(c[1, :], a[0, :])

def test_all_to_all(self):
devices = np.array(jax.devices()[:8]) # Take up to 8 devices
mesh = Mesh(devices, axis_names=('x'))
def test_collective_permute_with_multiple_axis_names(self):
mesh = Mesh(
np.array(jax.devices()[:8]).reshape((2, 2, 2)),
axis_names=('x', 'y', 'z'),
)
a = jax.device_put(
jnp.arange(8 * 8).reshape((4, 16)),
jax.sharding.NamedSharding(mesh, P('x', ('y', 'z'))),
)

@jax.jit
@partial(
shard_map,
mesh=mesh,
in_specs=(P('x', ('y', 'z')),),
out_specs=P('x', ('y', 'z')),
)
def fwd(a):
xy_axis_size = lax.psum(1, ('x', 'y'))
yz_axis_size = lax.psum(1, ('y', 'z'))
xy_perm = [(j, (j + 1) % xy_axis_size) for j in range(xy_axis_size)]
yz_perm = [(j, (j + 1) % yz_axis_size) for j in range(yz_axis_size)]
return (
lax.ppermute(a, ('x', 'y'), perm=xy_perm),
lax.ppermute(a, ('y', 'z'), perm=yz_perm),
)

c, d = fwd(a)
for i in range(8):
self.assertAllClose(
a.addressable_data(i), c.addressable_data((i + 2) % 8)
)
self.assertAllClose(
a.addressable_data(i), d.addressable_data(4 * (i // 4) + (i + 1) % 4)
)

@parameterized.named_parameters(
dict(
testcase_name='_single_axis_name', axis_name='x', mesh_axes=dict(x=8)
),
dict(
testcase_name='_multiple_axis_names',
axis_name=('x', 'y'),
mesh_axes=dict(x=4, y=2),
),
)
def test_all_to_all(self, axis_name, mesh_axes):
devices = np.array(jax.devices()[: np.prod(tuple(mesh_axes.values()))])
mesh = Mesh(
devices.reshape(tuple(mesh_axes.values())),
axis_names=tuple(mesh_axes.keys()),
)
a = jax.device_put(
jnp.arange(8 * 8).reshape((8, 8)),
jax.sharding.NamedSharding(mesh, P('x', None)))
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
)

@jax.jit
@partial(shard_map, mesh=mesh,
in_specs=(P('x', None),), out_specs=P(None, 'x'))
@partial(
shard_map,
mesh=mesh,
in_specs=(P(axis_name, None),),
out_specs=P(None, axis_name),
)
def fwd(a):
return lax.all_to_all(a, 'x', split_axis=1, concat_axis=1, tiled=True)
return lax.all_to_all(
a, axis_name, split_axis=1, concat_axis=1, tiled=True
)

c = fwd(a)
assert (c == jnp.reshape(a.T, (1, 64))).all()
Expand Down Expand Up @@ -860,7 +931,9 @@ def test_dce(self):
def f(x, y, z):
@partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P(None, 'i')),
out_specs=(P(None, None), P(None, 'i'), P('i', 'j')))
def g(y, z): return jnp.sin(x), jnp.cos(z), jnp.tan(y)
def g(y, z):
return jnp.sin(x), jnp.cos(z), jnp.tan(y)

return g(y, z)

x = jnp.zeros((4, 4))
Expand Down

0 comments on commit c850b10

Please sign in to comment.