Skip to content

Commit

Permalink
Propagate name through ExecuteReplicated for dispatch.check_special
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 477351323
  • Loading branch information
yashk2810 authored and jax authors committed Sep 28, 2022
1 parent 933b6a2 commit b4e1d0a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
18 changes: 10 additions & 8 deletions jax/interpreters/pxla.py
Expand Up @@ -1615,10 +1615,11 @@ def from_hlo(xla_computation,
pci.backend, xla_computation, compile_options, host_callbacks)
handle_args = InputsHandler(
compiled.local_devices(), in_shardings, input_indices, InputsHandlerMode.pmap)
execute_fun = ExecuteReplicated(compiled, pci.backend, handle_args,
handle_outs, unordered_effects,
ordered_effects, keepalive,
bool(host_callbacks), set(range(len(input_indices))))
execute_fun = ExecuteReplicated(compiled, "parallel computation",
pci.backend, handle_args, handle_outs,
unordered_effects, ordered_effects,
keepalive, bool(host_callbacks),
set(range(len(input_indices))))
fingerprint = getattr(compiled, "fingerprint", None)

return PmapExecutable(compiled, execute_fun, fingerprint, pci.avals)
Expand Down Expand Up @@ -1977,17 +1978,18 @@ def partitioned_sharding_spec(num_partitions: int,

class ExecuteReplicated:
"""The logic to shard inputs, execute a replicated model, returning outputs."""
__slots__ = ['xla_executable', 'backend', 'in_handler', 'out_handler',
__slots__ = ['xla_executable', 'name', 'backend', 'in_handler', 'out_handler',
'has_unordered_effects', 'ordered_effects', 'keepalive',
'has_host_callbacks', '_local_devices', 'kept_var_idx',
'__weakref__']

def __init__(self, xla_executable, backend, in_handler: InputsHandler,
def __init__(self, xla_executable, name, backend, in_handler: InputsHandler,
out_handler: ResultsHandler,
unordered_effects: List[core.Effect],
ordered_effects: List[core.Effect], keepalive: Any,
has_host_callbacks: bool, kept_var_idx: Set[int]):
self.xla_executable = xla_executable
self.name = name
self.backend = backend
self.in_handler = in_handler
self.out_handler = out_handler
Expand Down Expand Up @@ -2047,7 +2049,7 @@ def __call__(self, *args):
for bufs in out_bufs:
if xb.use_sharded_buffer and isinstance(bufs, xb.xla_client.ShardedBuffer):
bufs = cast(xb.xla_client.ShardedBuffer, bufs).get_device_buffers()
dispatch.check_special("parallel computation", bufs)
dispatch.check_special(self.name, bufs)
return self.out_handler(out_bufs)


Expand Down Expand Up @@ -3292,7 +3294,7 @@ def from_hlo(name: str,
kept_var_idx, bool(host_callbacks),
from_lower_sharding_computation=True)
else:
unsafe_call = ExecuteReplicated(xla_executable, backend, handle_args,
unsafe_call = ExecuteReplicated(xla_executable, name, backend, handle_args,
handle_outs, unordered_effects,
ordered_effects, keepalive,
bool(host_callbacks), kept_var_idx)
Expand Down
2 changes: 1 addition & 1 deletion tests/debug_nans_test.py
Expand Up @@ -135,7 +135,7 @@ def testXmap(self):
with jax.experimental.maps.Mesh(np.array(jax.local_devices()[:1]), ('x',)):
with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in parallel computation"):
r"invalid value \(nan\) encountered in xmap"):
ans = f(jnp.array([0.]))
ans.block_until_ready()

Expand Down

0 comments on commit b4e1d0a

Please sign in to comment.