From b4e1d0af8a163ef7f1e36993e55e340ebf8c7006 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 27 Sep 2022 21:31:58 -0700 Subject: [PATCH] Propagate `name` through ExecuteReplicated for `dispatch.check_special` PiperOrigin-RevId: 477351323 --- jax/interpreters/pxla.py | 18 ++++++++++-------- tests/debug_nans_test.py | 2 +- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index a0ea04d621e9..e5efbbff1833 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -1615,10 +1615,11 @@ def from_hlo(xla_computation, pci.backend, xla_computation, compile_options, host_callbacks) handle_args = InputsHandler( compiled.local_devices(), in_shardings, input_indices, InputsHandlerMode.pmap) - execute_fun = ExecuteReplicated(compiled, pci.backend, handle_args, - handle_outs, unordered_effects, - ordered_effects, keepalive, - bool(host_callbacks), set(range(len(input_indices)))) + execute_fun = ExecuteReplicated(compiled, "parallel computation", + pci.backend, handle_args, handle_outs, + unordered_effects, ordered_effects, + keepalive, bool(host_callbacks), + set(range(len(input_indices)))) fingerprint = getattr(compiled, "fingerprint", None) return PmapExecutable(compiled, execute_fun, fingerprint, pci.avals) @@ -1977,17 +1978,18 @@ def partitioned_sharding_spec(num_partitions: int, class ExecuteReplicated: """The logic to shard inputs, execute a replicated model, returning outputs.""" - __slots__ = ['xla_executable', 'backend', 'in_handler', 'out_handler', + __slots__ = ['xla_executable', 'name', 'backend', 'in_handler', 'out_handler', 'has_unordered_effects', 'ordered_effects', 'keepalive', 'has_host_callbacks', '_local_devices', 'kept_var_idx', '__weakref__'] - def __init__(self, xla_executable, backend, in_handler: InputsHandler, + def __init__(self, xla_executable, name, backend, in_handler: InputsHandler, out_handler: ResultsHandler, unordered_effects: List[core.Effect], ordered_effects: List[core.Effect], keepalive: Any, has_host_callbacks: bool, kept_var_idx: Set[int]): self.xla_executable = xla_executable + self.name = name self.backend = backend self.in_handler = in_handler self.out_handler = out_handler @@ -2047,7 +2049,7 @@ def __call__(self, *args): for bufs in out_bufs: if xb.use_sharded_buffer and isinstance(bufs, xb.xla_client.ShardedBuffer): bufs = cast(xb.xla_client.ShardedBuffer, bufs).get_device_buffers() - dispatch.check_special("parallel computation", bufs) + dispatch.check_special(self.name, bufs) return self.out_handler(out_bufs) @@ -3292,7 +3294,7 @@ def from_hlo(name: str, kept_var_idx, bool(host_callbacks), from_lower_sharding_computation=True) else: - unsafe_call = ExecuteReplicated(xla_executable, backend, handle_args, + unsafe_call = ExecuteReplicated(xla_executable, name, backend, handle_args, handle_outs, unordered_effects, ordered_effects, keepalive, bool(host_callbacks), kept_var_idx) diff --git a/tests/debug_nans_test.py b/tests/debug_nans_test.py index 22f02a6bc704..c508fc49a20f 100644 --- a/tests/debug_nans_test.py +++ b/tests/debug_nans_test.py @@ -135,7 +135,7 @@ def testXmap(self): with jax.experimental.maps.Mesh(np.array(jax.local_devices()[:1]), ('x',)): with self.assertRaisesRegex( FloatingPointError, - r"invalid value \(nan\) encountered in parallel computation"): + r"invalid value \(nan\) encountered in xmap"): ans = f(jnp.array([0.])) ans.block_until_ready()