Skip to content

Commit

Permalink
Cleanup the code to picking lowering rules based on platform.
Browse files Browse the repository at this point in the history
Previously, we had special-cased the code to pick the lowering
rule for a primitive based on the lowering platform, and separately
we had the code to handle multi-platform lowering. The latter,
called `mlir.lower_multi_platform` had its own special case for
when a single lowering rule applied.

We rename `mlir.lower_multi_platform` to `mlir.lower_per_platform`
to not imply that it is only for multi-platform. We simplify
its API (takes a dictionary instead of a list of tuples).
  • Loading branch information
gnecula committed Nov 19, 2023
1 parent 52b31a4 commit 2d9da6c
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 114 deletions.
176 changes: 75 additions & 101 deletions jax/_src/interpreters/mlir.py
Expand Up @@ -1416,46 +1416,23 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None
ctx.name_stack)
with source_info_util.user_context(eqn.source_info.traceback), loc:
override_rule = get_override_lowering_rule(eqn.primitive)
if len(ctx.platforms) == 1:
# Classic, single-platform lowering
# TODO(necula): unify the code paths when multi-platform is finished
platform = ctx.platforms[0]
if override_rule is not None:
rule = override_rule
elif eqn.primitive in _platform_specific_lowerings[platform]:
rule = _platform_specific_lowerings[platform][eqn.primitive]
elif eqn.primitive in xla._backend_specific_translations[platform]:
rule = xla_fallback_lowering(eqn.primitive)
elif eqn.primitive in _lowerings:
rule = _lowerings[eqn.primitive]
elif eqn.primitive in xla._translations:
rule = xla_fallback_lowering(eqn.primitive)
else:
raise NotImplementedError(
f"MLIR translation rule for primitive '{eqn.primitive.name}' not "
f"found for platform {platform}")
platform_rules: dict[str, LoweringRule] = {}
default_rule: Optional[LoweringRule] = None
# See mlir.lower_per_platform for meaning of `platform_rules` and `default_rule`
if override_rule is not None:
default_rule = override_rule
else:
rules: list[MultiPlatformLoweringRule]
# See mlir.lower_multi_platform for the `rules` format
if override_rule is not None:
rules = [(None, override_rule)]
else:
# First the platform-specific rules
rules = []
for p in ctx.platforms:
if eqn.primitive in _platform_specific_lowerings[p]:
rules.append(
([p], _platform_specific_lowerings[p][eqn.primitive]))
elif eqn.primitive in xla._backend_specific_translations[p]:
rules.append(
([p], xla_fallback_lowering(eqn.primitive)))
# Now the catch-all rules
if eqn.primitive in _lowerings:
rules.append(
(None, _lowerings[eqn.primitive])) # type: ignore
elif eqn.primitive in xla._translations:
rules.append(
(None, xla_fallback_lowering(eqn.primitive))) # type: ignore
# First the platform-specific rules
for p in ctx.platforms:
if eqn.primitive in _platform_specific_lowerings[p]:
platform_rules[p] = _platform_specific_lowerings[p][eqn.primitive]
elif eqn.primitive in xla._backend_specific_translations[p]:
platform_rules[p] = xla_fallback_lowering(eqn.primitive)
# Now the default rule
if eqn.primitive in _lowerings:
default_rule = _lowerings[eqn.primitive]
elif eqn.primitive in xla._translations:
default_rule = xla_fallback_lowering(eqn.primitive)

eqn_ctx = ctx.replace(name_stack=source_info.name_stack)
effects = list(effects_lib.ordered_effects.filter_in(eqn.effects))
Expand All @@ -1473,13 +1450,10 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None
rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env)

rule_inputs = map(_unwrap_singleton_ir_values, in_nodes)
if len(ctx.platforms) == 1:
# Classic, single-platform lowering
ans = rule(rule_ctx, *rule_inputs, **eqn.params)
else:
ans = lower_multi_platform(rule_ctx, str(eqn), rules,
eqn.effects,
*rule_inputs, **eqn.params)
ans = lower_per_platform(rule_ctx, str(eqn.primitive),
platform_rules, default_rule,
eqn.effects,
*rule_inputs, **eqn.params)

