diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 944b208db87c..e8c23b0d81df 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1416,46 +1416,23 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None ctx.name_stack) with source_info_util.user_context(eqn.source_info.traceback), loc: override_rule = get_override_lowering_rule(eqn.primitive) - if len(ctx.platforms) == 1: - # Classic, single-platform lowering - # TODO(necula): unify the code paths when multi-platform is finished - platform = ctx.platforms[0] - if override_rule is not None: - rule = override_rule - elif eqn.primitive in _platform_specific_lowerings[platform]: - rule = _platform_specific_lowerings[platform][eqn.primitive] - elif eqn.primitive in xla._backend_specific_translations[platform]: - rule = xla_fallback_lowering(eqn.primitive) - elif eqn.primitive in _lowerings: - rule = _lowerings[eqn.primitive] - elif eqn.primitive in xla._translations: - rule = xla_fallback_lowering(eqn.primitive) - else: - raise NotImplementedError( - f"MLIR translation rule for primitive '{eqn.primitive.name}' not " - f"found for platform {platform}") + platform_rules: dict[str, LoweringRule] = {} + default_rule: Optional[LoweringRule] = None + # See mlir.lower_per_platform for meaning of `platform_rules` and `default_rule` + if override_rule is not None: + default_rule = override_rule else: - rules: list[MultiPlatformLoweringRule] - # See mlir.lower_multi_platform for the `rules` format - if override_rule is not None: - rules = [(None, override_rule)] - else: - # First the platform-specific rules - rules = [] - for p in ctx.platforms: - if eqn.primitive in _platform_specific_lowerings[p]: - rules.append( - ([p], _platform_specific_lowerings[p][eqn.primitive])) - elif eqn.primitive in xla._backend_specific_translations[p]: - rules.append( - ([p], xla_fallback_lowering(eqn.primitive))) - # Now the catch-all rules - if eqn.primitive in _lowerings: - rules.append( - (None, _lowerings[eqn.primitive])) # type: ignore - elif eqn.primitive in xla._translations: - rules.append( - (None, xla_fallback_lowering(eqn.primitive))) # type: ignore + # First the platform-specific rules + for p in ctx.platforms: + if eqn.primitive in _platform_specific_lowerings[p]: + platform_rules[p] = _platform_specific_lowerings[p][eqn.primitive] + elif eqn.primitive in xla._backend_specific_translations[p]: + platform_rules[p] = xla_fallback_lowering(eqn.primitive) + # Now the default rule + if eqn.primitive in _lowerings: + default_rule = _lowerings[eqn.primitive] + elif eqn.primitive in xla._translations: + default_rule = xla_fallback_lowering(eqn.primitive) eqn_ctx = ctx.replace(name_stack=source_info.name_stack) effects = list(effects_lib.ordered_effects.filter_in(eqn.effects)) @@ -1473,13 +1450,10 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env) rule_inputs = map(_unwrap_singleton_ir_values, in_nodes) - if len(ctx.platforms) == 1: - # Classic, single-platform lowering - ans = rule(rule_ctx, *rule_inputs, **eqn.params) - else: - ans = lower_multi_platform(rule_ctx, str(eqn), rules, - eqn.effects, - *rule_inputs, **eqn.params) + ans = lower_per_platform(rule_ctx, str(eqn.primitive), + platform_rules, default_rule, + eqn.effects, + *rule_inputs, **eqn.params) if effects: # If there were ordered effects in the primitive, there should be output @@ -1510,82 +1484,82 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None core.clean_up_dead_vars(eqn, env, last_used) return map(read, jaxpr.outvars), tokens -# See docstring for lower_multi_platform. -MultiPlatformLoweringRule = tuple[Optional[Sequence[str]], Callable] -def lower_multi_platform(ctx: LoweringRuleContext, - description: str, - rules: Sequence[MultiPlatformLoweringRule], - effects: effects_lib.Effects, - *rule_args: ir.Value, - **rule_kwargs) -> ir.Value: - """Emits single- or multi-platform code for a primitive. + +def lower_per_platform(ctx: LoweringRuleContext, + description: str, + platform_rules: dict[str, LoweringRule], + default_rule: Optional[LoweringRule], + effects: effects_lib.Effects, + *rule_args: ir.Value, + **rule_kwargs) -> ir.Value: + """Emits code for a primitive for the current lowering platform(s). For example, given - ctx.module_context.lowering_parameters.platforms = ("cpu", "gpu", "tpu") + platform_rules = dict(tpu=rule0, cpu=rule0) + default_rule = rule1 + and - rules = [(["tpu", "cpu"], rule0), - (None, rule1) + ctx.module_context.lowering_parameters.platforms = ("cpu",) + + emits: + rule0(ctx, *rule_args, **rule_kwargs) + + In case of multi-platform lowering, e.g., if + ctx.module_context.lowering_parameters.platforms = ("cpu", "cuda", "tpu") + emits: rule_idx = case current_platform_idx: 0: return 0 # cpu rule index - 1: return 1 # gpu rule index + 1: return 1 # cuda rule index 2: return 0 # tpu rule index output = case rule_idx 0: return rule0(*rule_args, **rule_kwargs) 1: return rule1(*rule_args, **rule_kwargs) - If the primitive has a single lowering rule for all platforms of interest, - skips the conditionals and emits the same code as for classic single-platform - lowering. - Args: ctx: lowering context. description: a string to include in error messages. - rules: a sequence of per-platform rules. Each entry is a tuple, with the - first element specifying the platforms, either a sequence of applicable - platform names (maybe empty), or None to denote a default entry to use - when no other entry applies. The second element of the tuple is a - lowering rule, i.e., a function to invoke with a - LoweringRuleContext (a sub-context of `ctx`), - and `*rule_args` and `**rule_kwargs`. + platform_rules: map platform names, e.g., "cpu", "cuda", to + `LoweringRule`s, for the platforms that have non-default lowering. + default_rule: an optional rule to use for platforms not in `platform_rules`. + effects: the set of effects for the current primitive. rule_args: the args of the lowering rules. rule_kwargs: the kwargs of the lowering rules. """ platforms: Sequence[str] = ctx.module_context.platforms - platforms_with_specific_rules: Sequence[str] = util.flatten( - [ps for ps, _ in rules if ps is not None]) - platforms_with_default_rule = [p for p in platforms - if p not in platforms_with_specific_rules] - kept_rules: list[MultiPlatformLoweringRule] = [] # Only the rules for platforms of interest + # Special case the common case (single-platform lowering) + if len(platforms) == 1: + rule = platform_rules.get(platforms[0], default_rule) + if rule is None: + raise NotImplementedError( + f"MLIR translation rule for primitive '{description}' not " + f"found for platform {platforms[0]}") + + # Multi-platform lowering + kept_rules: list[LoweringRule] = [] # Only the rules for the platforms of interest platform_to_kept_rules_idx: dict[str, int] = {} - for ps, r in rules: - rule_index = len(kept_rules) - if ps is not None: - # Keep only rules that mention the platforms of interest - interesting_ps = [p for p in platforms if p in ps] # type: ignore - if interesting_ps: - for p in interesting_ps: - assert p not in platform_to_kept_rules_idx - platform_to_kept_rules_idx[p] = rule_index - kept_rules.append((interesting_ps, r)) - elif platforms_with_default_rule: - for p in platforms_with_default_rule: - assert p not in platform_to_kept_rules_idx - platform_to_kept_rules_idx[p] = rule_index - kept_rules.append((platforms_with_default_rule, r)) - - platforms_without_rules = [p for p in platforms - if p not in platform_to_kept_rules_idx] - if platforms_without_rules: - raise ValueError( - f"MLIR translation rule for primitive '{description}' not " - f"found for platforms {platforms_without_rules}") - assert kept_rules + for p, prule in platform_rules.items(): + if p not in platforms: + continue + platform_to_kept_rules_idx[p] = len(kept_rules) + kept_rules.append(prule) + + platforms_without_specific_rule = [p for p in platforms + if p not in platform_to_kept_rules_idx] + if platforms_without_specific_rule: + if default_rule is None: + raise NotImplementedError( + f"MLIR translation rule for primitive '{description}' not " + f"found for platforms {platforms_without_specific_rule}") + for p in platforms_without_specific_rule: + platform_to_kept_rules_idx[p] = len(kept_rules) + kept_rules.append(default_rule) - # Maybe there is a single rule left, just apply the rule, no conditionals. + assert kept_rules + # If there is a single rule left just apply the rule, without conditionals. if len(kept_rules) == 1: - return kept_rules[0][1](ctx, *rule_args, **rule_kwargs) + return kept_rules[0](ctx, *rule_args, **rule_kwargs) assert len(platforms) > 1 and len(kept_rules) >= 2, (platforms, kept_rules) assert len(ctx.dim_var_values) >= 1, "Must have a platform_index variable" @@ -1609,7 +1583,7 @@ def lower_multi_platform(ctx: LoweringRuleContext, case_op = hlo.CaseOp(util.flatten(output_types), index=rule_idx_op, num_branches=len(kept_rules)) - for i, (_, rule) in enumerate(kept_rules): + for i, rule in enumerate(kept_rules): inner_ctx = ctx.replace() branch = case_op.regions[i].blocks.append() with ir.InsertionPoint(branch): diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 376e35852ebf..192ba2648665 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -938,7 +938,7 @@ def other_platforms_code(*args): ... platform_branches: list[tuple[list[str], Callable]] = [] for pname, pbranch in per_platform.items(): if pname == "gpu": - raise ValueError("Use 'cuda' or 'rocm' for this API.") + raise ValueError("Use 'cuda' or 'rocm' for lax.platform_dependent.") for ps, b in platform_branches: if b == pbranch: ps.append(pname) @@ -979,18 +979,17 @@ def _platform_index_lowering(ctx: mlir.LoweringRuleContext, has_default: bool): def lower_constant(ctx: mlir.LoweringRuleContext, *, i: int) -> mlir.ir.Value: return mlir.ir_constants(np.int32(i)) - lowering_rules: tuple[mlir.MultiPlatformLoweringRule, ...] = tuple( - (ps, partial(lower_constant, i=i)) - for i, ps in enumerate(platforms) - ) - if has_default: - lowering_rules = lowering_rules + ( - (None, partial(lower_constant, i=len(platforms))), - ) - return mlir.lower_multi_platform( + platform_rules: dict[str, mlir.LoweringRule] = {} + for i, ps in enumerate(platforms): + rule = partial(lower_constant, i=i) + for p in ps: + platform_rules[p] = rule + + default_rule = ( + partial(lower_constant, i=len(platforms)) if has_default else None) + return mlir.lower_per_platform( ctx, f"platform_index(platforms={platforms}, has_default={has_default})", - lowering_rules, - effects.no_effects) + platform_rules, default_rule, effects.no_effects) mlir.register_lowering(platform_index_p, _platform_index_lowering) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index ca067e10ddbf..6f08e1317882 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2782,7 +2782,7 @@ def test_platform_dependent_no_default(self): ctx = contextlib.ExitStack() if jtu.device_under_test() != "tpu": ctx.enter_context( - self.assertRaisesRegex(ValueError, + self.assertRaisesRegex(NotImplementedError, "translation rule .* not found for platform")) with ctx: lax.platform_dependent(