forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 2
/
mesh_utils_test.py
331 lines (282 loc) · 12.9 KB
/
mesh_utils_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for mesh utils."""
import collections
from collections.abc import Sequence
import dataclasses
import numpy as np
from absl import logging
from absl.testing import absltest
from absl.testing import parameterized
from jax.experimental import mesh_utils
from jax.sharding import Mesh
from jax._src import test_util
@dataclasses.dataclass(frozen=True)
class MockTpuDevice:
"""Mock TPU device for testing."""
id: int
platform: str
device_kind: str
process_index: int
coords: Sequence[int]
core_on_chip: int
slice_index: int = 0
def mock_tpu_devices(x, y, z, dev_kind, one_device_per_chip, num_slices=1,
reorder=False):
"""Produce fake jax.devices() output for a TPU slice."""
cores_per_chip = 1 if one_device_per_chip else 2
nxd, nyd, nzd = (2, 2, 1)
nxp, nyp, nzp = x // nxd, y // nyd, z // nzd
def mock_tpu_device(core_on_chip, xd, yd, zd, xp, yp, zp, slice_index):
process_index = xp + nxp * (yp + nyp * (zp + nzp * slice_index))
coords = (xd + nxd * xp, yd + nyd * yp, zd + nzd * zp)
device_id = core_on_chip + cores_per_chip * (xd + nxd * (xp + nxp * (
yd + nyd * (yp + nyp * (zd + nzd * (zp + nzp * slice_index))))))
return MockTpuDevice(device_id, 'tpu', dev_kind, process_index, coords,
core_on_chip, slice_index)
devices = [mock_tpu_device(core_on_chip, xd, yd, zd, xp, yp, zp, slice_index)
for slice_index in range(num_slices)
for zp in range(nzp) for yp in range(nyp) for xp in range(nxp)
for zd in range(nzd) for yd in range(nyd) for xd in range(nxd)
for core_on_chip in range(cores_per_chip)]
if reorder:
devices = devices[::-1]
_validate_mocked_process_indices(devices, one_device_per_chip)
return devices
# If this function raises, it's a bug in the test code!
def _validate_mocked_process_indices(devices, one_device_per_chip):
process_to_devices = collections.defaultdict(list)
for d in devices:
process_to_devices[d.process_index].append(d)
for local_devices in process_to_devices.values():
if one_device_per_chip:
# 4 devices per process
assert len(local_devices) == 4, local_devices
else:
# 8 devices per process
assert len(local_devices) == 8, local_devices
# All devices have same z coord
assert len({d.coords[2] for d in local_devices}) == 1, local_devices
# All devices in a 2x2 subgrid
min_coords = min(d.coords for d in local_devices)
expected = set()
for x, y in [(0,0), (0,1), (1,0), (1,1)]:
expected.add((min_coords[0] + x, min_coords[1] + y, min_coords[2]))
assert {d.coords for d in local_devices} == expected, local_devices
def mock_2x2_devices():
"""Hard-coded reproduction of jax.devices() output on v3-2x2."""
return mock_tpu_devices(2, 2, 1, 'TPU v3', False)
def mock_4x4_devices():
"""Hard-coded reproduction of jax.devices() output on v3-4x4."""
return mock_tpu_devices(4, 4, 1, 'TPU v3', False)
def mock_8x8_devices(one_device_per_chip=False):
"""Hard-coded reproduction of jax.devices() output on v3-8x8."""
return mock_tpu_devices(8, 8, 1, 'TPU v3', one_device_per_chip)
def mock_2x2x1_devices(one_device_per_chip):
"""Hard-coded reproduction of jax.devices() output on 2x2x1."""
return mock_tpu_devices(2, 2, 1, 'TPU v4', one_device_per_chip)
def mock_2x2x4_devices(one_device_per_chip):
"""Hard-coded reproduction of jax.devices() output on 2x2x4."""
return mock_tpu_devices(2, 2, 4, 'TPU v4', one_device_per_chip)
def mock_4x4x4_devices(one_device_per_chip):
"""Hard-coded reproduction of jax.devices() output on 4x4x4."""
return mock_tpu_devices(4, 4, 4, 'TPU v4', one_device_per_chip)
def mock_4x4x8_devices(one_device_per_chip):
"""Hard-coded reproduction of jax.devices() output on 4x4x8."""
return mock_tpu_devices(4, 4, 8, 'TPU v4', one_device_per_chip)
def mock_8x8x8_devices(one_device_per_chip):
"""Hard-coded reproduction of jax.devices() output on 8x8x8."""
return mock_tpu_devices(8, 8, 8, 'TPU v4', one_device_per_chip)
def mock_4x8x8_devices(one_device_per_chip):
"""Hard-coded reproduction of jax.devices() output on 4x8x8."""
return mock_tpu_devices(4, 8, 8, 'TPU v4', one_device_per_chip)
def mock_4x8x16_devices(one_device_per_chip):
"""Hard-coded reproduction of jax.devices() output on 4x8x16."""
return mock_tpu_devices(4, 8, 16, 'TPU v4', one_device_per_chip)
def mock_8x8x16_devices(one_device_per_chip):
"""Hard-coded reproduction of jax.devices() output on 8x8x16."""
return mock_tpu_devices(8, 8, 16, 'TPU v4', one_device_per_chip)
class MeshUtilsTest(test_util.JaxTestCase):
@parameterized.named_parameters(
('2x2x1_t', mock_2x2x1_devices, True, (2, 2, 1, 1)),
('2x2x1_f', mock_2x2x1_devices, False, (2, 2, 1, 2)),
('8x8x16_t', mock_8x8x16_devices, True, (8, 8, 16, 1)),
('8x8x16_f', mock_8x8x16_devices, False, (8, 8, 16, 2)),
)
def test_bounds_from_last_device(self, devices, one_device_per_chip,
expected_bounds):
self.assertEqual(
mesh_utils._bounds_from_last_device(devices(one_device_per_chip)[-1]),
expected_bounds)
@parameterized.named_parameters(
('4x4x4_t', (4, 4, 4), True),
('4x4x4_f', (4, 4, 4), False),
('8x8x16_t', (8, 8, 16), True),
('8x8x16_f', (8, 8, 16), False),
)
def test_get_physical_tpu_mesh(self, xyz, reorder):
x, y, z = xyz
jax_devices = mock_tpu_devices(x, y, z, 'TPU v4', True, reorder=reorder)
normalized = mesh_utils._get_physical_tpu_mesh(jax_devices)
self.assertEqual(normalized.shape, xyz)
# major_to_minor: x, y, z
for i in range(x):
for j in range(y):
for k in range(z):
self.assertEqual(normalized[i, j, k].coords, (i, j, k))
@parameterized.named_parameters(
('2x2x1', mock_2x2x1_devices, [1, 1, 4], [(), (), (0, 1, 2)]),
('2x2x4', mock_2x2x4_devices, [1, 4, 4], [(), (2,), (0, 1)]),
('4x4x4', mock_4x4x4_devices, [1, 16, 4], [(), (1, 2), (0,)]),
('4x4x8a', mock_4x4x8_devices, [1, 16, 8], [(), (0, 1), (2,)]),
('4x4x8b', mock_4x4x8_devices, [1, 8, 16], [(), (2,), (0, 1)]),
('4x4x8c', mock_4x4x8_devices, [16, 8, 1], [(0, 1), (2,), ()]),
('4x8x8', mock_4x8x8_devices, [1, 32, 8], [(), (0, 2), (1,)]),
('8x8x8', mock_8x8x8_devices, [1, 64, 8], [(), (1, 2), (0,)]),
('8x8x16', mock_8x8x16_devices, [1, 64, 16], [(), (0, 1), (2,)]),
('8x8', mock_8x8_devices, [8, 8], [(1,), (0, 2)])
)
def test_create_device_mesh_for_nd_torus(self, devices, mesh_shape,
expected_assignment):
jax_devices = devices(True)
physical_mesh = mesh_utils._get_physical_tpu_mesh(jax_devices)
_, assignment = mesh_utils._create_device_mesh_for_nd_torus(
physical_mesh, mesh_shape)
self.assertEqual(assignment, expected_assignment)
@parameterized.named_parameters(
('2X4x4x4a', (1, 16, 4), (2, 1, 1)),
('2X4x4x4b', (1, 4, 16), (1, 2, 1)),
)
def test_create_hybrid_device_mesh(self, mesh_shape, dcn_mesh_shape):
devices = mock_tpu_devices(4, 4, 4, 'TPU v4', True, 2)
mesh = mesh_utils.create_hybrid_device_mesh(
mesh_shape, dcn_mesh_shape, devices)
total_mesh_shape = tuple(
m1 * m2 for m1, m2 in zip(mesh_shape, dcn_mesh_shape))
self.assertEqual(mesh.shape, total_mesh_shape)
@parameterized.named_parameters(
('2X4x4x4a', (1, 16, 4), (2, 1, 1)),
('2X4x4x4b', (1, 4, 16), (1, 2, 1)),
)
def test_create_hybrid_device_mesh_device_sorting(
self,
mesh_shape: tuple[int, ...],
dcn_mesh_shape: tuple[int, ...],
):
devices = mock_tpu_devices(4, 4, 4, 'TPU v4', True, 2)
reversed_slices_devices = list(
np.flip(np.array(devices).reshape(2, -1), axis=0).flat)
mesh = mesh_utils.create_hybrid_device_mesh(
mesh_shape,
dcn_mesh_shape,
devices,
should_sort_granules_by_key=False,
)
sorted_slices_mesh = mesh_utils.create_hybrid_device_mesh(
mesh_shape,
dcn_mesh_shape,
reversed_slices_devices,
should_sort_granules_by_key=True,
)
np.testing.assert_array_equal(mesh, sorted_slices_mesh)
self.assertSetEqual(
{0, 1},
{d.slice_index for d in sorted_slices_mesh.flat},
)
reversed_slices_mesh = mesh_utils.create_hybrid_device_mesh(
mesh_shape,
dcn_mesh_shape,
reversed_slices_devices,
should_sort_granules_by_key=False,
)
self.assertSetEqual(
{1, 0},
{d.slice_index for d in reversed_slices_mesh.flat},
)
@parameterized.named_parameters(
# Physical ring order over tray
('2x2_1d', mock_2x2_devices, [8], [0, 1, 2, 3, 6, 7, 4, 5]),
# Reshaped physical ring order over tray
('2x2_2d', mock_2x2_devices, [2, 4], [[0, 1, 2, 3],
[6, 7, 4, 5]]),
# 4 per-tray rings
('4x4_2d', mock_4x4_devices, [4, 8], [[ 0, 1, 2, 3, 10, 11, 8, 9],
[ 4, 5, 6, 7, 14, 15, 12, 13],
[16, 17, 18, 19, 26, 27, 24, 25],
[20, 21, 22, 23, 30, 31, 28, 29]]),
)
def test_v3_create_device_mesh(self, devices, mesh_shape,
expected_device_id_mesh):
global_devices = devices()
mesh = mesh_utils.create_device_mesh(
mesh_shape, devices=global_devices, contiguous_submeshes=False)
device_id_mesh = np.vectorize(lambda d: d.id)(mesh)
self.assertAllClose(np.array(expected_device_id_mesh), device_id_mesh)
def _assert_contiguous_submeshes(self, global_device_mesh):
global_mesh = Mesh(global_device_mesh, list(range(global_device_mesh.ndim)))
max_process_index = max(d.process_index
for d in global_device_mesh.flatten())
for p_idx in range(max_process_index + 1):
# Raises an error if non-contiguous
global_mesh._local_mesh(p_idx)
def test_create_contiguous_submeshes_for_tpu_v4(self):
v4 = mesh_utils._TPU_V4
for topology, mesh_shapes in mesh_utils._TRANSPOSE_TRICKS.items():
logging.vlog(1, "topology: %s", topology)
devices = mock_tpu_devices(topology[0], topology[1], topology[2], v4,
one_device_per_chip=True)
for mesh_shape in mesh_shapes:
logging.vlog(1, " mesh_shape: %s", mesh_shape)
mesh = mesh_utils.create_device_mesh(
mesh_shape, devices=devices, contiguous_submeshes=True)
self._assert_contiguous_submeshes(mesh)
def test_create_contiguous_submeshes_for_tpu_v4_leading_1_dims(self):
v4 = mesh_utils._TPU_V4
for topology, mesh_shapes in mesh_utils._TRANSPOSE_TRICKS.items():
logging.vlog(1, "topology: %s", topology)
devices = mock_tpu_devices(topology[0], topology[1], topology[2], v4,
one_device_per_chip=True)
for mesh_shape in mesh_shapes:
logging.vlog(1, ' mesh_shape: %s', (1, 1) + mesh_shape + (1, 1))
mesh = mesh_utils.create_device_mesh(
(1, 1) + mesh_shape + (1, 1),
devices=devices,
contiguous_submeshes=True)
self._assert_contiguous_submeshes(mesh)
def test_create_contiguous_submeshes_errors(self):
v4 = mesh_utils._TPU_V4
topology = (4, 4, 8)
mesh_shape = (1, 16, 8)
devices = mock_tpu_devices(topology[0], topology[1], topology[2], v4,
one_device_per_chip=True)
with self.assertRaisesWithLiteralMatch(
ValueError,
"create_device_mesh cannot create contiguous submeshes for "
"physical mesh topology (4, 4, 8)"):
mesh_utils.create_device_mesh(
mesh_shape, devices=devices, contiguous_submeshes=True)
topology = (4, 8, 8)
mesh_shape = (1, 128, 2)
devices = mock_tpu_devices(topology[0], topology[1], topology[2], v4,
one_device_per_chip=True)
with self.assertRaisesWithLiteralMatch(
ValueError,
"create_device_mesh cannot create contiguous submeshes for mesh_shape "
"(1, 128, 2) and physical mesh topology (4, 8, 8). "
'Available mesh_shapes: [(64, 4), (4, 64)]'):
mesh_utils.create_device_mesh(
mesh_shape, devices=devices, contiguous_submeshes=True)
if __name__ == '__main__':
absltest.main()