From 94d58b7270acd6d1d28f28cf5cd894e2aed351a8 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 12 Dec 2023 09:16:09 -0800 Subject: [PATCH] `mesh_utils.create_hybrid_device_mesh`: make sorting granules by key user configurable. When sorting by granule key is disabled, the granules are used to create the mesh in the order in which they appear in the sequence of devices. PiperOrigin-RevId: 590228169 --- jax/experimental/mesh_utils.py | 19 +++++++++++---- tests/mesh_utils_test.py | 44 +++++++++++++++++++++++++++++++++- 2 files changed, 57 insertions(+), 6 deletions(-) diff --git a/jax/experimental/mesh_utils.py b/jax/experimental/mesh_utils.py index faf48a841f07..8b05e511e7f1 100644 --- a/jax/experimental/mesh_utils.py +++ b/jax/experimental/mesh_utils.py @@ -321,10 +321,13 @@ def create_device_mesh( device_mesh = np.asarray(devices).reshape(mesh_shape) return device_mesh -def create_hybrid_device_mesh(mesh_shape: Sequence[int], - dcn_mesh_shape: Sequence[int], - devices: Optional[Sequence[Any]] = None, *, - process_is_granule: bool = False) -> np.ndarray: +def create_hybrid_device_mesh( + mesh_shape: Sequence[int], + dcn_mesh_shape: Sequence[int], + devices: Optional[Sequence[Any]] = None, *, + process_is_granule: bool = False, + should_sort_granules_by_key: bool = True, +) -> np.ndarray: """Creates a device mesh for hybrid (e.g., ICI and DCN) parallelism. Args: @@ -339,6 +342,9 @@ def create_hybrid_device_mesh(mesh_shape: Sequence[int], of the slower/outer network. Otherwise it will look for slice_index attributes on devices and use slices as the units. Enabling this is meant as a fallback for platforms (e.g., GPU) that don't set slice_index. + should_sort_granules_by_key: Whether device granules should be sorted by the + granule key, either slice or process index, depending on + process_is_granule. Raises: ValueError: if the number of slices to which the `devices` belong doesn't @@ -356,7 +362,10 @@ def create_hybrid_device_mesh(mesh_shape: Sequence[int], granule_dict = collections.defaultdict(list) for dev in devices: granule_dict[getattr(dev, attr)].append(dev) - granules = [granule_dict[key] for key in sorted(granule_dict.keys())] + granules = ( + [granule_dict[key] for key in sorted(granule_dict.keys())] + if should_sort_granules_by_key + else granule_dict.values()) if np.prod(dcn_mesh_shape) != len(granules): raise ValueError( f'Number of slices {len(granules)} must equal the product of ' diff --git a/tests/mesh_utils_test.py b/tests/mesh_utils_test.py index a4b1fda5c0ea..0ed5ce4c521c 100644 --- a/tests/mesh_utils_test.py +++ b/tests/mesh_utils_test.py @@ -63,6 +63,7 @@ def mock_tpu_device(core_on_chip, xd, yd, zd, xp, yp, zp, slice_index): _validate_mocked_process_indices(devices, one_device_per_chip) return devices + # If this function raises, it's a bug in the test code! def _validate_mocked_process_indices(devices, one_device_per_chip): process_to_devices = collections.defaultdict(list) @@ -202,7 +203,48 @@ def test_create_hybrid_device_mesh(self, mesh_shape, dcn_mesh_shape): mesh_shape, dcn_mesh_shape, devices) total_mesh_shape = tuple( m1 * m2 for m1, m2 in zip(mesh_shape, dcn_mesh_shape)) - assert mesh.shape == total_mesh_shape + self.assertEqual(mesh.shape, total_mesh_shape) + + @parameterized.named_parameters( + ('2X4x4x4a', (1, 16, 4), (2, 1, 1)), + ('2X4x4x4b', (1, 4, 16), (1, 2, 1)), + ) + def test_create_hybrid_device_mesh_device_sorting( + self, + mesh_shape: tuple[int, ...], + dcn_mesh_shape: tuple[int, ...], + ): + devices = mock_tpu_devices(4, 4, 4, 'TPU v4', True, 2) + reversed_slices_devices = list( + np.flip(np.array(devices).reshape(2, -1), axis=0).flat) + mesh = mesh_utils.create_hybrid_device_mesh( + mesh_shape, + dcn_mesh_shape, + devices, + should_sort_granules_by_key=False, + ) + sorted_slices_mesh = mesh_utils.create_hybrid_device_mesh( + mesh_shape, + dcn_mesh_shape, + reversed_slices_devices, + should_sort_granules_by_key=True, + ) + np.testing.assert_array_equal(mesh, sorted_slices_mesh) + self.assertSetEqual( + {0, 1}, + {d.slice_index for d in sorted_slices_mesh.flat}, + ) + + reversed_slices_mesh = mesh_utils.create_hybrid_device_mesh( + mesh_shape, + dcn_mesh_shape, + reversed_slices_devices, + should_sort_granules_by_key=False, + ) + self.assertSetEqual( + {1, 0}, + {d.slice_index for d in reversed_slices_mesh.flat}, + ) @parameterized.named_parameters( # Physical ring order over tray