Skip to content

Commit

Permalink
Turn execute_replicated into a class so we can access its fields.
Browse files Browse the repository at this point in the history
It's more readable than inspecting the internals of a `functools.partial`.

PiperOrigin-RevId: 429523075
  • Loading branch information
jblespiau authored and jax authors committed Feb 18, 2022
1 parent d123a10 commit 607e703
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 19 deletions.
11 changes: 6 additions & 5 deletions jax/_src/api.py
Expand Up @@ -2100,18 +2100,19 @@ def cache_miss(*args, **kwargs):
execute = pxla.parallel_callable.most_recent_entry()
use_fastpath = (
execute is not None and
# We don't support JAX extension backends. In particular, some
# extentions do not return a partial with a `func` attribute.
getattr(execute[0], "func", None) is pxla.execute_replicated and
# We don't support JAX extension backends.
isinstance(execute[0], pxla.ExecuteReplicated) and
# No tracers in the outputs. Checking for ShardedDeviceArray should be
# sufficient, but we use the more general `DeviceArray`.
all(isinstance(x, device_array.DeviceArray) for x in out_flat))
### If we can use the fastpath, we return required info to the caller.
if use_fastpath:
xla_executable, backend_, in_handler, out_handler = execute[0].args
execute_replicated = execute[0]
out_handler = execute_replicated.out_handler
in_handler = execute_replicated.in_handler
fastpath_data = _PmapFastpathData(
version=1,
xla_executable=xla_executable,
xla_executable=execute_replicated.xla_executable,
in_handler=in_handler,
out_handler=out_handler,
out_pytree_def=out_pytree_def,
Expand Down
43 changes: 29 additions & 14 deletions jax/interpreters/pxla.py
Expand Up @@ -1216,8 +1216,7 @@ def from_hlo(xla_computation,
pci.backend, xla_computation, compile_options)
handle_args = InputsHandler(
compiled.local_devices(), input_sharding_specs, input_indices)
execute_fun = partial(
execute_replicated, compiled, pci.backend, handle_args, handle_outs)
execute_fun = ExecuteReplicated(compiled, pci.backend, handle_args, handle_outs)
fingerprint = getattr(compiled, "fingerprint", None)

return PmapExecutable(compiled, execute_fun, fingerprint, pci.avals)
Expand Down Expand Up @@ -1386,6 +1385,12 @@ def __init__(self, local_devices, sharding_specs, input_indices):
def __call__(self, input_buffers):
return self.handler(input_buffers)

def __str__(self):
return ("InputsHandler(\n"
f"local_devices={self.local_devices},\n"
f"sharding_specs={self.sharding_specs},\n"
f"input_indices={self.input_indices})")


class ResultsHandler:
__slots__ = ("handlers", "out_specs", "out_indices", "unmapped_local_out_avals")
Expand Down Expand Up @@ -1540,14 +1545,25 @@ def partitioned_sharding_spec(num_partitions: int,
mesh_mapping=map(ShardedAxis, range(len(partitions))))


@profiler.annotate_function
def execute_replicated(compiled, backend, in_handler, out_handler, *args):
input_bufs = in_handler(args)
out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
if dispatch.needs_check_special():
for bufs in out_bufs:
dispatch.check_special("parallel computation", bufs)
return out_handler(out_bufs)
class ExecuteReplicated:
"""The logic to shard inputs, execute a replicated model, returning outputs."""
__slots__ = ['xla_executable', 'backend', 'in_handler', 'out_handler']

def __init__(self, xla_executable, backend, in_handler: InputsHandler,
out_handler: ResultsHandler):
self.xla_executable = xla_executable
self.backend = backend
self.in_handler = in_handler
self.out_handler = out_handler

@profiler.annotate_function
def __call__(self, *args):
input_bufs = self.in_handler(args)
out_bufs = self.xla_executable.execute_sharded_on_local_devices(input_bufs)
if dispatch.needs_check_special():
for bufs in out_bufs:
dispatch.check_special("parallel computation", bufs)
return self.out_handler(out_bufs)


xla_pmap_p = core.MapPrimitive('xla_pmap')
Expand Down Expand Up @@ -2326,10 +2342,9 @@ def from_hlo(name: str,
else:
with dispatch.log_elapsed_time(f"Finished XLA compilation of {name} "
"in {elapsed_time} sec"):
compiled = dispatch.compile_or_get_cached(backend, computation, compile_options)
handle_args = InputsHandler(compiled.local_devices(), input_specs, input_indices)
unsafe_call = partial(execute_replicated, compiled, backend, handle_args, handle_outs)
xla_executable = compiled
xla_executable = dispatch.compile_or_get_cached(backend, computation, compile_options)
handle_args = InputsHandler(xla_executable.local_devices(), input_specs, input_indices)
unsafe_call = ExecuteReplicated(xla_executable, backend, handle_args, handle_outs)

return MeshExecutable(xla_executable, unsafe_call, input_avals)

Expand Down
7 changes: 7 additions & 0 deletions tests/pmap_test.py
Expand Up @@ -1903,6 +1903,13 @@ class CppPmapTest(PythonPmapTest):
def pmap(self):
return src_api._cpp_pmap

def pmap_fast_path_is_enabled(self):
num_devices = jax.device_count()
f = jax.pmap(lambda x: x+1)
size = f._cache_size()
f(np.zeros([num_devices], dtype=np.float32))
self.assertEqual(f._cache_size(), size+1)


class VmapOfPmapTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 607e703

Please sign in to comment.