diff --git a/jax/_src/lax/control_flow/__init__.py b/jax/_src/lax/control_flow/__init__.py index 4b0ee80a8dda..05dcade84999 100644 --- a/jax/_src/lax/control_flow/__init__.py +++ b/jax/_src/lax/control_flow/__init__.py @@ -20,7 +20,8 @@ fori_loop, map, scan, scan_bind, scan_p, _scan_impl, while_loop, while_p) -from jax._src.lax.control_flow.conditionals import cond, cond_p, switch +from jax._src.lax.control_flow.conditionals import (cond, cond_p, switch, + platform_dependent) from jax._src.lax.control_flow.solves import (custom_linear_solve, custom_root, _custom_linear_solve_impl, linear_solve_p) diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index c1a8b4de9ac9..ab54f96d98d9 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Module for conditional control flow primitives.""" +from __future__ import annotations import collections from collections.abc import Sequence @@ -20,7 +21,7 @@ import inspect import itertools import operator -from typing import Callable +from typing import Any, Callable, TypeVar from jax.tree_util import tree_flatten, tree_unflatten from jax._src import ad_util @@ -882,3 +883,114 @@ def _cond_state_discharge_rule(in_avals, out_avals, *args, branches, linear): new_invals.append( next(ref_val_iter) if isinstance(aval, AbstractRef) else None) return new_invals, out_vals + + +_T = TypeVar("_T") +def platform_dependent(*args: Any, + default: Callable[..., _T] | None = None, + **per_platform: Callable[..., _T]): + """Stages out platform-specific code. + + In JAX the actual platform on which a computation is run is determined + very late, e.g., based on where the data is located. When using AOT + lowering or serialization, the computation may be compiled and executed + on a different machine, or even on a platform that is not available at + lowering time. This means that it is not safe to write platform-dependent + code using Python conditionals, e.g., based on the current default + JAX platform. Instead, one can use ``platform_dependent``: + + Usage:: + + def cpu_code(*args): ... + def tpu_code(*args): ... + def other_platforms_code(*args): ... + res = platform_dependent(*args, cpu=cpu_code, tpu=tpu_code, + default=other_platforms_code) + + When the staged out code is executed on a CPU, this is equivalent to + ``cpu_code(*args)``, on a TPU is equivalent to ``tpu_code(*args)`` and on + any other platform to ``other_platforms_code(*args)``. + Unlike a Python conditional, all alternatives are traced + and staged out to Jaxpr. This is similar to, and is implemented in terms of, + :func:`~switch`, from which it inherits the behavior + under transformations. + + Unlike a :func:`~switch` the choice of what gets executed is made earlier: + in most cases during lowering when the lowering platform is known; in the + rare case of multi-platform lowering and serialization, the StableHLO code + will contain a conditional on the actual platform. This conditional is + resolved just in time prior to compilation when the compilation platform is + known. This means that the compiler actually never sees a conditional. + + Args: + *args: JAX arrays passed to each of the branches. May be PyTrees. + **per_platform: branches to use for different platforms. The branches are + JAX callables invoked with ``*args``. The keywords are platform names, + e.g., 'cpu', 'tpu', 'cuda', 'rocm'. + default: optional default branch to use for a platform not mentioned in + ``per_platform``. If there is no ``default`` there will be an error when + the code is lowered for a platform not mentioned in ``per_platform``. + + Returns: + The value ``per_platform[execution_platform](*args)``. + """ + # Join identical branches + platform_branches: list[tuple[list[str], Callable]] = [] + for pname, pbranch in per_platform.items(): + if pname == "gpu": + raise ValueError("Use 'cuda' or 'rocm' for this API.") + for ps, b in platform_branches: + if b == pbranch: + ps.append(pname) + break + else: + platform_branches.append(([pname], pbranch)) + + platforms_lists, branches = util.unzip2(platform_branches) + platform_index = platform_index_p.bind( + platforms=tuple(tuple(ps) for ps in platforms_lists), + has_default=(default is not None)) + if default is not None: + branches = branches + (default,) + # Use a switch, to get the proper transformation rules for free. Since + # platform index has no dependence on the input data, it won't be vectorized + # under vmap. + return switch(platform_index, branches, *args) + +# A primitive to compute the index of a platform into a list of platforms. +# Args: +# platforms: Sequence[Sequence[str]]: a sequence of sequences of platform +# names. If the current lowering platform is in one of the inner sequences +# returns the index of that inner sequence in the outer sequence. +# has_default: if True, and if the lowering platform is not found in +# `platforms` then return `len(platforms)`. Otherwise, raise an error. +platform_index_p = core.Primitive("platform_index") +platform_index_p.multiple_results = False +platform_index_p.def_impl(functools.partial(dispatch.apply_primitive, + platform_index_p)) + +@platform_index_p.def_abstract_eval +def _platform_index_aval(*_, **__): + return core.ShapedArray((), np.int32) + +def _platform_index_lowering(ctx: mlir.LoweringRuleContext, + *, + platforms: Sequence[Sequence[str]], + has_default: bool): + def lower_constant(ctx: mlir.LoweringRuleContext, *, i: int) -> mlir.ir.Value: + return mlir.ir_constants(np.int32(i)) + lowering_rules: tuple[mlir.MultiPlatformLoweringRule, ...] = tuple( + (ps, partial(lower_constant, i=i)) + for i, ps in enumerate(platforms) + ) + if has_default: + lowering_rules = lowering_rules + ( + (None, partial(lower_constant, i=len(platforms))), + ) + return mlir.lower_multi_platform( + ctx, + f"platform_index(platforms={platforms}, has_default={has_default})", + lowering_rules, + effects.no_effects) + +mlir.register_lowering(platform_index_p, _platform_index_lowering) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index ff127a3721bb..56cbb0182bb5 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1500,6 +1500,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): "hessenberg", "tridiagonal", "eigh_jacobi", + "platform_index", ] tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index cee9e7f60318..a4adadf0f962 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -1683,46 +1683,16 @@ def test_multi_platform(self): self.skipTest("TODO: enable when we can handle i64 platform_index_argument") # Checks that we dispatch from TF to the proper JAX platform lowering. - # A primitive for testing multi-platform lowering. Takes one argument and - # adds a different value to it: cpu=2., tpu=3., cuda=.4, rocm=5. - _testing_multi_platform_p = core.Primitive("testing_multi_platform") + # We add a different value to it: cpu=2., tpu=3., cuda=.4, rocm=5. _testing_multi_platform_to_add = dict(cpu=2., tpu=3., cuda=4., rocm=5.) - @_testing_multi_platform_p.def_abstract_eval - def _testing_multi_platform_abstract_eval(xaval: core.AbstractValue): - assert xaval.dtype == np.float32 # type: ignore - return xaval - - @_testing_multi_platform_p.def_impl - def _testing_multi_platform_impl(x: jax.Array) -> jax.Array: - to_add = _testing_multi_platform_to_add[platform] - return x + to_add - - def _testing_multi_platform_lowering(ctx: mlir.LoweringRuleContext, - x: mlir.Value, - *, - platform: str) -> Sequence[mlir.Value]: - to_add = _testing_multi_platform_to_add[platform] - to_add_value = mlir.broadcast_in_dim(ctx, - mlir.ir_constant( - np.float32(to_add)), - ctx.avals_in[0], - broadcast_dimensions=()) - return mlir.hlo.AddOp(x, to_add_value).results - - # Register a default rule for cuda, to test the default-platform rule selection. - mlir.register_lowering(_testing_multi_platform_p, - functools.partial(_testing_multi_platform_lowering, - platform="cuda")) - for platform in ["cpu", "tpu", "rocm"]: - mlir.register_lowering(_testing_multi_platform_p, - functools.partial( - _testing_multi_platform_lowering, - platform=platform), - platform=platform) - def f_jax(x): - return _testing_multi_platform_p.bind(x) + return x + lax.platform_dependent( + tpu=lambda: _testing_multi_platform_to_add["tpu"], + cuda=lambda: _testing_multi_platform_to_add["cuda"], + rocm=lambda: _testing_multi_platform_to_add["rocm"], + default=lambda: _testing_multi_platform_to_add["cpu"] + ) x = np.float32(.42) f_tf = jax2tf.convert( diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index a87f5ac45cf5..07a26a3ac1bf 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -338,6 +338,7 @@ switch as switch, while_loop as while_loop, while_p as while_p, + platform_dependent as platform_dependent, ) from jax._src.lax.fft import ( fft as fft, diff --git a/tests/export_test.py b/tests/export_test.py index 6a4ac4524a2d..e93532002e88 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -18,11 +18,12 @@ import logging import math import re -from typing import Optional, Sequence +from typing import Optional import unittest from absl.testing import absltest import jax +from jax import lax from jax import numpy as jnp from jax import tree_util from jax.experimental.export import export @@ -94,56 +95,32 @@ def lowering_testing_primitive_with_effect(ctx, a, *, effect_class_name: str): lowering_testing_primitive_with_effect) ## Setup for multi-platform lowering -# A primitive for testing multi-platform lowering. Takes one argument and -# adds a different value to it: cpu=2., tpu=3., cuda=.4, rocm=5. -# The primitive takes an event_class_name kwarg that may be None, or -# the name of an effect class. -_testing_multi_platform_p = core.Primitive("testing_multi_platform") _testing_multi_platform_to_add = dict(cpu=2., tpu=3., cuda=4., rocm=5.) def _testing_multi_platform_func(x, *, effect_class_name: Optional[str] = None): - return _testing_multi_platform_p.bind(x, effect_class_name=effect_class_name) + # Behaves like x + 2 * _testing_multi_platform_to_add[platform] + def for_platform(platform: str): + if effect_class_name is None: + return 2. * _testing_multi_platform_to_add[platform] + else: + return testing_primitive_with_effect_p.bind( + _testing_multi_platform_to_add[platform], + effect_class_name=effect_class_name) + + return x + lax.platform_dependent( + tpu=lambda: for_platform("tpu"), + cuda=lambda: for_platform("cuda"), + rocm=lambda: for_platform("rocm"), + default=lambda: for_platform("cpu"), + ) def _testing_multi_platform_fun_expected(x, platform: str | None = None): - return x + _testing_multi_platform_to_add[ + return x + 2. * _testing_multi_platform_to_add[ xb.canonicalize_platform(platform or jtu.device_under_test()) ] -@_testing_multi_platform_p.def_effectful_abstract_eval -def _testing_multi_platform_abstract_eval(xaval: core.AbstractValue, - effect_class_name: Optional[str]): - assert xaval.dtype == np.float32 # type: ignore - effects = set() if effect_class_name is None else set([_testing_effects[effect_class_name]]) - return (xaval, effects) - -def _testing_multi_platform_lowering(ctx: mlir.LoweringRuleContext, - x: mlir.Value, - *, - effect_class_name: Optional[str], - platform: str) -> Sequence[mlir.Value]: - to_add = _testing_multi_platform_to_add[platform] - to_add_value = mlir.broadcast_in_dim(ctx, - mlir.ir_constant(np.float32(to_add)), - ctx.avals_in[0], - broadcast_dimensions=()) - results = mlir.hlo.AddOp(x, to_add_value).results - if effect_class_name is not None and "Ordered" in effect_class_name: - token_in = ctx.tokens_in.get(_testing_effects[effect_class_name])[0] - ctx.set_tokens_out(mlir.TokenSet({_testing_effects[effect_class_name]: (token_in,)})) - return results - -# Register a default rule for cuda, to test the default-platform rule selection. -mlir.register_lowering(_testing_multi_platform_p, - functools.partial(_testing_multi_platform_lowering, - platform="cuda")) -for platform in ["cpu", "tpu", "rocm"]: - mlir.register_lowering(_testing_multi_platform_p, - functools.partial(_testing_multi_platform_lowering, - platform=platform), - platform=platform) - class JaxExportTest(jtu.JaxTestCase): @@ -813,8 +790,8 @@ def f_jax(x): # x: f32[10,20] -> f32[20,10] def test_multi_platform(self): x = np.arange(8, dtype=np.float32) exp = export.export(_testing_multi_platform_func, - lowering_platforms=("cpu", "tpu", "cuda"))(x) - self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda")) + lowering_platforms=("tpu", "cpu", "cuda"))(x) + self.assertEqual(exp.lowering_platforms, ("tpu", "cpu", "cuda")) module_str = str(exp.mlir_module()) expected_main_re = ( r"@main\(" diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index ba4b10dc054f..ca067e10ddbf 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -14,6 +14,7 @@ import collections +import contextlib from functools import partial import itertools import operator @@ -2719,6 +2720,116 @@ def bad_switchfun_jvp(primals, tangents): self.assertAllClose(expected2, expected3) self.assertAllClose(expected3, actual) + def test_platform_dependent(self): + def f(x): + return lax.platform_dependent(x, cpu=jnp.sin, default=jnp.cos) + + x = np.arange(3, dtype=np.float32) + res = f(x) + self.assertAllClose( + res, + np.sin(x) if jtu.device_under_test() == "cpu" else np.cos(x)) + + def test_platform_dependent_no_args(self): + def f(x): + return lax.platform_dependent(cpu=lambda: jnp.sin(x), + default=lambda: jnp.cos(x)) + + x = np.arange(3, dtype=np.float32) + res = f(x) + self.assertAllClose( + res, + np.sin(x) if jtu.device_under_test() == "cpu" else np.cos(x)) + + def test_platform_dependent_lowering(self): + def f(x): + return lax.platform_dependent(x, cpu=jnp.sin, default=jnp.cos) + + x = np.arange(3, dtype=np.float32) + lowered = jax.jit(f).lower(x) + stablehlo = lowered.as_text() + self.assertIn("stablehlo.case", stablehlo) + self.assertIn("stablehlo.sine", stablehlo) + self.assertIn("stablehlo.cosine", stablehlo) + + # The HLO has been canonicalized and contains only the branch we need + hlo = lowered.as_text("hlo") + if jtu.device_under_test() == "cpu": + self.assertIn(" sine", hlo) + self.assertNotIn(" cosine", hlo) + else: + self.assertNotIn(" sine", hlo) + self.assertIn(" cosine", hlo) + + def test_platform_dependent_multiple_identical_branches(self): + x = np.arange(3, dtype=np.float32) + def f(x): + return lax.platform_dependent( + x, + cpu=jnp.sin, + tpu=jnp.sin, + default=lambda x: x) + res = f(x) + self.assertAllClose( + res, + np.sin(x) if jtu.device_under_test() in ["cpu", "tpu"] else x) + # We only lower the common branches once + stablehlo = jax.jit(f).lower(x).as_text() + sines = re.findall(r"stablehlo.sine", stablehlo) + self.assertEqual(1, len(sines)) + + def test_platform_dependent_no_default(self): + ctx = contextlib.ExitStack() + if jtu.device_under_test() != "tpu": + ctx.enter_context( + self.assertRaisesRegex(ValueError, + "translation rule .* not found for platform")) + with ctx: + lax.platform_dependent( + 3., + tpu=lambda x: x + 2.) + + def test_platform_dependent_batched(self): + def f(x): + return lax.platform_dependent(x, cpu=jnp.sin, default=jnp.cos) + + xs = np.arange(3, dtype=np.float32) + self.assertAllClose( + jax.vmap(f)(xs), + np.sin(xs) if jtu.device_under_test() == "cpu" else np.cos(xs)) + # We can still fold the un-needed branch + hlo = jax.jit(jax.vmap(f)).lower(xs).as_text('hlo') + expect_a_sine = (jtu.device_under_test() == "cpu") + self.assertEqual(expect_a_sine, " sine(" in hlo) + self.assertEqual(not expect_a_sine, " cosine(" in hlo) + + def test_platform_dependent_grad(self): + # For a function "lax.dot(x, x)", we choose two branches with very different + # implementations (a dot and a scan), and therefore different residuals, + # so that we can verify whether the residuals are as we expect (we don't + # get residuals from a different platform. + x = np.arange(8, dtype=np.float32) + def f_impl_dot(x): # x: f32[8] + return jnp.dot(x, x) + def f_impl_scan(x): + def scan_body(carry, x_i): + return (carry + x_i * x_i, None) + return lax.scan(scan_body, np.float32(0.), x)[0] + + def f(x): + return jnp.sin(lax.platform_dependent(x, + cpu=f_impl_dot, + default=f_impl_scan)) + self.assertAllClose( + jax.grad(f)(x), + jax.grad(lambda x: jnp.sin(f_impl_dot(x)))(x)) + + # Check that we do not have contamination of computations across platforms + hlo = jax.jit(jax.grad(f)).lower(x).as_text('hlo') + expect_a_dot = (jtu.device_under_test() == "cpu") + self.assertEqual(expect_a_dot, " dot(" in hlo) + self.assertEqual(not expect_a_dot, " while(" in hlo) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())