diff --git a/jax/experimental/mesh_utils.py b/jax/experimental/mesh_utils.py index fc1336e06dd9..bf9a7e1b29fd 100644 --- a/jax/experimental/mesh_utils.py +++ b/jax/experimental/mesh_utils.py @@ -22,6 +22,7 @@ import jax import numpy as np +_TPU_V2 = 'TPU v2' _TPU_V3 = 'TPU v3' _TPU_V4 = 'TPU v4' @@ -63,6 +64,9 @@ }, } +# Physical ordering of core IDs in a tray that creates a ring +_TRAY_RING_ORDER = (0, 1, 2, 3, 6, 7, 4, 5) + def _create_device_mesh_for_tpu_v4( physical_mesh: np.ndarray, mesh_shape: Sequence[int] @@ -245,13 +249,24 @@ def _create_device_mesh(process_0_devices, global_devices, device_kind, mesh_shape: Sequence[int], contiguous_submeshes: bool = False) -> np.ndarray: # TODO(zhangqiaorjc): Handle TPU versions other than v4 more generally. - if device_kind == _TPU_V3: - device_mesh = np.asarray(global_devices).reshape(mesh_shape) - if mesh_shape[-1] == 8: - logging.info('Re-order TPUv3 device mesh for better performance.') - perm = np.array([0, 1, 2, 3, 6, 7, 4, 5]) - device_mesh = device_mesh[:, :, perm] - return device_mesh + if device_kind in (_TPU_V2, _TPU_V3): + if len(global_devices) == 8: + logging.info('Reordering mesh to physical ring order on single-tray TPU v2/v3.') + device_mesh = np.asarray(global_devices) + device_mesh = device_mesh[np.array(_TRAY_RING_ORDER)] + device_mesh = device_mesh.reshape(mesh_shape) + return device_mesh + elif mesh_shape[-1] == 8: + device_mesh = np.asarray(global_devices).reshape(mesh_shape) + logging.info('Reordering mesh to physical ring order on each TPU v2/v3 tray.') + perm = np.array(_TRAY_RING_ORDER) + device_mesh = device_mesh[..., perm] + return device_mesh + else: + # TODO(skye): implement 2D mesh_shape logic here: + # https://github.com/tensorflow/lingvo/blob/0df40cf604dfcd14e28f7087d73687a0bd2fe5c6/lingvo/core/gshard_utils.py#L187 + # (possibly replaces above mesh_shape[-1] == 8 case) + return np.asarray(global_devices).reshape(mesh_shape) elif device_kind == _TPU_V4: physical_mesh = _jax_devices_order_normalized( process_0_devices, global_devices) diff --git a/tests/mesh_utils_test.py b/tests/mesh_utils_test.py index 5a4d9868db0e..f322acdfdbc8 100644 --- a/tests/mesh_utils_test.py +++ b/tests/mesh_utils_test.py @@ -19,6 +19,8 @@ import dataclasses from typing import Sequence +import numpy as np + from absl import logging from absl.testing import absltest from absl.testing import parameterized @@ -30,6 +32,7 @@ @dataclasses.dataclass class MockTpuDevice: """Mock TPU device for testing.""" + id: int platform: str device_kind: str process_index: int @@ -46,14 +49,14 @@ def mock_devices(x, y, z, dev_kind, two_cores_per_chip): for i in range(0, x, 2): # Local 2x2 subgrid of chips, with 2 cores per chip. host_devices = [ - MockTpuDevice('tpu', dev_kind, process_index, (i, j, k), 0), - MockTpuDevice('tpu', dev_kind, process_index, (i, j, k), 1), - MockTpuDevice('tpu', dev_kind, process_index, (i + 1, j, k), 0), - MockTpuDevice('tpu', dev_kind, process_index, (i + 1, j, k), 1), - MockTpuDevice('tpu', dev_kind, process_index, (i, j + 1, k), 0), - MockTpuDevice('tpu', dev_kind, process_index, (i, j + 1, k), 1), - MockTpuDevice('tpu', dev_kind, process_index, (i + 1, j + 1, k), 0), - MockTpuDevice('tpu', dev_kind, process_index, (i + 1, j + 1, k), 1), + MockTpuDevice(-1, 'tpu', dev_kind, process_index, (i, j, k), 0), + MockTpuDevice(-1, 'tpu', dev_kind, process_index, (i, j, k), 1), + MockTpuDevice(-1, 'tpu', dev_kind, process_index, (i + 1, j, k), 0), + MockTpuDevice(-1, 'tpu', dev_kind, process_index, (i + 1, j, k), 1), + MockTpuDevice(-1, 'tpu', dev_kind, process_index, (i, j + 1, k), 0), + MockTpuDevice(-1, 'tpu', dev_kind, process_index, (i, j + 1, k), 1), + MockTpuDevice(-1, 'tpu', dev_kind, process_index, (i + 1, j + 1, k), 0), + MockTpuDevice(-1, 'tpu', dev_kind, process_index, (i + 1, j + 1, k), 1), ] if two_cores_per_chip: # Only include core_on_chip = 0. @@ -61,6 +64,14 @@ def mock_devices(x, y, z, dev_kind, two_cores_per_chip): devices.extend(host_devices) # Simulate one process per host (1 host = 2x2x1 slice) process_index += 1 + + # id grows in (z, y, x) major order + for d in devices: + i, j, k = d.coords + d.id = k*x*y + j*x + i + if not two_cores_per_chip: + d.id = d.id * 2 + d.core_on_chip + _validate_mocked_process_indices(devices, two_cores_per_chip) return devices @@ -86,8 +97,19 @@ def _validate_mocked_process_indices(devices, two_cores_per_chip): expected.add((min_coords[0] + x, min_coords[1] + y, min_coords[2])) assert set(d.coords for d in local_devices) == expected, local_devices + +def mock_2x2_devices(): + """Hard-coded reproduction of jax.devices() output on v3-2x2.""" + return mock_devices(2, 2, 1, 'TPU v3', False) + + +def mock_4x4_devices(): + """Hard-coded reproduction of jax.devices() output on v3-4x4.""" + return mock_devices(4, 4, 1, 'TPU v3', False) + + def mock_8x8_devices(): - """Hard-coded reproduction of jax.devices() output on 8x8.""" + """Hard-coded reproduction of jax.devices() output on v3-8x8.""" return mock_devices(8, 8, 1, 'TPU v3', False) @@ -185,6 +207,28 @@ def test_create_device_mesh_for_tpu_v4(self, devices, mesh_shape, physical_mesh, mesh_shape) self.assertEqual(assignment, expected_assignment) + @parameterized.named_parameters( + # Physical ring order over tray + ('2x2_1d', mock_2x2_devices, [8], [0, 1, 2, 3, 6, 7, 4, 5]), + # Reshaped physical ring order over tray + ('2x2_2d', mock_2x2_devices, [2, 4], [[0, 1, 2, 3], + [6, 7, 4, 5]]), + # 4 per-tray rings + ('4x4_2d', mock_4x4_devices, [4, 8], [[ 0, 1, 2, 3, 10, 11, 8, 9], + [ 4, 5, 6, 7, 14, 15, 12, 13], + [16, 17, 18, 19, 26, 27, 24, 25], + [20, 21, 22, 23, 30, 31, 28, 29]]), + ) + def test_v3_create_device_mesh(self, devices, mesh_shape, + expected_device_id_mesh): + jax_local_devices_from_process_0 = mock_2x2_devices() + global_devices = devices() + mesh = mesh_utils._create_device_mesh( + jax_local_devices_from_process_0, global_devices, + global_devices[0].device_kind, mesh_shape, contiguous_submeshes=False) + device_id_mesh = np.vectorize(lambda d: d.id)(mesh) + self.assertAllClose(np.array(expected_device_id_mesh), device_id_mesh) + def _assert_contiguous_submeshes(self, global_device_mesh): global_mesh = Mesh(global_device_mesh, list(range(global_device_mesh.ndim))) max_process_index = max(d.process_index