Skip to content

Commit

Permalink
Cleanup the handling of single- and multi-platform lowering in Module…
Browse files Browse the repository at this point in the history
…Context

Previously, we introduced support for multi-platform lowering, by
adding a new LoweringParameters object that can be used to specify
a cross-lowering platform or even multiple platforms. But we had
kept the ModuleContext.platform in place because some lowering rules
were still referencing it. Now we replace ModuleContext.platform with
ModuleContext.platforms, which removes the redundancy, simplifies
the code, and makes it clearer that the lowering rules should not
simply assume single-platform lowering.

PiperOrigin-RevId: 576575376
  • Loading branch information
gnecula authored and jax authors committed Oct 25, 2023
1 parent 468f666 commit edbe49f
Show file tree
Hide file tree
Showing 11 changed files with 73 additions and 83 deletions.
2 changes: 1 addition & 1 deletion jax/_src/api.py
Expand Up @@ -565,7 +565,7 @@ def computation_maker(*args, **kwargs):
core.ClosedJaxpr(jaxpr, consts),
ordered_effects=ordered_effects,
backend_or_name=backend,
platform=platform,
platforms=[platform],
axis_context=sharding_impls.ReplicaAxisContext(axis_env_),
name_stack=source_info_util.new_name_stack(
wrap_name(fun_name, "xla_computation")),
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/debugging.py
Expand Up @@ -385,7 +385,7 @@ def inspect_sharding_partition(shapes, arg_shardings, result_shape,
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
trivial_comp = mlir.build_xla_computation_helper(closed_jaxpr,
name="tmp_xla_computation", platform=module_context.platform,
name="tmp_xla_computation", platforms=module_context.platforms,
backend_or_name=module_context.backend_or_name,
axis_context=module_context.axis_context)
# The trivial computation built here has a dummy tuple as the result,
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/dispatch.py
Expand Up @@ -554,6 +554,6 @@ def _common_device_put_lowering(ctx, x, *, device, src):
device.memory_kind is not None):
raise NotImplementedError(
"Passing memory_kind to device_put via Shardings is not supported on"
f" platform {ctx.module_context.platform}")
f" platforms {ctx.module_context.platforms}")
return [x]
mlir.register_lowering(device_put_p, _common_device_put_lowering)
104 changes: 48 additions & 56 deletions jax/_src/interpreters/mlir.py
Expand Up @@ -430,8 +430,7 @@ class LoweringParameters:
# The current lowering platforms, a non-empty tuple containing some of
# 'cpu', 'cuda', 'rocm', 'tpu'. If the tuple has multiple entries we are
# doing multi-platform lowering, otherwise it can specify cross-platform
# lowering. The value None specifies default lowering platform, for the
# platform specified by `ModuleContext.platform`.
# lowering. The value None specifies the default lowering platform.
# This is used only in export and jax2tf.
platforms: tuple[str, ...] | None = None

Expand All @@ -454,23 +453,6 @@ class LoweringParameters:
# native execution (and we can remove this parameter).
replace_tokens_with_dummy: bool = True

@property
def override_platform(self) -> str | None:
"""Overrides the lowering platform for cross-platform lowering.
One of 'cpu', 'cuda', 'rocm', 'tpu'.
If None, use the default JAX mechanisms to pick the lowering platform.
This is currently used for export and jax2tf.
"""
if self.platforms is not None:
return self.platforms[0]
else:
return None

@property
def is_multi_platform(self) -> bool:
return self.platforms is not None and len(self.platforms) > 1


