Skip to content

Commit

Permalink
Improve TPU v2 and v3 mesh_utils.create_device_mesh logic.
Browse files Browse the repository at this point in the history
* Fixes a bug when a non-3D mesh was requested
* Adds new logic when requesting a single-host mesh
* Extends logic to v2 as well as v3
  • Loading branch information
skye committed Mar 8, 2022
1 parent 7319982 commit bcee442
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 16 deletions.
29 changes: 22 additions & 7 deletions jax/experimental/mesh_utils.py
Expand Up @@ -22,6 +22,7 @@
import jax
import numpy as np

_TPU_V2 = 'TPU v2'
_TPU_V3 = 'TPU v3'
_TPU_V4 = 'TPU v4'

Expand Down Expand Up @@ -63,6 +64,9 @@
},
}

# Physical ordering of core IDs in a tray that creates a ring
_TRAY_RING_ORDER = (0, 1, 2, 3, 6, 7, 4, 5)


def _create_device_mesh_for_tpu_v4(
physical_mesh: np.ndarray, mesh_shape: Sequence[int]
Expand Down Expand Up @@ -245,13 +249,24 @@ def _create_device_mesh(process_0_devices, global_devices, device_kind,
mesh_shape: Sequence[int],
contiguous_submeshes: bool = False) -> np.ndarray:
# TODO(zhangqiaorjc): Handle TPU versions other than v4 more generally.
if device_kind == _TPU_V3:
device_mesh = np.asarray(global_devices).reshape(mesh_shape)
if mesh_shape[-1] == 8:
logging.info('Re-order TPUv3 device mesh for better performance.')
perm = np.array([0, 1, 2, 3, 6, 7, 4, 5])
device_mesh = device_mesh[:, :, perm]
return device_mesh
if device_kind in (_TPU_V2, _TPU_V3):
if len(global_devices) == 8:
logging.info('Reordering mesh to physical ring order on single-tray TPU v2/v3.')
device_mesh = np.asarray(global_devices)
device_mesh = device_mesh[np.array(_TRAY_RING_ORDER)]
device_mesh = device_mesh.reshape(mesh_shape)
return device_mesh
elif mesh_shape[-1] == 8:
device_mesh = np.asarray(global_devices).reshape(mesh_shape)
logging.info('Reordering mesh to physical ring order on each TPU v2/v3 tray.')
perm = np.array(_TRAY_RING_ORDER)
device_mesh = device_mesh[..., perm]
return device_mesh
else:
# TODO(skye): implement 2D mesh_shape logic here:
# https://github.com/tensorflow/lingvo/blob/0df40cf604dfcd14e28f7087d73687a0bd2fe5c6/lingvo/core/gshard_utils.py#L187
# (possibly replaces above mesh_shape[-1] == 8 case)
return np.asarray(global_devices).reshape(mesh_shape)
elif device_kind == _TPU_V4:
physical_mesh = _jax_devices_order_normalized(
process_0_devices, global_devices)
Expand Down
62 changes: 53 additions & 9 deletions tests/mesh_utils_test.py
Expand Up @@ -19,6 +19,8 @@
import dataclasses
from typing import Sequence

import numpy as np

from absl import logging
from absl.testing import absltest
from absl.testing import parameterized
Expand All @@ -30,6 +32,7 @@
@dataclasses.dataclass
class MockTpuDevice:
"""Mock TPU device for testing."""
id: int
platform: str
device_kind: str
process_index: int
Expand All @@ -46,21 +49,29 @@ def mock_devices(x, y, z, dev_kind, two_cores_per_chip):
for i in range(0, x, 2):
# Local 2x2 subgrid of chips, with 2 cores per chip.
host_devices = [
MockTpuDevice('tpu', dev_kind, process_index, (i, j, k), 0),
MockTpuDevice('tpu', dev_kind, process_index, (i, j, k), 1),
MockTpuDevice('tpu', dev_kind, process_index, (i + 1, j, k), 0),
MockTpuDevice('tpu', dev_kind, process_index, (i + 1, j, k), 1),
MockTpuDevice('tpu', dev_kind, process_index, (i, j + 1, k), 0),
MockTpuDevice('tpu', dev_kind, process_index, (i, j + 1, k), 1),
MockTpuDevice('tpu', dev_kind, process_index, (i + 1, j + 1, k), 0),
MockTpuDevice('tpu', dev_kind, process_index, (i + 1, j + 1, k), 1),
MockTpuDevice(-1, 'tpu', dev_kind, process_index, (i, j, k), 0),
MockTpuDevice(-1, 'tpu', dev_kind, process_index, (i, j, k), 1),
MockTpuDevice(-1, 'tpu', dev_kind, process_index, (i + 1, j, k), 0),
MockTpuDevice(-1, 'tpu', dev_kind, process_index, (i + 1, j, k), 1),
MockTpuDevice(-1, 'tpu', dev_kind, process_index, (i, j + 1, k), 0),
MockTpuDevice(-1, 'tpu', dev_kind, process_index, (i, j + 1, k), 1),
MockTpuDevice(-1, 'tpu', dev_kind, process_index, (i + 1, j + 1, k), 0),
MockTpuDevice(-1, 'tpu', dev_kind, process_index, (i + 1, j + 1, k), 1),
]
if two_cores_per_chip:
# Only include core_on_chip = 0.
host_devices = host_devices[::2]
devices.extend(host_devices)
# Simulate one process per host (1 host = 2x2x1 slice)
process_index += 1

# id grows in (z, y, x) major order
for d in devices:
i, j, k = d.coords
d.id = k*x*y + j*x + i
if not two_cores_per_chip:
d.id = d.id * 2 + d.core_on_chip

_validate_mocked_process_indices(devices, two_cores_per_chip)
return devices

Expand All @@ -86,8 +97,19 @@ def _validate_mocked_process_indices(devices, two_cores_per_chip):
expected.add((min_coords[0] + x, min_coords[1] + y, min_coords[2]))
assert set(d.coords for d in local_devices) == expected, local_devices


def mock_2x2_devices():
"""Hard-coded reproduction of jax.devices() output on v3-2x2."""
return mock_devices(2, 2, 1, 'TPU v3', False)


def mock_4x4_devices():
"""Hard-coded reproduction of jax.devices() output on v3-4x4."""
return mock_devices(4, 4, 1, 'TPU v3', False)


def mock_8x8_devices():
"""Hard-coded reproduction of jax.devices() output on 8x8."""
"""Hard-coded reproduction of jax.devices() output on v3-8x8."""
return mock_devices(8, 8, 1, 'TPU v3', False)


Expand Down Expand Up @@ -185,6 +207,28 @@ def test_create_device_mesh_for_tpu_v4(self, devices, mesh_shape,
physical_mesh, mesh_shape)
self.assertEqual(assignment, expected_assignment)

@parameterized.named_parameters(
# Physical ring order over tray
('2x2_1d', mock_2x2_devices, [8], [0, 1, 2, 3, 6, 7, 4, 5]),
# Reshaped physical ring order over tray
('2x2_2d', mock_2x2_devices, [2, 4], [[0, 1, 2, 3],
[6, 7, 4, 5]]),
# 4 per-tray rings
('4x4_2d', mock_4x4_devices, [4, 8], [[ 0, 1, 2, 3, 10, 11, 8, 9],
[ 4, 5, 6, 7, 14, 15, 12, 13],
[16, 17, 18, 19, 26, 27, 24, 25],
[20, 21, 22, 23, 30, 31, 28, 29]]),
)
def test_v3_create_device_mesh(self, devices, mesh_shape,
expected_device_id_mesh):
jax_local_devices_from_process_0 = mock_2x2_devices()
global_devices = devices()
mesh = mesh_utils._create_device_mesh(
jax_local_devices_from_process_0, global_devices,
global_devices[0].device_kind, mesh_shape, contiguous_submeshes=False)
device_id_mesh = np.vectorize(lambda d: d.id)(mesh)
self.assertAllClose(np.array(expected_device_id_mesh), device_id_mesh)

def _assert_contiguous_submeshes(self, global_device_mesh):
global_mesh = Mesh(global_device_mesh, list(range(global_device_mesh.ndim)))
max_process_index = max(d.process_index
Expand Down

0 comments on commit bcee442

Please sign in to comment.