diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index dc45196f3eb4..a987d8d0faf4 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -923,13 +923,14 @@ def _standard_collective_check(prim, mesh, x_rep, *, axis_name, **params): def _standard_collective_rewrite(prim, mesh, in_rep, x, axis_name, **params): # The standard collective rewrite may insert a pbroadcast on the input. - if type(axis_name) is tuple: raise NotImplementedError # TODO if params.get('axis_index_groups') is not None: raise NotImplementedError + axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name x_rep, = in_rep - if axis_name in in_rep: - x = pbroadcast(x, (axis_name,)) + axis_name_set = set(axis_name) + if pbroadcast_axis_name := axis_name_set & x_rep: + x = pbroadcast(x, tuple(pbroadcast_axis_name)) out_val = prim.bind(x, axis_name=axis_name, **params) - return [out_val], [x_rep - {axis_name}] + return [out_val], [x_rep - axis_name_set] for o in it.chain(lax.__dict__.values(), slicing.__dict__.values(), diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 0c087fd1025d..d9776a3e6216 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -125,10 +125,17 @@ def test_all_gather(self): @partial(shard_map, mesh=mesh, in_specs=(P('z', ('x', 'y')),), out_specs=P('z', ('x', 'y'))) def fwd(a): - return lax.all_gather(a, 'z', axis=0, tiled=True) - - c = fwd(a) + return ( + lax.all_gather(a, 'z', axis=0, tiled=True), + lax.all_gather(a, ('x', 'y'), axis=-1, tiled=True), + ) + c, d = fwd(a) self.assertEqual(c.addressable_data(0).shape, (8, 2)) + for i, a_shard in enumerate(np.split(a, 4, axis=1)): + self.assertAllClose(c.addressable_data(2 * i), a_shard) + self.assertEqual(d.addressable_data(0).shape, (4, 8)) + for i, a_shard in enumerate(np.split(a, 2, axis=0)): + self.assertAllClose(d.addressable_data(i), a_shard) def test_matmul_partial(self): raise unittest.SkipTest("invalid replication asserted by out_spec?") @@ -156,10 +163,17 @@ def test_matmul_reduce_scatter(self): out_specs=P(('z', 'y'), None)) def fwd(a, b): c = jnp.matmul(a, b) # [B.z, F] {y.unreduced} - return lax.psum_scatter(c, 'y', scatter_dimension=0, tiled=True) + return ( + lax.psum_scatter(c, 'y', scatter_dimension=0, tiled=True), + lax.psum_scatter(c, ('z', 'y'), scatter_dimension=0, tiled=True), + ) - c = fwd(a, b) + expected = jnp.matmul(a, b) + c, d = fwd(a, b) self.assertEqual(c.addressable_data(0).shape, (2, 8)) + self.assertAllClose(expected, c) + self.assertEqual(d.addressable_data(0).shape, (1, 8)) + self.assertAllClose(expected[:4] + expected[4:], d) def test_collective_permute(self): devices = np.array(jax.devices()[:8]) # Take up to 8 devices @@ -169,8 +183,9 @@ def test_collective_permute(self): jax.sharding.NamedSharding(mesh, P('x', None))) @jax.jit - @partial(shard_map, mesh=mesh, in_specs=(P('x', None),), - out_specs=P('x', None)) + @partial( + shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None) + ) def fwd(a): axis_size = lax.psum(1, 'x') perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] @@ -179,18 +194,74 @@ def fwd(a): c = fwd(a) self.assertAllClose(c[1, :], a[0, :]) - def test_all_to_all(self): - devices = np.array(jax.devices()[:8]) # Take up to 8 devices - mesh = Mesh(devices, axis_names=('x')) + def test_collective_permute_with_multiple_axis_names(self): + mesh = Mesh( + np.array(jax.devices()[:8]).reshape((2, 2, 2)), + axis_names=('x', 'y', 'z'), + ) + a = jax.device_put( + jnp.arange(8 * 8).reshape((4, 16)), + jax.sharding.NamedSharding(mesh, P('x', ('y', 'z'))), + ) + + @jax.jit + @partial( + shard_map, + mesh=mesh, + in_specs=(P('x', ('y', 'z')),), + out_specs=P('x', ('y', 'z')), + ) + def fwd(a): + xy_axis_size = lax.psum(1, ('x', 'y')) + yz_axis_size = lax.psum(1, ('y', 'z')) + xy_perm = [(j, (j + 1) % xy_axis_size) for j in range(xy_axis_size)] + yz_perm = [(j, (j + 1) % yz_axis_size) for j in range(yz_axis_size)] + return ( + lax.ppermute(a, ('x', 'y'), perm=xy_perm), + lax.ppermute(a, ('y', 'z'), perm=yz_perm), + ) + + c, d = fwd(a) + for i in range(8): + self.assertAllClose( + a.addressable_data(i), c.addressable_data((i + 2) % 8) + ) + self.assertAllClose( + a.addressable_data(i), d.addressable_data(4 * (i // 4) + (i + 1) % 4) + ) + + @parameterized.named_parameters( + dict( + testcase_name='_single_axis_name', axis_name='x', mesh_axes=dict(x=8) + ), + dict( + testcase_name='_multiple_axis_names', + axis_name=('x', 'y'), + mesh_axes=dict(x=4, y=2), + ), + ) + def test_all_to_all(self, axis_name, mesh_axes): + devices = np.array(jax.devices()[: np.prod(tuple(mesh_axes.values()))]) + mesh = Mesh( + devices.reshape(tuple(mesh_axes.values())), + axis_names=tuple(mesh_axes.keys()), + ) a = jax.device_put( jnp.arange(8 * 8).reshape((8, 8)), - jax.sharding.NamedSharding(mesh, P('x', None))) + jax.sharding.NamedSharding(mesh, P(axis_name, None)), + ) @jax.jit - @partial(shard_map, mesh=mesh, - in_specs=(P('x', None),), out_specs=P(None, 'x')) + @partial( + shard_map, + mesh=mesh, + in_specs=(P(axis_name, None),), + out_specs=P(None, axis_name), + ) def fwd(a): - return lax.all_to_all(a, 'x', split_axis=1, concat_axis=1, tiled=True) + return lax.all_to_all( + a, axis_name, split_axis=1, concat_axis=1, tiled=True + ) c = fwd(a) assert (c == jnp.reshape(a.T, (1, 64))).all() @@ -860,7 +931,9 @@ def test_dce(self): def f(x, y, z): @partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P(None, 'i')), out_specs=(P(None, None), P(None, 'i'), P('i', 'j'))) - def g(y, z): return jnp.sin(x), jnp.cos(z), jnp.tan(y) + def g(y, z): + return jnp.sin(x), jnp.cos(z), jnp.tan(y) + return g(y, z) x = jnp.zeros((4, 4))