@dataclasses.dataclass
class ModuleContext:
Expand All @@ -480,7 +462,7 @@ class ModuleContext:
ip: ir.InsertionPoint
symbol_table: ir.SymbolTable
backend_or_name: str | xb.XlaBackend | None
platform: str
platforms: Sequence[str]
axis_context: AxisContext
name_stack: source_info_util.NameStack
keepalives: list[Any]
Expand All @@ -503,7 +485,7 @@ def __init__(
self,
*,
backend_or_name: str | xb.XlaBackend | None,
platform: str,
platforms: Sequence[str],
axis_context: AxisContext,
name_stack: source_info_util.NameStack,
keepalives: list[Any],
Expand All @@ -519,13 +501,13 @@ def __init__(
cached_call_jaxpr_lowerings: None | (dict[Any,
func_dialect.FuncOp]) = None,
shape_poly_state = None):
assert platform is not None

self.context = context or make_ir_context()
self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context))
self.ip = ip or ir.InsertionPoint(self.module.body)
self.symbol_table = symbol_table or ir.SymbolTable(self.module.operation)
self.backend_or_name = backend_or_name
self.platform = platform
self.platforms = platforms
self.axis_context = axis_context
self.name_stack = name_stack
self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None
Expand All @@ -536,12 +518,18 @@ def __init__(
self.cached_call_jaxpr_lowerings = ({}
if cached_call_jaxpr_lowerings is None
else cached_call_jaxpr_lowerings)
self.shape_poly_state = shape_poly_state or ShapePolyLoweringState((),
(platform,))
self.shape_poly_state = (
shape_poly_state or ShapePolyLoweringState((), tuple(platforms)))
self.lowering_parameters = lowering_parameters

@property
def backend(self) -> xb.XlaBackend:
# TODO(necula): clean the use of backend and backend_or_name vs. platforms
if len(self.platforms) > 1:
raise NotImplementedError(
"accessing .backend in multi-lowering setting. This can occur when "
"lowering a primitive that has not been adapted to multi-platform "
"lowering")
if self.backend_or_name is None or isinstance(self.backend_or_name, str):
return xb.get_backend(self.backend_or_name)
return self.backend_or_name
Expand Down Expand Up @@ -722,7 +710,7 @@ def lower_jaxpr_to_module(
*,
ordered_effects: list[core.Effect],
backend_or_name: str | xb.XlaBackend | None,
platform: str,
platforms: Sequence[str],
axis_context: AxisContext,
name_stack: source_info_util.NameStack,
donated_args: Sequence[bool],
Expand All @@ -741,13 +729,7 @@ def lower_jaxpr_to_module(
Handles the quirks of the argument/return value passing conventions of the
runtime.
"""
platforms: tuple[str, ...]
platform = xb.canonicalize_platform(platform)
if lowering_parameters.is_multi_platform:
platforms = tuple(map(xb.canonicalize_platform,
lowering_parameters.platforms)) # type: ignore
else:
platforms = (platform,)
platforms = tuple(map(xb.canonicalize_platform, platforms))

input_output_aliases = None
in_avals = (jaxpr.in_avals if arg_shardings is None else
Expand Down Expand Up @@ -809,7 +791,7 @@ def lower_jaxpr_to_module(
if result_shardings is not None else result_shardings)

ctx = ModuleContext(backend_or_name=backend_or_name,
platform=platform, axis_context=axis_context,
platforms=platforms, axis_context=axis_context,
name_stack=name_stack,
keepalives=keepalives,
channel_iterator=channel_iter,
Expand Down Expand Up @@ -1349,7 +1331,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
dim_var_values: the list of dimension variables values in the current
IR function, in the order of ctx.shape_poly_state.dim_vars.
"""
assert ctx.platform != "gpu"
assert "gpu" not in ctx.platforms
def read(v: core.Atom) -> Sequence[ir.Value]:
if type(v) is core.Literal:
return ir_constants(xla.canonicalize_dtype(v.val))
Expand Down Expand Up @@ -1393,14 +1375,15 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None
ctx.name_stack)
with source_info_util.user_context(eqn.source_info.traceback), loc:
override_rule = get_override_lowering_rule(eqn.primitive)
if not ctx.lowering_parameters.is_multi_platform:
if len(ctx.platforms) == 1:
# Classic, single-platform lowering
# TODO(necula): unify the code paths when multi-platform is finished
platform = ctx.platforms[0]
if override_rule is not None:
rule = override_rule
elif eqn.primitive in _platform_specific_lowerings[ctx.platform]:
rule = _platform_specific_lowerings[ctx.platform][eqn.primitive]
elif eqn.primitive in xla._backend_specific_translations[ctx.platform]:
elif eqn.primitive in _platform_specific_lowerings[platform]:
rule = _platform_specific_lowerings[platform][eqn.primitive]
elif eqn.primitive in xla._backend_specific_translations[platform]:
rule = xla_fallback_lowering(eqn.primitive)
elif eqn.primitive in _lowerings:
rule = _lowerings[eqn.primitive]
Expand All @@ -1409,7 +1392,7 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None
else:
raise NotImplementedError(
f"MLIR translation rule for primitive '{eqn.primitive.name}' not "
f"found for platform {ctx.platform}")
f"found for platform {platform}")
else:
rules: list[MultiPlatformLoweringRule]
# See mlir.lower_multi_platform for the `rules` format
Expand All @@ -1418,7 +1401,7 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None
else:
# First the platform-specific rules
rules = []
for p in ctx.lowering_parameters.platforms: # type: ignore
for p in ctx.platforms:
if eqn.primitive in _platform_specific_lowerings[p]:
rules.append(
([p], _platform_specific_lowerings[p][eqn.primitive]))
Expand Down Expand Up @@ -1449,7 +1432,7 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None
rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env)

