diff --git a/jax/_src/api.py b/jax/_src/api.py index 12ceab07231c..bedbbfece051 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -582,7 +582,8 @@ def computation_maker(*args, **kwargs): wrap_name(fun_name, "xla_computation")), donated_args=donated_invars, arg_shardings=None, - result_shardings=None) + result_shardings=None, + lowering_parameters=mlir.LoweringParameters()) built = xc._xla.mlir.mlir_module_to_xla_computation( mlir.module_to_string(lowering_result.module), use_tuple_args=tuple_args, @@ -1904,8 +1905,8 @@ def lower(*args, **kwargs) -> stages.Lowered: Returns: A ``Lowered`` instance representing the post-map lowering. """ - _experimental_lowering_platform = kwargs.pop( - '_experimental_lowering_platform', None) + lowering_parameters = kwargs.pop( + '_experimental_lowering_parameters', mlir.LoweringParameters()) p = _prepare_pmap( fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, devices, backend, axis_size, args, kwargs) @@ -1920,7 +1921,7 @@ def lower(*args, **kwargs) -> stages.Lowered: donated_invars=p.donated_invars, is_explicit_global_axis_size=p.is_explicit_global_axis_size, avals=abstract_args, - lowering_platform=_experimental_lowering_platform) + lowering_parameters=lowering_parameters) return stages.Lowered.from_flat_info( computation, p.in_tree, abstract_args, donate_tuple, p.out_tree()) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 532f9103c7e8..16884b3760dc 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -149,7 +149,7 @@ def prim_fun(*args): computation = sharded_lowering( flat_fun, prim.name, donated_invars, keep_unused=False, inline=True, in_avals=in_avals, in_shardings=orig_in_shardings.shardings, - lowering_platform=None) + lowering_parameters=mlir.LoweringParameters()) compiled = computation.compile() if xla_extension_version >= 192: if config.jax_disable_jit: @@ -169,7 +169,8 @@ def prim_fun(*args): def sharded_lowering( fun: lu.WrappedFun, name: str, donated_invars: Sequence[bool], keep_unused: bool, inline: bool, in_avals: tuple[core.AbstractValue, ...], - in_shardings: Sequence[Sharding | None], lowering_platform: str | None + in_shardings: Sequence[Sharding | None], + lowering_parameters: mlir.LoweringParameters ) -> pxla.MeshComputation: in_shardings_unspec = [UNSPECIFIED if i is None else i for i in in_shardings] @@ -179,7 +180,8 @@ def sharded_lowering( return pxla.lower_sharding_computation( fun, 'jit', name, in_shardings_unspec, UNSPECIFIED, donated_invars, in_avals, keep_unused=keep_unused, inline=inline, - devices_from_context=None, lowering_platform=lowering_platform) + devices_from_context=None, + lowering_parameters=lowering_parameters) def simple_impl(prim): diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 9553d35a85db..6dac7ecf562a 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -390,12 +390,6 @@ def make_ir_context() -> ir.Context: ] class ShapePolyLoweringState: - # The current lowering platforms, a non-empty tuple containing some of - # 'cpu', 'cuda', 'rocm', 'tpu'. - # TODO: this state should be in ModuleContext, but since for now - # multi-platform lowering is implemented only for jax_export, like shape - # polymorphism, we keep it here. - lowering_platforms: tuple[str, ...] # The names of the dimension variables, sorted by name. This is the order in # which they are passed to the IR functions that need them. This is only # used for native serialization with polymorphic shapes when @@ -410,17 +404,48 @@ class ShapePolyLoweringState: # from an inner call to a polymorphic Exported. uses_dim_vars: bool - def __init__(self, dim_vars: tuple[str, ...], - lowering_platforms: tuple[str, ...]): - self.lowering_platforms = lowering_platforms + # If the first dimension variable is a platform index argument + has_platform_index_argument: bool + + def __init__(self, + dim_vars: tuple[str, ...], + lowering_platforms: tuple[str, ...] | None): self.uses_dim_vars = (len(dim_vars) > 0) - if len(lowering_platforms) > 1: + if lowering_platforms is not None and len(lowering_platforms) > 1: dim_vars = ("platform_index_",) + tuple(dim_vars) + self.has_platform_index_argument = True + else: + self.has_platform_index_argument = False self.dim_vars = dim_vars + +@dataclasses.dataclass(frozen=True) +class LoweringParameters: + # A mapping between primitives and user-defined LoweringRules. + # When lowering a primitive, give priorioty to the rule in this map over + # existing Jax rules. + override_lowering_rules: tuple[tuple[core.Primitive, LoweringRule]] | None = None + + # The current lowering platforms, a non-empty tuple containing some of + # 'cpu', 'cuda', 'rocm', 'tpu'. If the tuple has multiple entries we are + # doing multi-platform lowering, otherwise it can specify cross-platform + # lowering. The value None specify default lowering platform. + # This is used only in export and jax2tf. + platforms: tuple[str, ...] | None = None + @property - def has_platform_index_argument(self): - return len(self.lowering_platforms) > 1 + def override_platform(self) -> str | None: + """Overrides the lowering platform for cross-platform lowering. + + One of 'cpu', 'cuda', 'rocm', 'tpu'. + If None, use the default JAX mechanisms to pick the lowering platform. + This is currently used for export and jax2tf. + """ + if self.platforms is not None: + return self.platforms[0] + else: + return None + @dataclasses.dataclass class ModuleContext: @@ -443,10 +468,7 @@ class ModuleContext: cached_primitive_lowerings: dict[Any, func_dialect.FuncOp] cached_call_jaxpr_lowerings: dict[Any, func_dialect.FuncOp] - # A mapping between primitives and user-defined LoweringRules. - # When lowering a primitive, give priorioty to the rule in this map over - # existing Jax rules. - override_lowering_rules: tuple[tuple[core.Primitive, LoweringRule]] | None + lowering_parameters: LoweringParameters @property def axis_env(self) -> sharding_impls.AxisEnv: @@ -454,6 +476,7 @@ def axis_env(self) -> sharding_impls.AxisEnv: def __init__( self, + *, backend_or_name: str | xb.XlaBackend | None, platform: str, axis_context: AxisContext, @@ -461,6 +484,7 @@ def __init__( keepalives: list[Any], channel_iterator: Iterator[int], host_callbacks: list[Any], + lowering_parameters: LoweringParameters, context: ir.Context | None = None, module: ir.Module | None = None, ip: ir.InsertionPoint | None = None, @@ -469,8 +493,6 @@ def __init__( func_dialect.FuncOp]) = None, cached_call_jaxpr_lowerings: None | (dict[Any, func_dialect.FuncOp]) = None, - override_lowering_rules: None | ( - tuple[tuple[core.Primitive, LoweringRule]]) = None, shape_poly_state = None): assert platform is not None self.context = context or make_ir_context() @@ -489,9 +511,9 @@ def __init__( self.cached_call_jaxpr_lowerings = ({} if cached_call_jaxpr_lowerings is None else cached_call_jaxpr_lowerings) - self.override_lowering_rules = override_lowering_rules self.shape_poly_state = shape_poly_state or ShapePolyLoweringState((), (platform,)) + self.lowering_parameters = lowering_parameters @property def backend(self) -> xb.XlaBackend: @@ -664,6 +686,7 @@ def _get_mem_kind(s: Optional[XLACompatibleSharding]) -> Optional[str]: def lower_jaxpr_to_module( module_name: str, jaxpr: core.ClosedJaxpr, + *, ordered_effects: list[core.Effect], backend_or_name: str | xb.XlaBackend | None, platform: str | tuple[str, ...], @@ -678,24 +701,19 @@ def lower_jaxpr_to_module( num_replicas: int = 1, num_partitions: int = 1, all_default_mem_kind: bool = True, - override_lowering_rules: None | ( - tuple[tuple[core.Primitive, LoweringRule]]) = None, + lowering_parameters: LoweringParameters, ) -> LoweringResult: """Lowers a top-level jaxpr to an MLIR module. Handles the quirks of the argument/return value passing conventions of the runtime. """ - # TODO(necula): for now we receive the tuple of lowering platforms through - # the `platform` arg. For now we lower only for the first specified platform - # TODO(necula): change to "platforms" here and elsewhere. - if isinstance(platform, str): - platforms = (platform,) - else: - platforms = tuple(platform) # type: ignore - platform = platform[0] + if lowering_parameters.platforms is not None: + # Only for multi-platform lowering + # TODO(necula): for now we lower only for the first platform + platform = lowering_parameters.platforms[0] - platform = xb.canonicalize_platform(platform) + platform = xb.canonicalize_platform(platform) # type: ignore if not xb.is_known_platform(platform): raise ValueError(f"Unknown platform {platform}") input_output_aliases = None @@ -750,11 +768,16 @@ def lower_jaxpr_to_module( map(_to_logical_op_sharding, jaxpr.out_avals, result_shardings) if result_shardings is not None else result_shardings) - ctx = ModuleContext(backend_or_name, platform, axis_context, name_stack, - keepalives, channel_iter, host_callbacks, - override_lowering_rules=override_lowering_rules, - shape_poly_state=ShapePolyLoweringState(dim_vars, - platforms)) + ctx = ModuleContext(backend_or_name=backend_or_name, + platform=platform, axis_context=axis_context, + name_stack=name_stack, + keepalives=keepalives, + channel_iterator=channel_iter, + host_callbacks=host_callbacks, + lowering_parameters=lowering_parameters, + shape_poly_state=ShapePolyLoweringState( + dim_vars, + lowering_parameters.platforms)) with ctx.context, ir.Location.unknown(ctx.context): # Remove module name characters that XLA would alter. This ensures that # XLA computation preserves the module name. @@ -1292,9 +1315,9 @@ def write(v: core.Var, node: Sequence[ir.Value]): env[v] = tuple(node) def get_lowering(primitive: core.Primitive) -> LoweringRule | None: - if ctx.override_lowering_rules is None: + if ctx.lowering_parameters.override_lowering_rules is None: return None - for p, rule in ctx.override_lowering_rules: + for p, rule in ctx.lowering_parameters.override_lowering_rules: if primitive is p: return rule return None @@ -2187,7 +2210,8 @@ def build_xla_computation_helper( backend_or_name=backend_or_name, ordered_effects=[], name_stack=source_info_util.NameStack(), donated_args=[False] * len(closed_jaxpr.jaxpr.invars), - axis_context=axis_context, platform=platform) + axis_context=axis_context, platform=platform, + lowering_parameters=LoweringParameters()) return xc._xla.mlir.mlir_module_to_xla_computation( module_to_string(lowering_result.module), use_tuple_args=False, return_tuple=False) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index dc7e4d333f0c..cb13d8654ec2 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -560,7 +560,8 @@ def parallel_callable(fun: lu.WrappedFun, pmap_computation = lower_parallel_callable( fun, backend_name, axis_name, axis_size, global_axis_size, devices, name, in_axes, out_axes_thunk, donated_invars, - is_explicit_global_axis_size, avals, lowering_platform=None) + is_explicit_global_axis_size, avals, + lowering_parameters=mlir.LoweringParameters()) pmap_executable = pmap_computation.compile() return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint]) @@ -661,7 +662,7 @@ def lower_parallel_callable( is_explicit_global_axis_size: bool, avals: Sequence[core.AbstractValue], *, - lowering_platform: str | None): + lowering_parameters: mlir.LoweringParameters): # Determine global_axis_size for use in AxisEnv. # TODO(mattjj,skyewm): revive this check (inner_pmap always False now) # if xb.process_count() > 1 and global_axis_size is None and inner_pmap: @@ -755,18 +756,19 @@ def lower_parallel_callable( lowering_result = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, - ordered_effects, - backend, - lowering_platform or backend.platform, - sharding_impls.ReplicaAxisContext(axis_env), - name_stack, - donated_invars, + ordered_effects=ordered_effects, + backend_or_name=backend, + platform=lowering_parameters.override_platform or backend.platform, + axis_context=sharding_impls.ReplicaAxisContext(axis_env), + name_stack=name_stack, + donated_args=donated_invars, replicated_args=replicated_args, arg_shardings=None, result_shardings=None, arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names, result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths, - num_replicas=replicas.num_global_replicas) + num_replicas=replicas.num_global_replicas, + lowering_parameters=lowering_parameters) return PmapComputation(lowering_result.module, pci=pci, replicas=replicas, shards=shards, tuple_args=tuple_args, unordered_effects=unordered_effects, @@ -1784,9 +1786,9 @@ def _raise_warnings_or_errors_for_jit_of_pmap( @weakref_lru_cache def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, semantic_out_shardings, - da_object, lowering_platform, + da_object, donated_invars, name_stack, all_default_mem_kind, - override_lowering_rules): + lowering_parameters: mlir.LoweringParameters): jaxpr = closed_jaxpr.jaxpr in_shardings = semantic_in_shardings.shardings out_shardings = semantic_out_shardings.shardings @@ -1848,13 +1850,13 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, lowering_result = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, - ordered_effects, - backend, + ordered_effects=ordered_effects, + backend_or_name=backend, # Optionally, override the lowering platform - lowering_platform or backend.platform, - axis_ctx, - name_stack, - donated_invars, + platform=lowering_parameters.override_platform or backend.platform, + axis_context=axis_ctx, + name_stack=name_stack, + donated_args=donated_invars, replicated_args=replicated_args, arg_shardings=in_mlir_shardings, result_shardings=out_mlir_shardings, @@ -1863,7 +1865,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, num_replicas=nreps, num_partitions=num_partitions, all_default_mem_kind=all_default_mem_kind, - override_lowering_rules=override_lowering_rules) + lowering_parameters=lowering_parameters) tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform) unordered_effects = list( effects.ordered_effects.filter_not_in(closed_jaxpr.effects)) @@ -1969,9 +1971,7 @@ def lower_sharding_computation( keep_unused: bool, inline: bool, devices_from_context: Sequence[xc.Device] | None = None, - lowering_platform: str | None, - override_lowering_rules: None | ( - tuple[tuple[core.Primitive, mlir.LoweringRule]]) = None, + lowering_parameters: mlir.LoweringParameters, ) -> MeshComputation: """Lowers a computation to XLA. It can take arbitrary shardings as input. @@ -2048,8 +2048,9 @@ def lower_sharding_computation( (module, keepalive, host_callbacks, unordered_effects, ordered_effects, nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo( closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, - semantic_out_shardings, da_object, lowering_platform, - donated_invars, name_stack, all_default_mem_kind, override_lowering_rules) + semantic_out_shardings, da_object, + donated_invars, name_stack, all_default_mem_kind, + lowering_parameters=lowering_parameters) # backend and device_assignment is passed through to MeshExecutable because # if keep_unused=False and all in_shardings are pruned, then there is no way @@ -2108,7 +2109,7 @@ def lower_mesh_computation( spmd_lowering: bool, global_in_avals: Sequence[core.ShapedArray], tiling_method: TilingMethod | None, - lowering_platform: str | None) -> MeshComputation: + lowering_parameters: mlir.LoweringParameters) -> MeshComputation: assert not mesh.empty backend = xb.get_device_backend(mesh.devices.flat[0]) name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name)) @@ -2216,19 +2217,20 @@ def lower_mesh_computation( lowering_result = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, - ordered_effects, - backend, - lowering_platform or backend.platform, - axis_ctx, - name_stack, - donated_invars, + ordered_effects=ordered_effects, + backend_or_name=backend, + platform=lowering_parameters.platforms or backend.platform, + axis_context=axis_ctx, + name_stack=name_stack, + donated_args=donated_invars, replicated_args=replicated_args, arg_shardings=in_partitions, result_shardings=out_partitions, arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names, result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths, num_replicas=num_replicas, - num_partitions=num_partitions) + num_partitions=num_partitions, + lowering_parameters=lowering_parameters) return MeshComputation( str(name_stack), diff --git a/jax/_src/maps.py b/jax/_src/maps.py index 6d5ce6d6a362..f6a3cff4712d 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -605,8 +605,8 @@ def fun_mapped(*args): @decorate_serial def lower(*args, **kwargs): - _experimental_lowering_platform = kwargs.pop( - '_experimental_lowering_platform', None) + lowering_parameters = kwargs.pop( + '_experimental_lowering_platform', mlir.LoweringParameters()) fun_flat, args_flat, params, in_tree, out_tree = infer_params(*args) avals_flat = [shaped_abstractify(arg) for arg in args_flat] computation = make_xmap_callable( @@ -614,7 +614,7 @@ def lower(*args, **kwargs): params['donated_invars'], params['global_axis_sizes'], params['axis_resources'], params['resource_env'], params['backend'], params['spmd_in_axes'], params['spmd_out_axes_thunk'], - _experimental_lowering_platform, *avals_flat) + lowering_parameters, *avals_flat) in_tree = treedef_tuple([in_tree, tree_flatten({})[1]]) in_avals = in_tree.unflatten(avals_flat) @@ -633,7 +633,7 @@ def xmap_impl(fun: lu.WrappedFun, *args, name, in_axes, out_axes_thunk, donated_ fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes, axis_resources, resource_env, backend, spmd_in_axes, spmd_out_axes_thunk, - None, *in_avals).compile().unsafe_call + mlir.LoweringParameters(), *in_avals).compile().unsafe_call distributed_debug_log(("Running xmapped function", name), ("python function", fun.f), ("mesh", resource_env.physical_mesh), @@ -646,7 +646,7 @@ def make_xmap_callable(fun: lu.WrappedFun, in_axes, out_axes_thunk, donated_invars, global_axis_sizes, axis_resources, resource_env, backend, spmd_in_axes, spmd_out_axes_thunk, - lowering_platform: Optional[str], + lowering_parameters: mlir.LoweringParameters, *in_avals): plan = EvaluationPlan.from_axis_resources( axis_resources, resource_env, global_axis_sizes) @@ -700,11 +700,11 @@ def make_xmap_callable(fun: lu.WrappedFun, in_shardings, out_shardings, donated_invars, use_spmd_lowering, in_avals, tiling_method=tiling_method, - lowering_platform=lowering_platform) + lowering_parameters=lowering_parameters) else: return dispatch.sharded_lowering( f, name, donated_invars, True, False, in_avals, (None,) * len(in_avals), - lowering_platform=lowering_platform) + lowering_parameters=lowering_parameters) class EvaluationPlan(NamedTuple): diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 59b657fecc8a..a415a8c45413 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -325,10 +325,8 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames, @api_boundary def lower(*args, **kwargs): - _experimental_lowering_platform = kwargs.pop( - '_experimental_lowering_platform', None) - _experimental_override_lowering_rules = kwargs.pop( - '_experimental_override_lowering_rules', None) + lowering_parameters = kwargs.pop( + '_experimental_lowering_parameters', mlir.LoweringParameters()) (args_flat, flat_global_in_avals, params, in_tree, out_tree, donated_invars) = infer_params_fn(*args, **kwargs) resource_env = params['resource_env'] @@ -340,8 +338,7 @@ def lower(*args, **kwargs): params['jaxpr'], in_shardings, params['out_shardings'], params['resource_env'], params['donated_invars'], params['name'], params['keep_unused'], params['inline'], - lowering_platform=_experimental_lowering_platform, - override_lowering_rules=_experimental_override_lowering_rules) + lowering_parameters=lowering_parameters) except pxla.DeviceAssignmentMismatchError as e: fails, = e.args api_name = 'jit' if params['resource_env'] is None else 'pjit' @@ -1131,7 +1128,7 @@ def _pjit_call_impl_python( compiled = _pjit_lower( jaxpr, in_shardings, out_shardings, resource_env, donated_invars, name, keep_unused, inline, - lowering_platform=None).compile() + lowering_parameters=mlir.LoweringParameters()).compile() _most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled # This check is expensive so only do it if enable_checks is on. if compiled._auto_spmd_lowering and config.jax_enable_checks: @@ -1273,9 +1270,7 @@ def _pjit_lower_cached( keep_unused: bool, inline: bool, *, - lowering_platform: Optional[str], - override_lowering_rules: Optional[ - tuple[tuple[core.Primitive, mlir.LoweringRule]]] = None): + lowering_parameters: mlir.LoweringParameters): in_shardings: tuple[PjitShardingMinusUnspecified, ...] = cast( tuple[PjitShardingMinusUnspecified, ...], sdat_in_shardings.shardings) out_shardings: tuple[PjitSharding, ...] = sdat_out_shardings.shardings @@ -1298,7 +1293,7 @@ def _pjit_lower_cached( jaxpr, api_name, name, mesh, in_shardings, out_shardings, donated_invars, True, jaxpr.in_avals, tiling_method=None, - lowering_platform=lowering_platform) + lowering_parameters=lowering_parameters) else: return pxla.lower_sharding_computation( jaxpr, api_name, name, in_shardings, out_shardings, @@ -1306,8 +1301,7 @@ def _pjit_lower_cached( keep_unused=keep_unused, inline=inline, devices_from_context=( None if mesh is None or mesh.empty else list(mesh.devices.flat)), - lowering_platform=lowering_platform, - override_lowering_rules=override_lowering_rules, + lowering_parameters=lowering_parameters, ) diff --git a/jax/experimental/export/export.py b/jax/experimental/export/export.py index e0dcb434bf7b..5ed771b29a7b 100644 --- a/jax/experimental/export/export.py +++ b/jax/experimental/export/export.py @@ -421,7 +421,9 @@ def do_export(*args_specs, **kwargs_specs) -> Exported: shape_poly.thread_local_state.enable_shape_assertions = enable_shape_assertions lowered = wrapped_fun_jax.lower( *args_specs, **kwargs_specs, - _experimental_lowering_platform=lowering_platforms) + _experimental_lowering_parameters=mlir.LoweringParameters( + platforms=lowering_platforms, + )) lowering = lowered._lowering # type: ignore _check_lowering(lowering) @@ -601,9 +603,12 @@ def is_token(attrs): entry_block = new_main_op.add_entry_block() with ir.InsertionPoint(entry_block): module_context = mlir.ModuleContext( - "cpu", "cpu", sharding_impls.ShardingContext([]), - source_info_util.new_name_stack(), - [], itertools.count(1), [], module=wrapped_module, context=context) + backend_or_name="cpu", platform="cpu", + axis_context=sharding_impls.ShardingContext([]), + name_stack=source_info_util.new_name_stack(), + keepalives=[], channel_iterator=itertools.count(1), + host_callbacks=[], module=wrapped_module, context=context, + lowering_parameters=mlir.LoweringParameters()) ctx = mlir.LoweringRuleContext( module_context=module_context, primitive=None, avals_in=args_avals_flat, avals_out=None, diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 5fc2cabb9e2d..1804243063a4 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -16,6 +16,7 @@ AxisContext as AxisContext, ConstantHandler as ConstantHandler, DEVICE_TO_DEVICE_TYPE as DEVICE_TO_DEVICE_TYPE, + LoweringParameters as LoweringParameters, LoweringResult as LoweringResult, LoweringRule as LoweringRule, LoweringRuleContext as LoweringRuleContext, diff --git a/tests/api_test.py b/tests/api_test.py index 60b8beb60c4e..27eb2aa1a06a 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -10347,7 +10347,8 @@ def wsc_as_noop(ctx, operand, *args, **kwargs): lowered_ir = ( jax.jit(f) .lower(jax.ShapeDtypeStruct((2, 4), dtype=jnp.bfloat16), - _experimental_override_lowering_rules=rules).as_text()) + _experimental_lowering_parameters=mlir.LoweringParameters( + override_lowering_rules=rules)).as_text()) self.assertNotIn("stablehlo.custom_call @Sharding", lowered_ir) diff --git a/tests/export_test.py b/tests/export_test.py index 108ce21de842..907228abe8de 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -570,7 +570,7 @@ def test_multi_platform(self): x = np.arange(5, dtype=np.float32) # TODO: use a function with different behavior for different platforms exp = export.export(jnp.sin, - lowering_platforms=('cpu', 'tpu'))(x) + lowering_platforms=('cpu', 'tpu'))(x) self.assertEqual(exp.lowering_platforms, ('cpu', 'tpu')) module_str = str(exp.mlir_module()) platform_index = re.findall(