diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index e60f0f4b87e1..6c4750209940 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1157,9 +1157,8 @@ def wrap_with_memory_kind( result_type = x.type else: result_type = aval_to_ir_type(aval_out) - op = custom_call("annotate_device_placement", [result_type], [x], - has_side_effect=False, - api_version=1) + op = custom_call("annotate_device_placement", result_types=[result_type], + operands=[x], api_version=1) mka = get_compute_type(memory_kind) dict_attr = {"_xla_compute_type": ir.StringAttr.get(mka)} if is_input and mka == 'host': @@ -1698,9 +1697,8 @@ def _wrap_with_spmd_op(name: str, else: result_shapes = [eval_dynamic_shape_as_tensor(ctx, out_shape)] - op = custom_call(name, [result_type], [x], + op = custom_call(name, result_types=[result_type], operands=[x], backend_config=backend_config, - has_side_effect=False, api_version=1, result_shapes=result_shapes) set_sharding(op, sharding_proto) @@ -2153,27 +2151,42 @@ def build_xla_computation_helper( def custom_call( call_target_name: str, - out_types: Sequence[ir.Type], - operands: Sequence[ir.Value], *, - backend_config: str | dict[str, ir.Attribute] = "", + result_types: Sequence[ir.Type], + operands: Sequence[ir.Value], + backend_config: str | bytes | dict[str, ir.Attribute] = "", has_side_effect: bool = False, result_shapes: Sequence[ir.Value] | None = None, called_computations: Sequence[str] = (), api_version: int = 2, - extra_attributes: dict[str, ir.Attribute] = {}, + operand_output_aliases: dict[int, int] | None = None, + operand_layouts: Sequence[Sequence[int]] | None = None, + result_layouts: Sequence[Sequence[int]] | None = None, + extra_attributes: dict[str, ir.Attribute] | None = None, ) -> ir.Operation: - """Wraps a hlo.CustomCall. + """Helper function for building an hlo.CustomCall. Args: + call_target_name: the name of the custom call target + result_types: the MLIR types of the results of the custom call + operands: the MLIR IR values that are arguments to the custom call + backend_config: an opaque string passed to the custom call kernel + has_side_effect: if True, marks the custom call as effectful result_shapes: tensors that represent the result shapes, to be used when the results have dynamic shapes. If not-None, its length must match the number of the results. called_computations: the list of function names called by the custom call. + api_version: the ABI contract version of the custom call + operand_output_aliases: a dict mapping operand numbers to outputs they alias + operand_layouts: a sequence of layouts (dimension orders) for each operand + result_layouts: a sequence of layouts (dimension orders) for each result + extra_attributes: additional IR attributes to apply to the custom_call. """ + operands = list(operands) + if backend_config is None: backend_config_attr = ir.StringAttr.get("") - elif isinstance(backend_config, str): + elif isinstance(backend_config, (str, bytes)): backend_config_attr = ir.StringAttr.get(backend_config) elif isinstance(backend_config, dict): # TODO(necula): it seems that the CustomCallOp constructor requires that @@ -2193,10 +2206,24 @@ def custom_call( has_side_effect=ir.BoolAttr.get(has_side_effect), backend_config=backend_config_attr, api_version=i32_attr(api_version), - called_computations=ir.ArrayAttr.get([ - ir.FlatSymbolRefAttr.get(name) for name in called_computations]), + called_computations=ir.ArrayAttr.get( + [ir.FlatSymbolRefAttr.get(name) for name in called_computations] + ), ) - attributes.update(extra_attributes) + if operand_output_aliases is not None: + attributes["output_operand_aliases"] = ir.ArrayAttr.get([ + hlo.OutputOperandAlias.get( + # if len(result_types) == 1 then the aliasing refers implicitly to + # the only output. + output_tuple_indices=[output_idx] if len(result_types) > 1 else [], + operand_index=input_idx, + operand_tuple_indices=[], + ) + for input_idx, output_idx in (operand_output_aliases.items() or ()) + ]) + + if extra_attributes is not None: + attributes.update(extra_attributes) if result_shapes is not None: # We add the result_shapes at the end of the operands, and must pass @@ -2205,9 +2232,29 @@ def custom_call( attributes["indices_of_shape_operands"] = ir.DenseIntElementsAttr.get( np.asarray(list(range(len(operands), len(operands) + len(result_shapes))), dtype=np.int64)) + if operand_layouts is not None: + assert len(operand_layouts) == len(operands), (operand_layouts, operands) + operand_layouts = list(operand_layouts) + [(0,)] * len(result_shapes) operands = list(operands) + list(result_shapes) - op = hlo.CustomCallOp.build_generic(results=out_types, operands=operands, attributes=attributes) + if operand_layouts is not None: + attributes["operand_layouts"] = ir.ArrayAttr.get([ + ir.DenseIntElementsAttr.get( + np.atleast_1d(np.asarray(l, dtype=np.int64)), + type=ir.IndexType.get()) for l in operand_layouts + ]) + if result_layouts is not None: + assert result_layouts is not None + assert len(result_layouts) == len(result_types), ( + result_layouts, result_types) + attributes["result_layouts"] = ir.ArrayAttr.get([ + ir.DenseIntElementsAttr.get( + np.atleast_1d(np.asarray(l, dtype=np.int64)), + type=ir.IndexType.get()) for l in result_layouts + ]) + + op = hlo.CustomCallOp.build_generic(results=result_types, operands=operands, + attributes=attributes) if isinstance(backend_config, dict): backend_config_attr = ir.DictAttr.get(backend_config) op.operation.attributes["mhlo.backend_config"] = backend_config_attr @@ -2251,8 +2298,8 @@ def prep_one_pad(pad_lo_hi: tuple[core.DimSize, core.DimSize]): rw = custom_call( "stablehlo.dynamic_reduce_window", - list(map(aval_to_ir_type, out_avals)), - [ + result_types=list(map(aval_to_ir_type, out_avals)), + operands=[ *operands, *init_values, eval_dynamic_shape_as_tensor(ctx, window_dimensions), eval_dynamic_shape_as_tensor(ctx, window_strides), diff --git a/jax/_src/lax/ann.py b/jax/_src/lax/ann.py index 2c59d4470caf..b00fb0fb30ea 100644 --- a/jax/_src/lax/ann.py +++ b/jax/_src/lax/ann.py @@ -347,8 +347,8 @@ def _approx_top_k_lowering(ctx, operand, *, k, out = mlir.custom_call( "ApproxTopK", - [mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], - [operand, iota, init_val, init_arg], + result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], + operands=[operand, iota, init_val, init_arg], called_computations=[comparator.name.value], backend_config=backend_config, result_shapes=result_shapes) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 3dc7566adc62..64621f99c742 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4233,9 +4233,9 @@ def _top_k_lower(ctx, operand, k): out_values_aval, out_indices_aval, = ctx.avals_out return mlir.custom_call( "stablehlo.dynamic_top_k", - [mlir.aval_to_ir_type(out_values_aval), + result_types=[mlir.aval_to_ir_type(out_values_aval), mlir.aval_to_ir_type(out_indices_aval)], - [operand, k_value]).results + operands=[operand, k_value]).results mlir.register_lowering(top_k_p, _top_k_lower) ad.primitive_jvps[top_k_p] = _top_k_jvp @@ -4499,9 +4499,9 @@ def _rng_bit_generator_lowering( mlir.eval_dynamic_shape(ctx, out_vals_aval.shape)) out_key, out_vals = mlir.custom_call( "stablehlo.dynamic_rng_bit_generator", - [key.type, - mlir.aval_to_ir_type(core.ShapedArray(shape, rbg_dtype))], - [key, output_shape], + result_types=[key.type, + mlir.aval_to_ir_type(core.ShapedArray(shape, rbg_dtype))], + operands=[key, output_shape], extra_attributes=dict(rng_algorithm=algorithm_attr)).results else: out_key, out_vals = hlo.RngBitGeneratorOp( diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index e75a2f573a93..907fe6cf45cd 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -644,8 +644,8 @@ def _eigh_jacobi_lowering_rule(ctx, operand, lower, sort_eigenvalues): result_shapes = None op = mlir.custom_call( "Eigh", - result_types, - [operand], + result_types=result_types, + operands=[operand], backend_config=backend_config, api_version=1, result_shapes=result_shapes, @@ -1301,8 +1301,8 @@ def _lu_tpu_lowering_rule(ctx, operand): result_shapes = None op = mlir.custom_call( "LuDecomposition", - result_types, - [operand], + result_types=result_types, + operands=[operand], result_shapes=result_shapes) return op.results @@ -1436,8 +1436,8 @@ def _geqrf_lowering_rule(ctx, operand): result_shapes = None op = mlir.custom_call( "Qr", - result_types, - [operand], + result_types=result_types, + operands=[operand], api_version=1, result_shapes=result_shapes ) @@ -1561,8 +1561,8 @@ def _householder_product_lowering_rule(ctx, a, taus): result_shapes = None op = mlir.custom_call( "ProductOfElementaryHouseholderReflectors", - [mlir.aval_to_ir_type(aval_out)], - [a, taus], + result_types=[mlir.aval_to_ir_type(aval_out)], + operands=[a, taus], api_version=1, result_shapes=result_shapes) return [op.result] diff --git a/jax/experimental/jax2tf/shape_poly.py b/jax/experimental/jax2tf/shape_poly.py index d97b5d6e3086..a66497b7413a 100644 --- a/jax/experimental/jax2tf/shape_poly.py +++ b/jax/experimental/jax2tf/shape_poly.py @@ -830,8 +830,8 @@ def _shape_assertion_lowering_rule(ctx: mlir.LoweringRuleContext, error_message: str): op = mlir.custom_call( "shape_assertion", - [], # No results - [assert_what, *error_message_inputs], + result_types=[], # No results + operands=[assert_what, *error_message_inputs], has_side_effect=True, extra_attributes=dict(error_message=mlir.ir.StringAttr.get(error_message)) ) diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 0d31aab5a5c1..5fc2cabb9e2d 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -31,6 +31,7 @@ aval_to_ir_type as aval_to_ir_type, aval_to_ir_types as aval_to_ir_types, core_call_lowering as core_call_lowering, + custom_call as custom_call, dense_bool_elements as dense_bool_elements, dense_int_elements as dense_int_elements, dtype_to_ir_type as dtype_to_ir_type, diff --git a/tests/filecheck/custom_call.filecheck.py b/tests/filecheck/custom_call.filecheck.py new file mode 100644 index 000000000000..1f33f2391b7c --- /dev/null +++ b/tests/filecheck/custom_call.filecheck.py @@ -0,0 +1,83 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Tests for mlir.custom_call(). + +# RUN: %PYTHON %s | FileCheck %s + +from absl import app + +import jax +from jax.interpreters import mlir +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import func as func_dialect +import numpy as np + +ShapedArray = jax.core.ShapedArray + +def print_custom_call(name, arg_avals, result_avals, **kw): + print(f"TEST: {name}") + ctx = mlir.make_ir_context() + loc = ir.Location.unknown(context=ctx) + with ctx, loc: + module = ir.Module.create(loc=ir.Location.unknown()) + ip = ir.InsertionPoint(module.body) + arg_types = [mlir.aval_to_ir_type(aval) for aval in arg_avals] + result_types = [mlir.aval_to_ir_type(aval) for aval in result_avals] + ftype = ir.FunctionType.get(arg_types, result_types) + func = func_dialect.FuncOp("func", ftype, ip=ip) + entry_block = func.add_entry_block() + with ir.InsertionPoint(entry_block): + outs = mlir.custom_call( + name, result_types=result_types, operands=entry_block.arguments, **kw + ) + func_dialect.ReturnOp(outs) + module.operation.verify() + print(str(module)) + +def main(_): + aval1 = ShapedArray((2, 3), np.dtype(np.float32)) + aval2 = ShapedArray((3, 4), np.dtype(np.int64)) + # CHECK-LABEL: TEST: simple + # CHECK: stablehlo.custom_call @simple(%arg0) {api_version = 2 : i32} : (tensor<2x3xf32>) -> tensor<3x4xi64> + print_custom_call("simple", [aval1], [aval2]) + + # CHECK-LABEL: TEST: sideeffect + # CHECK: stablehlo.custom_call @sideeffect(%arg0) {has_side_effect = true} : (tensor<2x3xf32>) -> tensor<3x4xi64> + print_custom_call("sideeffect", [aval1], [aval2], api_version=1, + has_side_effect=True) + + # CHECK-LABEL: TEST: backendconfig + # CHECK: stablehlo.custom_call @backendconfig(%arg0) {backend_config = "hello"} : (tensor<2x3xf32>) -> tensor<3x4xi64> + print_custom_call("backendconfig", [aval1], [aval2], api_version=1, + backend_config=b"hello") + + # CHECK-LABEL: TEST: calledcomputations + # CHECK: stablehlo.custom_call @calledcomputations(%arg0) {called_computations = [@a, @b]} : (tensor<2x3xf32>) -> tensor<3x4xi64> + print_custom_call("calledcomputations", [aval1], [aval2], api_version=1, + called_computations=["a", "b"]) + + # CHECK-LABEL: TEST: aliases + # CHECK: stablehlo.custom_call @aliases(%arg0, %arg1) {output_operand_aliases = [#stablehlo.output_operand_alias]} : (tensor<2x3xf32>, tensor<3x4xi64>) -> (tensor<3x4xi64>, tensor<2x3xf32>) + print_custom_call("aliases", [aval1, aval2], [aval2, aval1], api_version=1, + operand_output_aliases={1: 0}) + + # CHECK-LABEL: TEST: layouts + # CHECK: stablehlo.custom_call @layouts(%arg0, %arg1) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>]} : (tensor<2x3xf32>, tensor<3x4xi64>) -> (tensor<3x4xi64>, tensor<2x3xf32>) + print_custom_call("layouts", [aval1, aval2], [aval2, aval1], api_version=1, + operand_layouts=[[0, 1], [1, 0]], + result_layouts=[[1, 0], [0, 1]]) + +if __name__ == "__main__": + app.run(main)