Skip to content

Commit

Permalink
Add a lax.platform_dependent API for writing platform-dependent code.
Browse files Browse the repository at this point in the history
In JAX the actual platform on which a computation is run is determined
very late, e.g., based on where the data is located. When using AOT
lowering or serialization, the computation may execute on a different
machine, or even on a platform that is not available at lowering time.
This means that it is not safe to write platform-dependent code using
Python conditionals, e.g., based on the current default JAX platform.
The proper way to do this is to introduce a primitive with
platform-specific lowering rules. This change introduces such a
primitive along with a user-facing API.

See more details in the docstring of lax.platform_dependent.
  • Loading branch information
gnecula committed Nov 2, 2023
1 parent d41078f commit 8feb413
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 82 deletions.
3 changes: 2 additions & 1 deletion jax/_src/lax/control_flow/__init__.py
Expand Up @@ -20,7 +20,8 @@
fori_loop, map,
scan, scan_bind, scan_p,
_scan_impl, while_loop, while_p)
from jax._src.lax.control_flow.conditionals import cond, cond_p, switch
from jax._src.lax.control_flow.conditionals import (cond, cond_p, switch,
platform_dependent)
from jax._src.lax.control_flow.solves import (custom_linear_solve, custom_root,
_custom_linear_solve_impl,
linear_solve_p)
Expand Down
114 changes: 113 additions & 1 deletion jax/_src/lax/control_flow/conditionals.py
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for conditional control flow primitives."""
from __future__ import annotations

import collections
from collections.abc import Sequence
Expand All @@ -20,7 +21,7 @@
import inspect
import itertools
import operator
from typing import Callable
from typing import Any, Callable, TypeVar

from jax.tree_util import tree_flatten, tree_unflatten
from jax._src import ad_util
Expand Down Expand Up @@ -882,3 +883,114 @@ def _cond_state_discharge_rule(in_avals, out_avals, *args, branches, linear):
new_invals.append(
next(ref_val_iter) if isinstance(aval, AbstractRef) else None)
return new_invals, out_vals


_T = TypeVar("_T")
def platform_dependent(*args: Any,
default: Callable[..., _T] | None = None,
**per_platform: Callable[..., _T]):
"""Stages out platform-specific code.
In JAX the actual platform on which a computation is run is determined
very late, e.g., based on where the data is located. When using AOT
lowering or serialization, the computation may be compiled and executed
on a different machine, or even on a platform that is not available at
lowering time. This means that it is not safe to write platform-dependent
code using Python conditionals, e.g., based on the current default
JAX platform. Instead, one can use ``platform_dependent``:
Usage::
def cpu_code(*args): ...
def tpu_code(*args): ...
def other_platforms_code(*args): ...
res = platform_dependent(*args, cpu=cpu_code, tpu=tpu_code,
default=other_platforms_code)
When the staged out code is executed on a CPU, this is equivalent to
``cpu_code(*args)``, on a TPU is equivalent to ``tpu_code(*args)`` and on
any other platform to ``other_platforms_code(*args)``.
Unlike a Python conditional, all alternatives are traced
and staged out to Jaxpr. This is similar to, and is implemented in terms of,
:func:`~switch`, from which it inherits the behavior
under transformations.
Unlike a :func:`~switch` the choice of what gets executed is made earlier:
in most cases during lowering when the lowering platform is known; in the
rare case of multi-platform lowering and serialization, the StableHLO code
will contain a conditional on the actual platform. This conditional is
resolved just in time prior to compilation when the compilation platform is
known. This means that the compiler actually never sees a conditional.
Args:
*args: JAX arrays passed to each of the branches. May be PyTrees.
**per_platform: branches to use for different platforms. The branches are
JAX callables invoked with ``*args``. The keywords are platform names,
e.g., 'cpu', 'tpu', 'cuda', 'rocm'.
default: optional default branch to use for a platform not mentioned in
``per_platform``. If there is no ``default`` there will be an error when
the code is lowered for a platform not mentioned in ``per_platform``.
Returns:
The value ``per_platform[execution_platform](*args)``.
"""
# Join identical branches
platform_branches: list[tuple[list[str], Callable]] = []
for pname, pbranch in per_platform.items():
if pname == "gpu":
raise ValueError("Use 'cuda' or 'rocm' for this API.")
for ps, b in platform_branches:
if b == pbranch:
ps.append(pname)
break
else:
platform_branches.append(([pname], pbranch))

platforms_lists, branches = util.unzip2(platform_branches)
platform_index = platform_index_p.bind(
platforms=tuple(tuple(ps) for ps in platforms_lists),
has_default=(default is not None))
if default is not None:
branches = branches + (default,)
# Use a switch, to get the proper transformation rules for free. Since
# platform index has no dependence on the input data, it won't be vectorized
# under vmap.
return switch(platform_index, branches, *args)

# A primitive to compute the index of a platform into a list of platforms.
# Args:
# platforms: Sequence[Sequence[str]]: a sequence of sequences of platform
# names. If the current lowering platform is in one of the inner sequences
# returns the index of that inner sequence in the outer sequence.
# has_default: if True, and if the lowering platform is not found in
# `platforms` then return `len(platforms)`. Otherwise, raise an error.
platform_index_p = core.Primitive("platform_index")
platform_index_p.multiple_results = False
platform_index_p.def_impl(functools.partial(dispatch.apply_primitive,
platform_index_p))

@platform_index_p.def_abstract_eval
def _platform_index_aval(*_, **__):
return core.ShapedArray((), np.int32)

def _platform_index_lowering(ctx: mlir.LoweringRuleContext,
*,
platforms: Sequence[Sequence[str]],
has_default: bool):
def lower_constant(ctx: mlir.LoweringRuleContext, *, i: int) -> mlir.ir.Value:
return mlir.ir_constants(np.int32(i))
lowering_rules: tuple[mlir.MultiPlatformLoweringRule, ...] = tuple(
(ps, partial(lower_constant, i=i))
for i, ps in enumerate(platforms)
)
if has_default:
lowering_rules = lowering_rules + (
(None, partial(lower_constant, i=len(platforms))),
)
return mlir.lower_multi_platform(
ctx,
f"platform_index(platforms={platforms}, has_default={has_default})",
lowering_rules,
effects.no_effects)

mlir.register_lowering(platform_index_p, _platform_index_lowering)
1 change: 1 addition & 0 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -1500,6 +1500,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
"hessenberg",
"tridiagonal",
"eigh_jacobi",
"platform_index",
]

tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient
Expand Down
44 changes: 7 additions & 37 deletions jax/experimental/jax2tf/tests/jax2tf_test.py
Expand Up @@ -1683,46 +1683,16 @@ def test_multi_platform(self):
self.skipTest("TODO: enable when we can handle i64 platform_index_argument")
# Checks that we dispatch from TF to the proper JAX platform lowering.

# A primitive for testing multi-platform lowering. Takes one argument and
# adds a different value to it: cpu=2., tpu=3., cuda=.4, rocm=5.
_testing_multi_platform_p = core.Primitive("testing_multi_platform")
# We add a different value to it: cpu=2., tpu=3., cuda=.4, rocm=5.
_testing_multi_platform_to_add = dict(cpu=2., tpu=3., cuda=4., rocm=5.)

@_testing_multi_platform_p.def_abstract_eval
def _testing_multi_platform_abstract_eval(xaval: core.AbstractValue):
assert xaval.dtype == np.float32 # type: ignore
return xaval

@_testing_multi_platform_p.def_impl
def _testing_multi_platform_impl(x: jax.Array) -> jax.Array:
to_add = _testing_multi_platform_to_add[platform]
return x + to_add

def _testing_multi_platform_lowering(ctx: mlir.LoweringRuleContext,
x: mlir.Value,
*,
platform: str) -> Sequence[mlir.Value]:
to_add = _testing_multi_platform_to_add[platform]
to_add_value = mlir.broadcast_in_dim(ctx,
mlir.ir_constant(
np.float32(to_add)),
ctx.avals_in[0],
broadcast_dimensions=())
return mlir.hlo.AddOp(x, to_add_value).results

# Register a default rule for cuda, to test the default-platform rule selection.
mlir.register_lowering(_testing_multi_platform_p,
functools.partial(_testing_multi_platform_lowering,
platform="cuda"))
for platform in ["cpu", "tpu", "rocm"]:
mlir.register_lowering(_testing_multi_platform_p,
functools.partial(
_testing_multi_platform_lowering,
platform=platform),
platform=platform)

def f_jax(x):
return _testing_multi_platform_p.bind(x)
return x + lax.platform_dependent(
tpu=lambda: _testing_multi_platform_to_add["tpu"],
cuda=lambda: _testing_multi_platform_to_add["cuda"],
rocm=lambda: _testing_multi_platform_to_add["rocm"],
default=lambda: _testing_multi_platform_to_add["cpu"]
)

x = np.float32(.42)
f_tf = jax2tf.convert(
Expand Down
1 change: 1 addition & 0 deletions jax/lax/__init__.py
Expand Up @@ -338,6 +338,7 @@
switch as switch,
while_loop as while_loop,
while_p as while_p,
platform_dependent as platform_dependent,
)
from jax._src.lax.fft import (
fft as fft,
Expand Down
63 changes: 20 additions & 43 deletions tests/export_test.py
Expand Up @@ -18,11 +18,12 @@
import logging
import math
import re
from typing import Optional, Sequence
from typing import Optional
import unittest

from absl.testing import absltest
import jax
from jax import lax
from jax import numpy as jnp
from jax import tree_util
from jax.experimental.export import export
Expand Down Expand Up @@ -94,56 +95,32 @@ def lowering_testing_primitive_with_effect(ctx, a, *, effect_class_name: str):
lowering_testing_primitive_with_effect)

## Setup for multi-platform lowering
# A primitive for testing multi-platform lowering. Takes one argument and
# adds a different value to it: cpu=2., tpu=3., cuda=.4, rocm=5.
# The primitive takes an event_class_name kwarg that may be None, or
# the name of an effect class.
_testing_multi_platform_p = core.Primitive("testing_multi_platform")
_testing_multi_platform_to_add = dict(cpu=2., tpu=3., cuda=4., rocm=5.)

def _testing_multi_platform_func(x, *,
effect_class_name: Optional[str] = None):
return _testing_multi_platform_p.bind(x, effect_class_name=effect_class_name)
# Behaves like x + 2 * _testing_multi_platform_to_add[platform]
def for_platform(platform: str):
if effect_class_name is None:
return 2. * _testing_multi_platform_to_add[platform]
else:
return testing_primitive_with_effect_p.bind(
_testing_multi_platform_to_add[platform],
effect_class_name=effect_class_name)

return x + lax.platform_dependent(
tpu=lambda: for_platform("tpu"),
cuda=lambda: for_platform("cuda"),
rocm=lambda: for_platform("rocm"),
default=lambda: for_platform("cpu"),
)

def _testing_multi_platform_fun_expected(x,
platform: str | None = None):
return x + _testing_multi_platform_to_add[
return x + 2. * _testing_multi_platform_to_add[
xb.canonicalize_platform(platform or jtu.device_under_test())
]

@_testing_multi_platform_p.def_effectful_abstract_eval
def _testing_multi_platform_abstract_eval(xaval: core.AbstractValue,
effect_class_name: Optional[str]):
assert xaval.dtype == np.float32 # type: ignore
effects = set() if effect_class_name is None else set([_testing_effects[effect_class_name]])
return (xaval, effects)

def _testing_multi_platform_lowering(ctx: mlir.LoweringRuleContext,
x: mlir.Value,
*,
effect_class_name: Optional[str],
platform: str) -> Sequence[mlir.Value]:
to_add = _testing_multi_platform_to_add[platform]
to_add_value = mlir.broadcast_in_dim(ctx,
mlir.ir_constant(np.float32(to_add)),
ctx.avals_in[0],
broadcast_dimensions=())
results = mlir.hlo.AddOp(x, to_add_value).results
if effect_class_name is not None and "Ordered" in effect_class_name:
token_in = ctx.tokens_in.get(_testing_effects[effect_class_name])[0]
ctx.set_tokens_out(mlir.TokenSet({_testing_effects[effect_class_name]: (token_in,)}))
return results

# Register a default rule for cuda, to test the default-platform rule selection.
mlir.register_lowering(_testing_multi_platform_p,
functools.partial(_testing_multi_platform_lowering,
platform="cuda"))
for platform in ["cpu", "tpu", "rocm"]:
mlir.register_lowering(_testing_multi_platform_p,
functools.partial(_testing_multi_platform_lowering,
platform=platform),
platform=platform)


class JaxExportTest(jtu.JaxTestCase):

Expand Down Expand Up @@ -813,8 +790,8 @@ def f_jax(x): # x: f32[10,20] -> f32[20,10]
def test_multi_platform(self):
x = np.arange(8, dtype=np.float32)
exp = export.export(_testing_multi_platform_func,
lowering_platforms=("cpu", "tpu", "cuda"))(x)
self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda"))
lowering_platforms=("tpu", "cpu", "cuda"))(x)
self.assertEqual(exp.lowering_platforms, ("tpu", "cpu", "cuda"))
module_str = str(exp.mlir_module())
expected_main_re = (
r"@main\("
Expand Down

0 comments on commit 8feb413

Please sign in to comment.