From 607e7033a6b51ca48558f9e049c3e149850e4a5e Mon Sep 17 00:00:00 2001 From: Jean-Baptiste Lespiau Date: Fri, 18 Feb 2022 03:18:19 -0800 Subject: [PATCH] Turn execute_replicated into a class so we can access its fields. It's more readable than inspecting the internals of a `functools.partial`. PiperOrigin-RevId: 429523075 --- jax/_src/api.py | 11 +++++----- jax/interpreters/pxla.py | 43 +++++++++++++++++++++++++++------------- tests/pmap_test.py | 7 +++++++ 3 files changed, 42 insertions(+), 19 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 92d74cd40a71..a582902c543f 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2100,18 +2100,19 @@ def cache_miss(*args, **kwargs): execute = pxla.parallel_callable.most_recent_entry() use_fastpath = ( execute is not None and - # We don't support JAX extension backends. In particular, some - # extentions do not return a partial with a `func` attribute. - getattr(execute[0], "func", None) is pxla.execute_replicated and + # We don't support JAX extension backends. + isinstance(execute[0], pxla.ExecuteReplicated) and # No tracers in the outputs. Checking for ShardedDeviceArray should be # sufficient, but we use the more general `DeviceArray`. all(isinstance(x, device_array.DeviceArray) for x in out_flat)) ### If we can use the fastpath, we return required info to the caller. if use_fastpath: - xla_executable, backend_, in_handler, out_handler = execute[0].args + execute_replicated = execute[0] + out_handler = execute_replicated.out_handler + in_handler = execute_replicated.in_handler fastpath_data = _PmapFastpathData( version=1, - xla_executable=xla_executable, + xla_executable=execute_replicated.xla_executable, in_handler=in_handler, out_handler=out_handler, out_pytree_def=out_pytree_def, diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 40ae415c873f..f1a4c5976155 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -1216,8 +1216,7 @@ def from_hlo(xla_computation, pci.backend, xla_computation, compile_options) handle_args = InputsHandler( compiled.local_devices(), input_sharding_specs, input_indices) - execute_fun = partial( - execute_replicated, compiled, pci.backend, handle_args, handle_outs) + execute_fun = ExecuteReplicated(compiled, pci.backend, handle_args, handle_outs) fingerprint = getattr(compiled, "fingerprint", None) return PmapExecutable(compiled, execute_fun, fingerprint, pci.avals) @@ -1386,6 +1385,12 @@ def __init__(self, local_devices, sharding_specs, input_indices): def __call__(self, input_buffers): return self.handler(input_buffers) + def __str__(self): + return ("InputsHandler(\n" + f"local_devices={self.local_devices},\n" + f"sharding_specs={self.sharding_specs},\n" + f"input_indices={self.input_indices})") + class ResultsHandler: __slots__ = ("handlers", "out_specs", "out_indices", "unmapped_local_out_avals") @@ -1540,14 +1545,25 @@ def partitioned_sharding_spec(num_partitions: int, mesh_mapping=map(ShardedAxis, range(len(partitions)))) -@profiler.annotate_function -def execute_replicated(compiled, backend, in_handler, out_handler, *args): - input_bufs = in_handler(args) - out_bufs = compiled.execute_sharded_on_local_devices(input_bufs) - if dispatch.needs_check_special(): - for bufs in out_bufs: - dispatch.check_special("parallel computation", bufs) - return out_handler(out_bufs) +class ExecuteReplicated: + """The logic to shard inputs, execute a replicated model, returning outputs.""" + __slots__ = ['xla_executable', 'backend', 'in_handler', 'out_handler'] + + def __init__(self, xla_executable, backend, in_handler: InputsHandler, + out_handler: ResultsHandler): + self.xla_executable = xla_executable + self.backend = backend + self.in_handler = in_handler + self.out_handler = out_handler + + @profiler.annotate_function + def __call__(self, *args): + input_bufs = self.in_handler(args) + out_bufs = self.xla_executable.execute_sharded_on_local_devices(input_bufs) + if dispatch.needs_check_special(): + for bufs in out_bufs: + dispatch.check_special("parallel computation", bufs) + return self.out_handler(out_bufs) xla_pmap_p = core.MapPrimitive('xla_pmap') @@ -2326,10 +2342,9 @@ def from_hlo(name: str, else: with dispatch.log_elapsed_time(f"Finished XLA compilation of {name} " "in {elapsed_time} sec"): - compiled = dispatch.compile_or_get_cached(backend, computation, compile_options) - handle_args = InputsHandler(compiled.local_devices(), input_specs, input_indices) - unsafe_call = partial(execute_replicated, compiled, backend, handle_args, handle_outs) - xla_executable = compiled + xla_executable = dispatch.compile_or_get_cached(backend, computation, compile_options) + handle_args = InputsHandler(xla_executable.local_devices(), input_specs, input_indices) + unsafe_call = ExecuteReplicated(xla_executable, backend, handle_args, handle_outs) return MeshExecutable(xla_executable, unsafe_call, input_avals) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index a00dde7a131f..a5ef0ccd8ae7 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -1903,6 +1903,13 @@ class CppPmapTest(PythonPmapTest): def pmap(self): return src_api._cpp_pmap + def pmap_fast_path_is_enabled(self): + num_devices = jax.device_count() + f = jax.pmap(lambda x: x+1) + size = f._cache_size() + f(np.zeros([num_devices], dtype=np.float32)) + self.assertEqual(f._cache_size(), size+1) + class VmapOfPmapTest(jtu.JaxTestCase):