Skip to content

Commit

Permalink
Remove trivial execution from jax since it leads to 100x slower dispa…
Browse files Browse the repository at this point in the history
…tch time.

Trivial computations were added for a pre-omnistaging world. After omnistaging, JAX produces less trivial computations, so there is need for this to exist.

In the future, if we want to support forwarding of inputs to outputs, there would need to be a different way which the C++ dispatch path knows about.

```
jit_trivial_dispatch                                   246µs ± 3%                4µs ± 1%  -98.52%          (p=0.008 n=5+5)
jit_trivial                                            250µs ± 3%                5µs ± 1%  -98.19%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 560141018
  • Loading branch information
yashk2810 authored and jax authors committed Aug 25, 2023
1 parent c71eedf commit 970f4c9
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 150 deletions.
129 changes: 5 additions & 124 deletions jax/_src/interpreters/pxla.py
Expand Up @@ -2016,22 +2016,6 @@ def lower_sharding_computation(
"To fix this error, run your `jitted` computation inside "
"`with jax.spmd_mode('allow_all'):` context manager.")

has_outfeed = core.jaxpr_uses_outfeed(jaxpr)
kept_outputs = [True] * len(global_out_avals)

# Computations that only produce constants and/or only rearrange their inputs,
# which are often produced from partial evaluation, don't need compilation,
# and don't need to evaluate their arguments.
if (not always_lower and not (jaxpr.effects or has_outfeed) and
(not jaxpr.eqns and all(kept_outputs) or not jaxpr.outvars) and
all(is_unspecified(o) for o in out_shardings)):
return MeshComputation(
str(name_stack), None, True, donated_invars, jaxpr=jaxpr,
consts=closed_jaxpr.consts, global_in_avals=global_in_avals,
global_out_avals=global_out_avals, in_shardings=in_shardings,
backend=backend, da_object=da_object,
committed=committed, kept_var_idx=kept_var_idx, keepalive=None)

# 2. Build up the HLO
semantic_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore
semantic_out_shardings = SemanticallyEqualShardings(out_shardings)
Expand All @@ -2049,7 +2033,6 @@ def lower_sharding_computation(
return MeshComputation(
str(name_stack),
module,
False,
donated_invars,
global_in_avals=global_in_avals,
global_out_avals=global_out_avals,
Expand Down Expand Up @@ -2223,7 +2206,6 @@ def lower_mesh_computation(
return MeshComputation(
str(name_stack),
lowering_result.module,
False,
donated_invars,
global_in_avals=global_in_avals,
global_out_avals=global_out_avals,
Expand All @@ -2248,35 +2230,23 @@ class MeshComputation(stages.XlaLowering):
_executable: MeshExecutable | None

def __init__(self, name: str, hlo: ir.Module | None,
is_trivial: bool, donated_invars: Sequence[bool], **compile_args):
donated_invars: Sequence[bool], **compile_args):
self._name = name
self._hlo = hlo
self.is_trivial = is_trivial
self._donated_invars = donated_invars
self.compile_args = compile_args
self._executable = None

# -- stages.XlaLowering overrides

def stablehlo(self) -> ir.Module:
if self.is_trivial:
raise ValueError("A trivial computation has no HLO")
return self._hlo

def compile(
self,
compiler_options=None,
) -> MeshExecutable:
def compile(self, compiler_options=None) -> MeshExecutable:
if self._executable is None or compiler_options is not None:
if self.is_trivial:
executable = MeshExecutable.from_trivial_jaxpr(
**self.compile_args)
else:
executable = UnloadedMeshExecutable.from_hlo(
self._name,
self._hlo,
**self.compile_args,
compiler_options=compiler_options)
executable = UnloadedMeshExecutable.from_hlo(
self._name, self._hlo, **self.compile_args,
compiler_options=compiler_options)
if compiler_options is None:
self._executable = executable
return executable
Expand Down Expand Up @@ -2735,32 +2705,6 @@ def unsafe_call(self) -> Callable[..., Any]:
self._unsafe_call = self.build_unsafe_call()
return self._unsafe_call

@staticmethod
def from_trivial_jaxpr(jaxpr, consts, global_in_avals, global_out_avals,
in_shardings, backend, da_object,
committed, kept_var_idx, keepalive) -> MeshExecutable:
assert keepalive is None
if hasattr(backend, "compile_replicated"):
return _compile_replicated_mesh_executable_from_trivial_jaxpr(
jaxpr, consts, global_in_avals, global_out_avals, in_shardings,
backend, da_object, committed, kept_var_idx, 1)

out_shardings = _out_shardings_for_trivial(
jaxpr, consts, in_shardings, da_object)
indices = _get_input_indices(global_out_avals, out_shardings, da_object)
# TODO(yashkatariya): Make local_device_assignment directly usable in the
# downstream code without tuple conversion.
local_device_assignment = tuple(da_object.addressable_device_list)
handle_ins = InputsHandler(local_device_assignment, out_shardings, indices)
handle_outs = global_avals_to_results_handler(
global_out_avals, out_shardings, committed,
[False] * len(global_out_avals))
unsafe_call = partial(_execute_trivial, jaxpr, consts, handle_ins,
handle_outs, kept_var_idx)
return MeshExecutable(None, lambda: unsafe_call, global_in_avals,
in_shardings, out_shardings, False, kept_var_idx,
None)

# -- stages.XlaExecutable overrides

def xla_extension_executable(self):
Expand Down Expand Up @@ -2853,47 +2797,6 @@ def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings):
return in_shardings, out_shardings, committed, tuple(local_devices)


def _out_shardings_for_trivial(
jaxpr: core.Jaxpr, consts: Sequence[Any],
in_shardings: Sequence[sharding_impls.XLACompatibleSharding],
device_assignment: Sequence[xc.Device],
) -> list[sharding_impls.XLACompatibleSharding]:
# For each jaxpr output, compute a Sharding by:
# * if the output is a forwarded input, get the corresponding in_sharding;
# * if the output is a constant Array, get its .sharding attribute;
# * otherwise, the output is a literal or numpy.ndarray constant, so give it
# a replicated sharding
from jax._src import array

if len(device_assignment) > 1:
rep = sharding_impls.GSPMDSharding.get_replicated(device_assignment)
in_shardings = tuple(
i._original_sharding if hasattr(i, '_original_sharding') else i
for i in in_shardings)
else:
dev, = device_assignment
rep = sharding_impls.SingleDeviceSharding(dev)
in_shardings = (sharding_impls.SingleDeviceSharding(dev),) * len(in_shardings)

shardings: dict[core.Var, sharding_impls.XLACompatibleSharding] = {}
for constvar, constval in zip(jaxpr.constvars, consts):
if isinstance(constval, array.ArrayImpl):
shardings[constvar] = constval.sharding
map(shardings.setdefault, jaxpr.invars, in_shardings)
return [rep if isinstance(x, core.Literal) else shardings.get(x, rep)
for x in jaxpr.outvars]


def _execute_trivial(jaxpr, consts, in_handler, out_handler, kept_var_idx, *args):
env: dict[core.Var, Any] = {}
pruned_args = (x for i, x in enumerate(args) if i in kept_var_idx)
map(env.setdefault, jaxpr.invars, pruned_args)
map(env.setdefault, jaxpr.constvars, consts)
outs = [xla.canonicalize_dtype(v.val) if type(v) is core.Literal else env[v]
for v in jaxpr.outvars]
return out_handler(in_handler(outs))


@weakref_lru_cache
def _compile_replicated_mesh_executable_from_hlo(
computation, name, global_in_avals, global_out_avals, semantics_in_shardings,
Expand Down Expand Up @@ -2926,28 +2829,6 @@ def _compile_replicated_mesh_executable_from_hlo(
kept_var_idx, jaxpr_debug_info, None)


def _compile_replicated_mesh_executable_from_trivial_jaxpr(
jaxpr, consts, global_in_avals, global_out_avals, in_shardings, backend,
da_object, committed, kept_var_idx, pmap_nreps):
out_shardings = _out_shardings_for_trivial(
jaxpr, consts, in_shardings, da_object)

input_indices = _get_input_indices(global_in_avals, in_shardings, da_object) # type: ignore
handle_outs = global_avals_to_results_handler(
global_out_avals, out_shardings, committed,
[False] * len(global_out_avals))
# Use the standard out_handler.
unsafe_call = backend.compile_replicated(
is_trivial=True, jaxpr=jaxpr, consts=consts,
device_assignment=da_object, in_avals=global_in_avals,
in_indices=input_indices, in_shardings=in_shardings,
kept_var_idx=kept_var_idx, out_handler=handle_outs,
out_shardings=out_shardings, pmap_nreps=pmap_nreps)
return MeshExecutable(None, lambda: unsafe_call, global_in_avals,
in_shardings, out_shardings, False, kept_var_idx,
None)


@lru_cache
def create_mesh_pspec_sharding(
mesh: Mesh, pspec: Optional[PartitionSpec], parsed_pspec=None,
Expand Down
26 changes: 13 additions & 13 deletions tests/api_test.py
Expand Up @@ -661,15 +661,15 @@ def f(x):
def test_trivial_computations(self):
x = jnp.array([1, 2, 3])
y = self.jit(lambda x: x)(x)
self.assertEqual(x.unsafe_buffer_pointer(), y.unsafe_buffer_pointer())
self.assertNotEqual(x.unsafe_buffer_pointer(), y.unsafe_buffer_pointer())

z1, z2 = self.jit(lambda x: (x, x))(x)
self.assertEqual(z1.unsafe_buffer_pointer(), z2.unsafe_buffer_pointer())
self.assertNotEqual(z1.unsafe_buffer_pointer(), z2.unsafe_buffer_pointer())

x1, x2 = jnp.array([1, 2]), jnp.array([2, 3])
z1, z2, z3 = self.jit(lambda x, y: (y, 1, x))(x1, x2)
self.assertEqual(z1.unsafe_buffer_pointer(), x2.unsafe_buffer_pointer())
self.assertEqual(z3.unsafe_buffer_pointer(), x1.unsafe_buffer_pointer())
self.assertNotEqual(z1.unsafe_buffer_pointer(), x2.unsafe_buffer_pointer())
self.assertNotEqual(z3.unsafe_buffer_pointer(), x1.unsafe_buffer_pointer())
self.assertEqual(z2, 1)

def test_trivial_computations_with_tokens(self):
Expand Down Expand Up @@ -1176,7 +1176,7 @@ def test_jit_lower_no_pruning(self):
self.assertLen(compiled._executable.in_avals, 2)
# Also works with jax.jit
jitted_f = self.jit(lambda x, y: x, keep_unused=True)
with jtu.count_device_put() as count:
with jtu.count_pjit_cpp_cache_miss() as count:
_ = jitted_f(1, 2)
self.assertEqual(count[0], 1)

Expand Down Expand Up @@ -3273,15 +3273,15 @@ def unflatten(unused_aux_data, children):
def test_trivial_computations(self):
x = jnp.array([1, 2, 3])
y = api.jit(lambda x: x)(x)
self.assertEqual(x.unsafe_buffer_pointer(), y.unsafe_buffer_pointer())
self.assertNotEqual(x.unsafe_buffer_pointer(), y.unsafe_buffer_pointer())

z1, z2 = api.jit(lambda x: (x, x))(x)
self.assertEqual(z1.unsafe_buffer_pointer(), z2.unsafe_buffer_pointer())
self.assertNotEqual(z1.unsafe_buffer_pointer(), z2.unsafe_buffer_pointer())

x1, x2 = jnp.array([1, 2]), jnp.array([2, 3])
z1, z2, z3 = api.jit(lambda x, y: (y, 1, x))(x1, x2)
self.assertEqual(z1.unsafe_buffer_pointer(), x2.unsafe_buffer_pointer())
self.assertEqual(z3.unsafe_buffer_pointer(), x1.unsafe_buffer_pointer())
self.assertNotEqual(z1.unsafe_buffer_pointer(), x2.unsafe_buffer_pointer())
self.assertNotEqual(z3.unsafe_buffer_pointer(), x1.unsafe_buffer_pointer())
self.assertEqual(z2, 1)

def test_nested_jit_hoisting(self):
Expand Down Expand Up @@ -5455,19 +5455,19 @@ def test_vjp_caching(self):
# https://github.com/google/jax/issues/9661
identity = jax.checkpoint(jax.jit(lambda x: 2 * x))
_, f_vjp = jax.vjp(identity, 1.)
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
with jtu.count_pjit_cpp_cache_miss() as count: # noqa: F841
for _ in range(20):
f_vjp(1.)[0].block_until_ready()
self.assertEqual(count[0], 1) # fwd execute_trivial, backward_pass on bwd
self.assertEqual(count[0], 2) # fwd execute_trivial, backward_pass on bwd

def test_vjp_caching_static_argnums(self):
identity = jax.remat(lambda x, y: jax.jit(lambda x: 2 * x if y else x)(x),
static_argnums=(1,))
_, f_vjp = jax.vjp(identity, 1., True)
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
with jtu.count_pjit_cpp_cache_miss() as count: # noqa: F841
for _ in range(20):
f_vjp(1.)[0].block_until_ready()
self.assertEqual(count[0], 1) # fwd execute_trivial, backward_pass on bwd
self.assertEqual(count[0], 2) # fwd execute_trivial, backward_pass on bwd

def test_fwd_caching(self):
# see above test also
Expand Down
4 changes: 2 additions & 2 deletions tests/lax_numpy_test.py
Expand Up @@ -3119,14 +3119,14 @@ def testArrayCopy(self, dtype, func):
_ptr = lambda x: x.unsafe_buffer_pointer()

self.assertEqual(_ptr(x), _ptr(x_view))
self.assertEqual(_ptr(x), _ptr(x_view_jit))
self.assertNotEqual(_ptr(x), _ptr(x_view_jit))
self.assertNotEqual(_ptr(x), _ptr(x_copy))
self.assertNotEqual(_ptr(x), _ptr(x_copy_jit))

x.delete()

self.assertTrue(x_view.is_deleted())
self.assertTrue(x_view_jit.is_deleted())
self.assertFalse(x_view_jit.is_deleted())

self.assertFalse(x_copy.is_deleted())
self.assertFalse(x_copy_jit.is_deleted())
Expand Down
4 changes: 2 additions & 2 deletions tests/lax_test.py
Expand Up @@ -2995,8 +2995,8 @@ def shard_foo_array_handler(x, devices, indices, sharding):
return pxla.batched_device_put(
aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device])

def foo_array_constant_handler(x, c):
return array._array_mlir_constant_handler(x.data, c)
def foo_array_constant_handler(x):
return array._array_mlir_constant_handler(x.data)

def make_lowering(*, shape):
return jnp.zeros((*shape, 2), 'uint32')
Expand Down
3 changes: 2 additions & 1 deletion tests/multi_device_test.py
Expand Up @@ -116,7 +116,8 @@ def test_computation_follows_data(self):
z1, z2 = jax.jit(lambda x: (x, x))(x_uncommitted)
self.assert_uncommitted_to_device(z1, devices[0])
self.assert_uncommitted_to_device(z2, devices[0])
self.assertEqual(z1.unsafe_buffer_pointer(), z2.unsafe_buffer_pointer())
# trivial computation does not exist in JAX anymore.
self.assertNotEqual(z1.unsafe_buffer_pointer(), z2.unsafe_buffer_pointer())

x2_uncommitted = jnp.array([2, 3])
z1, z2, z3 = jax.jit(lambda x, y: (y, 1, x))(x_uncommitted, x2_uncommitted)
Expand Down
12 changes: 4 additions & 8 deletions tests/pjit_test.py
Expand Up @@ -1987,14 +1987,10 @@ def test_pjit_single_device_sharding_cache(self):
a = jnp.arange(16).reshape((8, 2))
f = pjit(lambda x: x)

out = f(a)
cache_info1 = pjit_lib._pjit_lower_cached.cache_info()

_ = f(out)
cache_info2 = pjit_lib._pjit_lower_cached.cache_info()

self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
self.assertEqual(cache_info2.misses, cache_info1.misses)
with jtu.count_pjit_cpp_cache_miss() as count:
out = f(a)
_ = f(out)
self.assertEqual(count[0], 1)

def test_pjit_different_device_recompilation(self):
if jax.device_count() < 2:
Expand Down

0 comments on commit 970f4c9

Please sign in to comment.