From 4dd0d851393758c98b8002eea9d138e31b69f1f3 Mon Sep 17 00:00:00 2001 From: Sudhakar Date: Fri, 23 Sep 2022 12:11:56 -0700 Subject: [PATCH] add multihost pjit tests --- tests/multiprocess_gpu_test.py | 290 ++++++++++++++++++++++++++++++++- 1 file changed, 289 insertions(+), 1 deletion(-) diff --git a/tests/multiprocess_gpu_test.py b/tests/multiprocess_gpu_test.py index 5e8e58bc1762..9977b6b01353 100644 --- a/tests/multiprocess_gpu_test.py +++ b/tests/multiprocess_gpu_test.py @@ -17,16 +17,23 @@ import sys import threading import unittest +import functools from absl.testing import absltest from absl.testing import parameterized +import numpy as np import jax +from jax import experimental from jax.config import config from jax._src import distributed import jax.numpy as jnp from jax._src.lib import xla_extension_version from jax._src import test_util as jtu +from jax._src import util +from jax.experimental import global_device_array +from jax.experimental import maps +from jax.experimental import pjit try: import portpicker @@ -40,7 +47,6 @@ config.parse_flags_with_absl() - @unittest.skipIf(not portpicker, "Test requires portpicker") class DistributedTest(jtu.JaxTestCase): @@ -170,6 +176,49 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase): if pytest is not None: pytestmark = pytest.mark.SlurmMultiNodeGpuTest + def sorted_devices(self): + devices = sorted(jax.devices(), key=lambda d: (d.id, d.host_id)) + if len(devices) != 16: + raise unittest.SkipTest( + "Test assumes that it runs on 16 devices (2 nodes)") + return devices + + def create_2d_non_contiguous_mesh(self): + devices = self.sorted_devices() + device_mesh = np.array([[devices[0], devices[2]], + [devices[4], devices[6]], + [devices[1], devices[3]], + [devices[5], devices[7]], + [devices[8], devices[10]], + [devices[12], devices[14]], + [devices[9], devices[11]], + [devices[13], devices[15]]]) + # The mesh looks like this (the integers are process index): + # 0 2 + # 4 6 + # 1 3 + # 5 7 + # 8 10 + # 12 14 + # 9 11 + # 13 15 + assert [d.id for d in device_mesh.flat + ] == [0, 2, 4, 6, 1, 3, 5, 7, 8, 10, 12, 14, 9, 11, 13, 15] + return maps.Mesh(device_mesh, ("x", "y")) + + def setUp(self): + super().setUp() + self.xmap_spmd_lowering_enabled = jax.config.experimental_xmap_spmd_lowering + jax.config.update("experimental_xmap_spmd_lowering", True) + self.gda_enabled = jax.config.jax_parallel_functions_output_gda + jax.config.update('jax_parallel_functions_output_gda', True) + + def tearDown(self): + jax.config.update("experimental_xmap_spmd_lowering", + self.xmap_spmd_lowering_enabled) + jax.config.update('jax_parallel_functions_output_gda', self.gda_enabled) + super().tearDown() + def test_gpu_multi_node_initialize_and_psum(self): # Hookup the ENV vars expected to be set already in the SLURM environment @@ -224,5 +273,244 @@ def test_gpu_multi_node_transparent_initialize_and_psum(self): self.assertEqual(y[0], jax.device_count()) print(y) + # TODO(sudhakarsingh27): To change/omit test in favor of using `Array` + # since `GlobalDeviceArray` is going to be deprecated in the future + def test_pjit_gda_multi_input_multi_output(self): + jax.distributed.initialize() + global_mesh = jtu.create_global_mesh((8, 2), ("x", "y")) + global_input_shape = (16, 2) + global_input_data = np.arange( + util.prod(global_input_shape)).reshape(global_input_shape) + + def cb(index): + return global_input_data[index] + + mesh_axes1 = experimental.PartitionSpec("x", "y") + gda1 = global_device_array.GlobalDeviceArray.from_callback( + global_input_shape, global_mesh, mesh_axes1, cb) + mesh_axes2 = experimental.PartitionSpec("x") + gda2 = global_device_array.GlobalDeviceArray.from_callback( + global_input_shape, global_mesh, mesh_axes2, cb) + mesh_axes3 = experimental.PartitionSpec(("x", "y")) + gda3 = global_device_array.GlobalDeviceArray.from_callback( + global_input_shape, global_mesh, mesh_axes3, cb) + + with maps.Mesh(global_mesh.devices, global_mesh.axis_names): + + @functools.partial( + pjit.pjit, + # `FROM_GDA` will be replicated for all the inputs. + in_axis_resources=pjit.FROM_GDA, + out_axis_resources=(mesh_axes1, None, mesh_axes2)) + def f(x, y, z): + return x @ x.T, y, z + + out1, out2, out3 = f(gda1, gda2, gda3) + + self.assertIsInstance(out1, global_device_array.GlobalDeviceArray) + self.assertEqual(out1.shape, (16, 16)) + self.assertEqual(out1.local_shards[0].data.shape, (2, 8)) + self.assertDictEqual(out1.mesh.shape, {"x": 8, "y": 2}) + expected_matrix_mul = global_input_data @ global_input_data.T + for s in out1.local_shards: + np.testing.assert_array_equal(np.asarray(s.data), + expected_matrix_mul[s.index]) + + self.assertIsInstance(out2, global_device_array.GlobalDeviceArray) + self.assertEqual(out2.shape, (16, 2)) + self.assertEqual(out2.local_shards[0].data.shape, (16, 2)) + for s in out2.local_shards: + np.testing.assert_array_equal(np.asarray(s.data), global_input_data) + + self.assertIsInstance(out3, global_device_array.GlobalDeviceArray) + self.assertEqual(out3.shape, (16, 2)) + self.assertEqual(out3.local_shards[0].data.shape, (2, 2)) + for s in out3.local_shards: + np.testing.assert_array_equal(np.asarray(s.data), + global_input_data[s.index]) + + # TODO(sudhakarsingh27): To change/omit test in favor of using `Array` + # since `GlobalDeviceArray` is going to be deprecated in the future + def test_pjit_gda_non_contiguous_mesh(self): + jax.distributed.initialize() + devices = self.sorted_devices() + mesh_devices = np.array(devices[0:8:2] + devices[1:8:2] + devices[8:16:2] + + devices[9:16:2]) + # The device order in the below mesh is: + # [0, 2, 4, 6, 1, 3, 5, 7, 8, 10, 12, 14, 9, 11, 13, 15] + # each having the following process index: + # The process-gpu mapping is random: @sudhakarsingh27 to figure out why so + # and the data is: + # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + global_mesh = maps.Mesh(mesh_devices, ("x",)) + global_input_shape = (16,) + mesh_axes = experimental.PartitionSpec("x") + global_input_data = np.arange( + util.prod(global_input_shape)).reshape(global_input_shape) + + def cb(index): + return global_input_data[index] + + gda1 = global_device_array.GlobalDeviceArray.from_callback( + global_input_shape, global_mesh, mesh_axes, cb) + + # device_id -> (index, replica_id) + expected_idx_rid = { + 0: ((slice(0, 1),), 0), + 1: ((slice(4, 5),), 0), + 2: ((slice(1, 2),), 0), + 3: ((slice(5, 6),), 0), + 4: ((slice(2, 3),), 0), + 5: ((slice(6, 7),), 0), + 6: ((slice(3, 4),), 0), + 7: ((slice(7, 8),), 0), + 8: ((slice(8, 9),), 0), + 9: ((slice(12, 13),), 0), + 10: ((slice(9, 10),), 0), + 11: ((slice(13, 14),), 0), + 12: ((slice(10, 11),), 0), + 13: ((slice(14, 15),), 0), + 14: ((slice(11, 12),), 0), + 15: ((slice(15, 16),), 0), + } + + with maps.Mesh(global_mesh.devices, global_mesh.axis_names): + f = pjit.pjit(lambda x: x, + in_axis_resources=pjit.FROM_GDA, + out_axis_resources=mesh_axes) + out = f(gda1) + for s in out.local_shards: + device_id = s.device.id + expected_index = expected_idx_rid[device_id][0] + expected_replica_id = expected_idx_rid[device_id][1] + self.assertEqual(s.index, expected_index) + self.assertEqual(s.replica_id, expected_replica_id) + self.assertEqual(s.data.shape, (1,)) + np.testing.assert_array_equal(np.asarray(s.data), + global_input_data[expected_index]) + + # TODO(sudhakarsingh27): To change/omit test in favor of using `Array` + # since `GlobalDeviceArray` is going to be deprecated in the future + def test_pjit_gda_non_contiguous_mesh_2d(self): + jax.distributed.initialize() + global_mesh = self.create_2d_non_contiguous_mesh() + global_input_shape = (16, 2) + mesh_axes = experimental.PartitionSpec("x", "y") + global_input_data = np.arange( + util.prod(global_input_shape)).reshape(global_input_shape) + + def cb(index): + return global_input_data[index] + + gda1 = global_device_array.GlobalDeviceArray.from_callback( + global_input_shape, global_mesh, mesh_axes, cb) + + # device_id -> (index, replica_id) + expected_idx_rid = { + 0: ((slice(0, 2), slice(0, 1)), 0), + 1: ((slice(4, 6), slice(0, 1)), 0), + 2: ((slice(0, 2), slice(1, 2)), 0), + 3: ((slice(4, 6), slice(1, 2)), 0), + 4: ((slice(2, 4), slice(0, 1)), 0), + 5: ((slice(6, 8), slice(0, 1)), 0), + 6: ((slice(2, 4), slice(1, 2)), 0), + 7: ((slice(6, 8), slice(1, 2)), 0), + 8: ((slice(8, 10), slice(0, 1)), 0), + 9: ((slice(12, 14), slice(0, 1)), 0), + 10: ((slice(8, 10), slice(1, 2)), 0), + 11: ((slice(12, 14), slice(1, 2)), 0), + 12: ((slice(10, 12), slice(0, 1)), 0), + 13: ((slice(14, 16), slice(0, 1)), 0), + 14: ((slice(10, 12), slice(1, 2)), 0), + 15: ((slice(14, 16), slice(1, 2)), 0), + } + + with global_mesh: + f = pjit.pjit(lambda x: x, + in_axis_resources=pjit.FROM_GDA, + out_axis_resources=mesh_axes) + out = f(gda1) + + for s in out.local_shards: + device_id = s.device.id + expected_index = expected_idx_rid[device_id][0] + expected_replica_id = expected_idx_rid[device_id][1] + self.assertEqual(s.index, expected_index) + self.assertEqual(s.replica_id, expected_replica_id) + self.assertEqual(s.data.shape, (2, 1)) + np.testing.assert_array_equal(np.asarray(s.data), + global_input_data[expected_index]) + + with global_mesh: + f = pjit.pjit(lambda x: x, + in_axis_resources=experimental.PartitionSpec(None), + out_axis_resources=mesh_axes) + # Fully replicated values allows a non-contiguous mesh. + out = f(global_input_data) + self.assertIsInstance(out, global_device_array.GlobalDeviceArray) + + with global_mesh: + f = pjit.pjit(lambda x: x, + in_axis_resources=None, + out_axis_resources=mesh_axes) + # Fully replicated values allows a non-contiguous mesh. + out = f(global_input_data) + self.assertIsInstance(out, global_device_array.GlobalDeviceArray) + + gda2 = global_device_array.GlobalDeviceArray.from_callback( + global_input_shape, global_mesh, experimental.PartitionSpec(None), cb) + + with global_mesh: + f = pjit.pjit(lambda x, y: (x, y), + in_axis_resources=(None, None), + out_axis_resources=(mesh_axes, mesh_axes)) + # Fully replicated values + GDA allows a non-contiguous mesh. + out1, out2 = f(global_input_data, gda2) + self.assertIsInstance(out1, global_device_array.GlobalDeviceArray) + self.assertIsInstance(out2, global_device_array.GlobalDeviceArray) + + # TODO(sudhakarsingh27): To change/omit test in favor of using `Array` + # since `GlobalDeviceArray` is going to be deprecated in the future + def test_pjit_gda_non_contiguous_mesh_2d_aot(self): + jax.distributed.initialize() + global_mesh = self.create_2d_non_contiguous_mesh() + global_input_shape = (8, 2) + mesh_axes = experimental.PartitionSpec("x", "y") + global_input_data = np.arange( + util.prod(global_input_shape)).reshape(global_input_shape) + gda1 = global_device_array.GlobalDeviceArray.from_callback( + global_input_shape, global_mesh, mesh_axes, + lambda idx: global_input_data[idx]) + + with global_mesh: + f = pjit.pjit(lambda x, y: (x, y), + in_axis_resources=experimental.PartitionSpec("x", "y"), + out_axis_resources=experimental.PartitionSpec("x", "y")) + inp_aval = jax.ShapedArray((8, 2), jnp.int32) + # `ShapedArray` is considered global when lowered and compiled. + # Hence it can bypass the contiguous mesh restriction. + compiled = f.lower(inp_aval, gda1, _global_avals=True).compile() + out1, out2 = compiled(gda1, gda1) + self.assertIsInstance(out1, global_device_array.GlobalDeviceArray) + self.assertEqual(out1.shape, (8, 2)) + self.assertIsInstance(out2, global_device_array.GlobalDeviceArray) + self.assertEqual(out2.shape, (8, 2)) + + # TODO(sudhakarsingh27): To change/omit test in favor of using `Array` + # since `GlobalDeviceArray` is going to be deprecated in the future + def test_pjit_gda_eval_shape(self): + jax.distributed.initialize() + + with jtu.create_global_mesh((16,), ("x")): + + @functools.partial(pjit.pjit, + in_axis_resources=experimental.PartitionSpec(None), + out_axis_resources=experimental.PartitionSpec("x")) + def f(): + return jnp.zeros([32, 10]) + + self.assertEqual(f().shape, (32, 10)) + self.assertEqual(jax.eval_shape(f).shape, (32, 10)) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())