if effects:
# If there were ordered effects in the primitive, there should be output
Expand Down Expand Up @@ -1510,82 +1484,82 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None
core.clean_up_dead_vars(eqn, env, last_used)
return map(read, jaxpr.outvars), tokens

# See docstring for lower_multi_platform.
MultiPlatformLoweringRule = tuple[Optional[Sequence[str]], Callable]

def lower_multi_platform(ctx: LoweringRuleContext,
description: str,
rules: Sequence[MultiPlatformLoweringRule],
effects: effects_lib.Effects,
*rule_args: ir.Value,
**rule_kwargs) -> ir.Value:
"""Emits single- or multi-platform code for a primitive.

def lower_per_platform(ctx: LoweringRuleContext,
description: str,
platform_rules: dict[str, LoweringRule],
default_rule: Optional[LoweringRule],
effects: effects_lib.Effects,
*rule_args: ir.Value,
**rule_kwargs) -> ir.Value:
"""Emits code for a primitive for the current lowering platform(s).
For example, given
ctx.module_context.lowering_parameters.platforms = ("cpu", "gpu", "tpu")
platform_rules = dict(tpu=rule0, cpu=rule0)
default_rule = rule1
and
rules = [(["tpu", "cpu"], rule0),
(None, rule1)
ctx.module_context.lowering_parameters.platforms = ("cpu",)
emits:
rule0(ctx, *rule_args, **rule_kwargs)
In case of multi-platform lowering, e.g., if
ctx.module_context.lowering_parameters.platforms = ("cpu", "cuda", "tpu")
emits:
rule_idx = case current_platform_idx:
0: return 0 # cpu rule index
1: return 1 # gpu rule index
1: return 1 # cuda rule index
2: return 0 # tpu rule index
output = case rule_idx
0: return rule0(*rule_args, **rule_kwargs)
1: return rule1(*rule_args, **rule_kwargs)
If the primitive has a single lowering rule for all platforms of interest,
skips the conditionals and emits the same code as for classic single-platform
lowering.
Args:
ctx: lowering context.
description: a string to include in error messages.
rules: a sequence of per-platform rules. Each entry is a tuple, with the
first element specifying the platforms, either a sequence of applicable
platform names (maybe empty), or None to denote a default entry to use
when no other entry applies. The second element of the tuple is a
lowering rule, i.e., a function to invoke with a
LoweringRuleContext (a sub-context of `ctx`),
and `*rule_args` and `**rule_kwargs`.
platform_rules: map platform names, e.g., "cpu", "cuda", to
`LoweringRule`s, for the platforms that have non-default lowering.
default_rule: an optional rule to use for platforms not in `platform_rules`.
effects: the set of effects for the current primitive.
rule_args: the args of the lowering rules.
rule_kwargs: the kwargs of the lowering rules.
"""
platforms: Sequence[str] = ctx.module_context.platforms
platforms_with_specific_rules: Sequence[str] = util.flatten(
[ps for ps, _ in rules if ps is not None])
platforms_with_default_rule = [p for p in platforms
if p not in platforms_with_specific_rules]
kept_rules: list[MultiPlatformLoweringRule] = [] # Only the rules for platforms of interest
# Special case the common case (single-platform lowering)
if len(platforms) == 1:
rule = platform_rules.get(platforms[0], default_rule)
if rule is None:
raise NotImplementedError(
f"MLIR translation rule for primitive '{description}' not "
f"found for platform {platforms[0]}")

# Multi-platform lowering
kept_rules: list[LoweringRule] = [] # Only the rules for the platforms of interest
platform_to_kept_rules_idx: dict[str, int] = {}
for ps, r in rules:
rule_index = len(kept_rules)
if ps is not None:
# Keep only rules that mention the platforms of interest
interesting_ps = [p for p in platforms if p in ps] # type: ignore
if interesting_ps:
for p in interesting_ps:
assert p not in platform_to_kept_rules_idx
platform_to_kept_rules_idx[p] = rule_index
kept_rules.append((interesting_ps, r))
elif platforms_with_default_rule:
for p in platforms_with_default_rule:
assert p not in platform_to_kept_rules_idx
platform_to_kept_rules_idx[p] = rule_index
kept_rules.append((platforms_with_default_rule, r))

platforms_without_rules = [p for p in platforms
if p not in platform_to_kept_rules_idx]
if platforms_without_rules:
raise ValueError(
f"MLIR translation rule for primitive '{description}' not "
f"found for platforms {platforms_without_rules}")
assert kept_rules
for p, prule in platform_rules.items():
if p not in platforms:
continue
platform_to_kept_rules_idx[p] = len(kept_rules)
kept_rules.append(prule)

platforms_without_specific_rule = [p for p in platforms
if p not in platform_to_kept_rules_idx]
if platforms_without_specific_rule:
if default_rule is None:
raise NotImplementedError(
f"MLIR translation rule for primitive '{description}' not "
f"found for platforms {platforms_without_specific_rule}")
for p in platforms_without_specific_rule:
platform_to_kept_rules_idx[p] = len(kept_rules)
kept_rules.append(default_rule)

# Maybe there is a single rule left, just apply the rule, no conditionals.
assert kept_rules
# If there is a single rule left just apply the rule, without conditionals.
if len(kept_rules) == 1:
return kept_rules[0][1](ctx, *rule_args, **rule_kwargs)
return kept_rules[0](ctx, *rule_args, **rule_kwargs)

assert len(platforms) > 1 and len(kept_rules) >= 2, (platforms, kept_rules)
assert len(ctx.dim_var_values) >= 1, "Must have a platform_index variable"
Expand All @@ -1609,7 +1583,7 @@ def lower_multi_platform(ctx: LoweringRuleContext,
case_op = hlo.CaseOp(util.flatten(output_types),
index=rule_idx_op,
num_branches=len(kept_rules))
for i, (_, rule) in enumerate(kept_rules):
for i, rule in enumerate(kept_rules):
inner_ctx = ctx.replace()
branch = case_op.regions[i].blocks.append()
with ir.InsertionPoint(branch):
Expand Down
23 changes: 11 additions & 12 deletions jax/_src/lax/control_flow/conditionals.py
Expand Up @@ -938,7 +938,7 @@ def other_platforms_code(*args): ...
platform_branches: list[tuple[list[str], Callable]] = []
for pname, pbranch in per_platform.items():
if pname == "gpu":
raise ValueError("Use 'cuda' or 'rocm' for this API.")
raise ValueError("Use 'cuda' or 'rocm' for lax.platform_dependent.")
for ps, b in platform_branches:
if b == pbranch:
ps.append(pname)
Expand Down Expand Up @@ -979,18 +979,17 @@ def _platform_index_lowering(ctx: mlir.LoweringRuleContext,
has_default: bool):
def lower_constant(ctx: mlir.LoweringRuleContext, *, i: int) -> mlir.ir.Value:
return mlir.ir_constants(np.int32(i))
lowering_rules: tuple[mlir.MultiPlatformLoweringRule, ...] = tuple(
(ps, partial(lower_constant, i=i))
for i, ps in enumerate(platforms)
)
if has_default:
lowering_rules = lowering_rules + (
(None, partial(lower_constant, i=len(platforms))),
)
return mlir.lower_multi_platform(
platform_rules: dict[str, mlir.LoweringRule] = {}
for i, ps in enumerate(platforms):
rule = partial(lower_constant, i=i)
for p in ps:
platform_rules[p] = rule

default_rule = (
partial(lower_constant, i=len(platforms)) if has_default else None)
return mlir.lower_per_platform(
ctx,
f"platform_index(platforms={platforms}, has_default={has_default})",
lowering_rules,
effects.no_effects)
platform_rules, default_rule, effects.no_effects)

mlir.register_lowering(platform_index_p, _platform_index_lowering)
2 changes: 1 addition & 1 deletion tests/lax_control_flow_test.py
Expand Up @@ -2782,7 +2782,7 @@ def test_platform_dependent_no_default(self):
ctx = contextlib.ExitStack()
if jtu.device_under_test() != "tpu":
ctx.enter_context(
self.assertRaisesRegex(ValueError,
self.assertRaisesRegex(NotImplementedError,
"translation rule .* not found for platform"))
with ctx:
lax.platform_dependent(
Expand Down

0 comments on commit 2d9da6c

Please sign in to comment.