rule_inputs = map(_unwrap_singleton_ir_values, in_nodes)
if not ctx.lowering_parameters.is_multi_platform:
if len(ctx.platforms) == 1:
# Classic, single-platform lowering
ans = rule(rule_ctx, *rule_inputs, **eqn.params)
else:
Expand Down Expand Up @@ -1528,12 +1511,7 @@ def lower_multi_platform(ctx: LoweringRuleContext,
rule_args: the args of the lowering rules.
rule_kwargs: the kwargs of the lowering rules.
"""
platforms: Sequence[str]
if ctx.module_context.lowering_parameters.is_multi_platform:
assert ctx.module_context.lowering_parameters.platforms is not None
platforms = ctx.module_context.lowering_parameters.platforms
else:
platforms = (ctx.module_context.platform,)
platforms: Sequence[str] = ctx.module_context.platforms
platforms_with_specific_rules: Sequence[str] = util.flatten(
[ps for ps, _ in rules if ps is not None])
platforms_with_default_rule = [p for p in platforms
Expand Down Expand Up @@ -1681,25 +1659,32 @@ def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects,
return func_op


def check_backend_matches(inner_backend, outer_backend):
def check_backend_matches(inner_backend: Optional[str],
lowering_platforms: Sequence[str]):
# For nested calls, the outermost call sets the backend for all inner calls;
# it's an error if the inner call has a conflicting explicit backend spec.
if inner_backend is None:
return
outer_backend, *more_lowering_platforms = lowering_platforms
if more_lowering_platforms:
raise NotImplementedError(
"Multi-platform lowering when a backend= parameter is specified")
if (inner_backend != outer_backend and
outer_backend not in xb.expand_platform_alias(inner_backend)):
raise ValueError(
f"Outer-jit backend specification {outer_backend} must match explicit "
f"inner-jit backend specification {inner_backend}.")


def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
def _call_lowering(fn_name, stack_name, call_jaxpr, backend,
ctx: ModuleContext, avals_in,
avals_out, tokens_in, *args,
dim_var_values: Sequence[ir.Value],
arg_names=None, result_names=None):
del stack_name, avals_in
if isinstance(call_jaxpr, core.Jaxpr):
call_jaxpr = core.ClosedJaxpr(call_jaxpr, ())
check_backend_matches(backend, ctx.platform)
check_backend_matches(backend, ctx.platforms)
effects = list(tokens_in.effects())
output_types = map(aval_to_ir_types, avals_out)
output_types = [token_type()] * len(effects) + output_types
Expand All @@ -1717,7 +1702,8 @@ def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
tokens_out = tokens_in.update_tokens(TokenSet(zip(effects, tokens)))
return out_nodes, tokens_out

def core_call_lowering(ctx, *args, name, backend=None, call_jaxpr):
def core_call_lowering(ctx: LoweringRuleContext,
*args, name, backend=None, call_jaxpr):
out_nodes, tokens = _call_lowering(
name, name, call_jaxpr, backend, ctx.module_context,
ctx.avals_in, ctx.avals_out, ctx.tokens_in, *args,
Expand Down Expand Up @@ -2137,8 +2123,11 @@ def fallback(ctx: LoweringRuleContext, *args, **params):
raise NotImplementedError(
f"Shape polymorphism for xla_fallback_lowering is not implemented ({ctx.primitive}); b/261682623")

if len(module_ctx.platforms) > 1:
raise NotImplementedError(
"fallback lowering not implemented for multi-platform lowering")
xla_computation = xla.primitive_subcomputation(
module_ctx.platform, axis_env, prim, ctx.avals_in,
module_ctx.platforms[0], axis_env, prim, ctx.avals_in,
ctx.avals_out, **params)
xla_module = xla_computation_to_mlir_module(xla_computation)
callee_name = merge_mlir_modules(
Expand Down Expand Up @@ -2301,7 +2290,9 @@ def emit_python_callback(
result_layouts: Sequence[Sequence[int] | None] | None = None,
) -> tuple[Sequence[ir.Value], Any, Any]:
"""Emits MLIR that calls back to a provided Python function."""
platform = ctx.module_context.platform
if len(ctx.module_context.platforms) > 1:
raise NotImplementedError("multi-platform lowering for python_callback")
platform = ctx.module_context.platforms[0]
if platform not in {"cpu", "cuda", "rocm", "tpu"}:
raise ValueError(
f"`EmitPythonCallback` not supported on {platform} backend.")
Expand Down Expand Up @@ -2423,7 +2414,8 @@ def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function
return results, token, ifrt_callback

def build_xla_computation_helper(
closed_jaxpr: core.ClosedJaxpr, *, name: str, platform: str,
closed_jaxpr: core.ClosedJaxpr, *, name: str,
platforms: Sequence[str],
backend_or_name: str, axis_context: AxisContext) -> xc.XlaComputation:
"""Helper to generate pmap-style XLA computations for custom partitioners."""
if closed_jaxpr.effects:
Expand All @@ -2432,7 +2424,7 @@ def build_xla_computation_helper(
backend_or_name=backend_or_name, ordered_effects=[],
name_stack=source_info_util.NameStack(),
donated_args=[False] * len(closed_jaxpr.jaxpr.invars),
axis_context=axis_context, platform=platform,
axis_context=axis_context, platforms=platforms,
lowering_parameters=LoweringParameters())
return xc._xla.mlir.mlir_module_to_xla_computation(
module_to_string(lowering_result.module), use_tuple_args=False,
Expand Down
8 changes: 4 additions & 4 deletions jax/_src/interpreters/pxla.py
Expand Up @@ -758,7 +758,7 @@ def lower_parallel_callable(
closed_jaxpr,
ordered_effects=ordered_effects,
backend_or_name=backend,
platform=lowering_parameters.override_platform or backend.platform,
platforms=lowering_parameters.platforms or (backend.platform,),
axis_context=sharding_impls.ReplicaAxisContext(axis_env),
name_stack=name_stack,
donated_args=donated_invars,
Expand Down Expand Up @@ -1362,7 +1362,7 @@ def _pmap_lowering(ctx, *in_nodes, axis_name,
call_jaxpr, backend=None, in_axes, out_axes,
donated_invars, is_explicit_global_axis_size):
del donated_invars # Unused.
mlir.check_backend_matches(backend, ctx.module_context.platform)
mlir.check_backend_matches(backend, ctx.module_context.platforms)
# We in-line here rather than generating a Call HLO as in the xla_call
# translation rule just because the extra tuple stuff is a pain.
if ctx.module_context.axis_env.names and devices is not None:
Expand Down Expand Up @@ -1835,7 +1835,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
ordered_effects=ordered_effects,
backend_or_name=backend,
# Optionally, override the lowering platform
platform=lowering_parameters.override_platform or backend.platform,
platforms=lowering_parameters.platforms or (backend.platform,),
axis_context=axis_ctx,
name_stack=name_stack,
donated_args=donated_invars,
Expand Down Expand Up @@ -2201,7 +2201,7 @@ def lower_mesh_computation(
closed_jaxpr,
ordered_effects=ordered_effects,
backend_or_name=backend,
platform=lowering_parameters.override_platform or backend.platform,
platforms=lowering_parameters.platforms or (backend.platform,),
axis_context=axis_ctx,
name_stack=name_stack,
donated_args=donated_invars,
Expand Down
7 changes: 1 addition & 6 deletions jax/_src/lax/parallel.py
Expand Up @@ -725,12 +725,7 @@ def _allreduce_abstract_eval(*args, axes, axis_index_groups):
for arg, named_shape in zip(args, named_shapes)]

def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
# TODO(necula): clean this up when we have module_context.platforms
if ctx.module_context.lowering_parameters.is_multi_platform:
for_tpu = ("tpu" in ctx.module_context.lowering_parameters.platforms)
else:
for_tpu = (ctx.module_context.platform == "tpu")
if axis_index_groups is not None and for_tpu:
if axis_index_groups is not None and ("tpu" in ctx.module_context.platforms):
len_0 = len(axis_index_groups[0])
if any(len(g) != len_0 for g in axis_index_groups):
raise ValueError("axis_index_groups must all be the same size for TPU lowering")
Expand Down
9 changes: 6 additions & 3 deletions jax/_src/maps.py
Expand Up @@ -1305,7 +1305,8 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes,
global_axis_sizes,
spmd_in_axes, spmd_out_axes,
axis_resources, resource_env, backend):
mlir.check_backend_matches(backend, ctx.module_context.platform)
mlir.check_backend_matches(backend, ctx.module_context.platforms)
del backend, donated_invars
# The only way for any of those two assertions to be violated is when xmap
# is using the SPMD lowering, but then this rule shouldn't even trigger.
assert spmd_in_axes is None and spmd_out_axes is None
Expand Down Expand Up @@ -1381,7 +1382,8 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
donated_invars, global_axis_sizes, spmd_in_axes,
spmd_out_axes, axis_resources,
resource_env, backend):
mlir.check_backend_matches(backend, ctx.module_context.platform)
mlir.check_backend_matches(backend, ctx.module_context.platforms)
del backend, donated_invars
plan = EvaluationPlan.from_axis_resources(
axis_resources, resource_env, global_axis_sizes)

Expand Down Expand Up @@ -1447,9 +1449,10 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
donated_invars, global_axis_sizes, spmd_in_axes,
spmd_out_axes, axis_resources,
resource_env, backend):
del donated_invars
assert spmd_in_axes is None and spmd_out_axes is None
# This first part (up to vtile_manual) is shared with non-MANUAL SPMD rule.
mlir.check_backend_matches(backend, ctx.module_context.platform)
mlir.check_backend_matches(backend, ctx.module_context.platforms)
plan = EvaluationPlan.from_axis_resources(
axis_resources, resource_env, global_axis_sizes)
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))
Expand Down

0 comments on commit edbe49f

Please sign in to comment.