From c12929b0120b771717d76ebe3b4f15124cb33c95 Mon Sep 17 00:00:00 2001 From: Tao Wang Date: Mon, 2 Oct 2023 13:04:03 -0700 Subject: [PATCH] Add more API set up for Mock GPU client. Also clean up previous mock GPU client API. PiperOrigin-RevId: 570153877 --- jax/_src/xla_bridge.py | 30 +++++++++++++++++- tests/BUILD | 15 +++++++++ tests/mock_gpu_test.py | 69 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 113 insertions(+), 1 deletion(-) create mode 100644 tests/mock_gpu_test.py diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 23e52cdebdd0..a9f5b750ce51 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -84,6 +84,18 @@ 'Restricts the set of ROCM devices that JAX will use. Either "all", or a ' 'comma-separate list of integer device IDs.') +_USE_MOCK_GPU_CLIENT = jax_config.DEFINE_bool( + name="use_mock_gpu_client", + default=False, + help="If True, use a mock GPU client instead of a real one.", +) + +_MOCK_NUM_GPUS = jax_config.DEFINE_integer( + name="mock_num_gpus", + default=1, + help="Mock GPU client number of gpus.", +) + # Backends @@ -221,12 +233,28 @@ def make_gpu_client( if platform_name == "cuda": _check_cuda_versions() + if xla_extension_version <= 199: + return xla_client.make_gpu_client( + distributed_client=distributed.global_state.client, + node_id=distributed.global_state.process_id, + num_nodes=distributed.global_state.num_processes, + platform_name=platform_name, + allowed_devices=allowed_devices, + ) + use_mock_gpu_client = _USE_MOCK_GPU_CLIENT.value + num_nodes = ( + _MOCK_NUM_GPUS.value + if use_mock_gpu_client + else distributed.global_state.num_processes + ) + return xla_client.make_gpu_client( distributed_client=distributed.global_state.client, node_id=distributed.global_state.process_id, - num_nodes=distributed.global_state.num_processes, + num_nodes=num_nodes, platform_name=platform_name, allowed_devices=allowed_devices, + mock=use_mock_gpu_client, # type: ignore[call-arg] ) diff --git a/tests/BUILD b/tests/BUILD index e5dd22fa0184..27f2c115db66 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -231,6 +231,21 @@ jax_test( ], ) +jax_test( + name = "mock_gpu_test", + srcs = ["mock_gpu_test.py"], + disable_backends = [ + "cpu", + "tpu", + ], + tags = [ + "config-cuda-only", + ], + deps = [ + "//jax:experimental", + ], +) + jax_test( name = "array_test", srcs = ["array_test.py"], diff --git a/tests/mock_gpu_test.py b/tests/mock_gpu_test.py new file mode 100644 index 000000000000..dced8882c357 --- /dev/null +++ b/tests/mock_gpu_test.py @@ -0,0 +1,69 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +import math + +from absl.testing import absltest +import jax +from jax import config +from jax._src import test_util as jtu +from jax._src.lib import xla_extension_version +import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P +import numpy as np + +config.parse_flags_with_absl() + + +class MockGPUTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + jax.config.update('use_mock_gpu_client', True) + + def tearDown(self): + jax.config.update('use_mock_gpu_client', False) + jax.config.update('mock_num_gpus', 1) + super().tearDown() + + def testMockWithSharding(self): + if xla_extension_version < 200: + return + num_shards = 16 + jax.config.update('mock_num_gpus', num_shards) + mesh_shape = (num_shards,) + axis_names = ('x',) + mesh_devices = np.array(jax.devices()).reshape(mesh_shape) + mesh = jax.sharding.Mesh(mesh_devices, axis_names) + @partial( + jax.jit, + in_shardings=NamedSharding(mesh, P('x',)), + out_shardings=NamedSharding(mesh, P('x',)), + ) + def f(x, y): + z = x @ y + return z @ y + + shape = (64, 64) + x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) + y = x + 1 + f_lowered = f.lower(x, y) + hlo = f_lowered.compiler_ir() + self.assertIn('sharding = "{devices=[16,1]<=[16]}"', str(hlo)) + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader())