From d0a6813ea29ecf36ddd3444b298b4725612e3106 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 29 Aug 2023 08:49:30 -0700 Subject: [PATCH] Make mlir.custom_call() more general and expose it as jax.interpreters.mlir.custom_call(). This change is in preparation for deprecating the XlaBuilder APIs for building non-MLIR HLO. In general JAX would be best served by adding a more user-friendly "custom kernel" API that doesn't require the user to build IR directly, but for the moment the best we can do is migrate users to use MLIR/StableHLO utilities instead of classic HLO utilities. Since most users of custom kernels probably want to build a custom-call we can get most of the benefit by providing an ergonomic helper function for building the IR for custom calls that can be called by external primitive lowering rules. This function has two benefits over just building the stablehlo directly: a) it is a JAX API, and we can be more confident the API won't change because of upstream MLIR changes b) the Python API to build stablehlo.custom_call generated by the bindings isn't that easy to use (e.g. it doesn't have sensible defaults). Next step will be to deprecate XlaBuilder and encourage users to switch to lowering rules using this helper. PiperOrigin-RevId: 561042402 --- jax/_src/interpreters/mlir.py | 81 ++++++++++++++++++----- jax/_src/lax/ann.py | 4 +- jax/_src/lax/lax.py | 10 +-- jax/_src/lax/linalg.py | 16 ++--- jax/experimental/jax2tf/shape_poly.py | 4 +- jax/interpreters/mlir.py | 1 + tests/filecheck/custom_call.filecheck.py | 83 ++++++++++++++++++++++++ 7 files changed, 165 insertions(+), 34 deletions(-) create mode 100644 tests/filecheck/custom_call.filecheck.py diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index e60f0f4b87e1..6c4750209940 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1157,9 +1157,8 @@ def wrap_with_memory_kind( result_type = x.type else: result_type = aval_to_ir_type(aval_out) - op = custom_call("annotate_device_placement", [result_type], [x], - has_side_effect=False, - api_version=1) + op = custom_call("annotate_device_placement", result_types=[result_type], + operands=[x], api_version=1) mka = get_compute_type(memory_kind) dict_attr = {"_xla_compute_type": ir.StringAttr.get(mka)} if is_input and mka == 'host': @@ -1698,9 +1697,8 @@ def _wrap_with_spmd_op(name: str, else: result_shapes = [eval_dynamic_shape_as_tensor(ctx, out_shape)] - op = custom_call(name, [result_type], [x], + op = custom_call(name, result_types=[result_type], operands=[x], backend_config=backend_config, - has_side_effect=False, api_version=1, result_shapes=result_shapes) set_sharding(op, sharding_proto) @@ -2153,27 +2151,42 @@ def build_xla_computation_helper( def custom_call( call_target_name: str, - out_types: Sequence[ir.Type], - operands: Sequence[ir.Value], *, - backend_config: str | dict[str, ir.Attribute] = "", + result_types: Sequence[ir.Type], + operands: Sequence[ir.Value], + backend_config: str | bytes | dict[str, ir.Attribute] = "", has_side_effect: bool = False, result_shapes: Sequence[ir.Value] | None = None, called_computations: Sequence[str] = (), api_version: int = 2, - extra_attributes: dict[str, ir.Attribute] = {}, + operand_output_aliases: dict[int, int] | None = None, + operand_layouts: Sequence[Sequence[int]] | None = None, + result_layouts: Sequence[Sequence[int]] | None = None, + extra_attributes: dict[str, ir.Attribute] | None = None, ) -> ir.Operation: - """Wraps a hlo.CustomCall. + """Helper function for building an hlo.CustomCall. Args: + call_target_name: the name of the custom call target + result_types: the MLIR types of the results of the custom call + operands: the MLIR IR values that are arguments to the custom call + backend_config: an opaque string passed to the custom call kernel + has_side_effect: if True, marks the custom call as effectful result_shapes: tensors that represent the result shapes, to be used when the results have dynamic shapes. If not-None, its length must match the number of the results. called_computations: the list of function names called by the custom call. + api_version: the ABI contract version of the custom call + operand_output_aliases: a dict mapping operand numbers to outputs they alias + operand_layouts: a sequence of layouts (dimension orders) for each operand + result_layouts: a sequence of layouts (dimension orders) for each result + extra_attributes: additional IR attributes to apply to the custom_call. """ + operands = list(operands) + if backend_config is None: backend_config_attr = ir.StringAttr.get("") - elif isinstance(backend_config, str): + elif isinstance(backend_config, (str, bytes)): backend_config_attr = ir.StringAttr.get(backend_config) elif isinstance(backend_config, dict): # TODO(necula): it seems that the CustomCallOp constructor requires that @@ -2193,10 +2206,24 @@ def custom_call( has_side_effect=ir.BoolAttr.get(has_side_effect), backend_config=backend_config_attr, api_version=i32_attr(api_version), - called_computations=ir.ArrayAttr.get([ - ir.FlatSymbolRefAttr.get(name) for name in called_computations]), + called_computations=ir.ArrayAttr.get( + [ir.FlatSymbolRefAttr.get(name) for name in called_computations] + ), ) - attributes.update(extra_attributes) + if operand_output_aliases is not None: + attributes["output_operand_aliases"] = ir.ArrayAttr.get([ + hlo.OutputOperandAlias.get( + # if len(result_types) == 1 then the aliasing refers implicitly to + # the only output. + output_tuple_indices=[output_idx] if len(result_types) > 1 else [], + operand_index=input_idx, + operand_tuple_indices=[], + ) + for input_idx, output_idx in (operand_output_aliases.items() or ()) + ]) + + if extra_attributes is not None: + attributes.update(extra_attributes) if result_shapes is not None: # We add the result_shapes at the end of the operands, and must pass @@ -2205,9 +2232,29 @@ def custom_call( attributes["indices_of_shape_operands"] = ir.DenseIntElementsAttr.get( np.asarray(list(range(len(operands), len(operands) + len(result_shapes))), dtype=np.int64)) + if operand_layouts is not None: + assert len(operand_layouts) == len(operands), (operand_layouts, operands) + operand_layouts = list(operand_layouts) + [(0,)] * len(result_shapes) operands = list(operands) + list(result_shapes) - op = hlo.CustomCallOp.build_generic(results=out_types, operands=operands, attributes=attributes) + if operand_layouts is not None: + attributes["operand_layouts"] = ir.ArrayAttr.get([ + ir.DenseIntElementsAttr.get( + np.atleast_1d(np.asarray(l, dtype=np.int64)), + type=ir.IndexType.get()) for l in operand_layouts + ]) + if result_layouts is not None: + assert result_layouts is not None + assert len(result_layouts) == len(result_types), ( + result_layouts, result_types) + attributes["result_layouts"] = ir.ArrayAttr.get([ + ir.DenseIntElementsAttr.get( + np.atleast_1d(np.asarray(l, dtype=np.int64)), + type=ir.IndexType.get()) for l in result_layouts + ]) + + op = hlo.CustomCallOp.build_generic(results=result_types, operands=operands, + attributes=attributes) if isinstance(backend_config, dict): backend_config_attr = ir.DictAttr.get(backend_config) op.operation.attributes["mhlo.backend_config"] = backend_config_attr @@ -2251,8 +2298,8 @@ def prep_one_pad(pad_lo_hi: tuple[core.DimSize, core.DimSize]): rw = custom_call( "stablehlo.dynamic_reduce_window", - list(map(aval_to_ir_type, out_avals)), - [ + result_types=list(map(aval_to_ir_type, out_avals)), + operands=[ *operands, *init_values, eval_dynamic_shape_as_tensor(ctx, window_dimensions), eval_dynamic_shape_as_tensor(ctx, window_strides), diff --git a/jax/_src/lax/ann.py b/jax/_src/lax/ann.py index 2c59d4470caf..b00fb0fb30ea 100644 --- a/jax/_src/lax/ann.py +++ b/jax/_src/lax/ann.py @@ -347,8 +347,8 @@ def _approx_top_k_lowering(ctx, operand, *, k, out = mlir.custom_call( "ApproxTopK", - [mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], - [operand, iota, init_val, init_arg], + result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], + operands=[operand, iota, init_val, init_arg], called_computations=[comparator.name.value], backend_config=backend_config, result_shapes=result_shapes) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 3dc7566adc62..64621f99c742 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4233,9 +4233,9 @@ def _top_k_lower(ctx, operand, k): out_values_aval, out_indices_aval, = ctx.avals_out return mlir.custom_call( "stablehlo.dynamic_top_k", - [mlir.aval_to_ir_type(out_values_aval), + result_types=[mlir.aval_to_ir_type(out_values_aval), mlir.aval_to_ir_type(out_indices_aval)], - [operand, k_value]).results + operands=[operand, k_value]).results mlir.register_lowering(top_k_p, _top_k_lower) ad.primitive_jvps[top_k_p] = _top_k_jvp @@ -4499,9 +4499,9 @@ def _rng_bit_generator_lowering( mlir.eval_dynamic_shape(ctx, out_vals_aval.shape)) out_key, out_vals = mlir.custom_call( "stablehlo.dynamic_rng_bit_generator", - [key.type, - mlir.aval_to_ir_type(core.ShapedArray(shape, rbg_dtype))], - [key, output_shape], + result_types=[key.type, + mlir.aval_to_ir_type(core.ShapedArray(shape, rbg_dtype))], + operands=[key, output_shape], extra_attributes=dict(rng_algorithm=algorithm_attr)).results else: out_key, out_vals = hlo.RngBitGeneratorOp( diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index e75a2f573a93..907fe6cf45cd 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -644,8 +644,8 @@ def _eigh_jacobi_lowering_rule(ctx, operand, lower, sort_eigenvalues): result_shapes = None op = mlir.custom_call( "Eigh", - result_types, - [operand], + result_types=result_types, + operands=[operand], backend_config=backend_config, api_version=1, result_shapes=result_shapes, @@ -1301,8 +1301,8 @@ def _lu_tpu_lowering_rule(ctx, operand): result_shapes = None op = mlir.custom_call( "LuDecomposition", - result_types, - [operand], + result_types=result_types, + operands=[operand], result_shapes=result_shapes) return op.results @@ -1436,8 +1436,8 @@ def _geqrf_lowering_rule(ctx, operand): result_shapes = None op = mlir.custom_call( "Qr", - result_types, - [operand], + result_types=result_types, + operands=[operand], api_version=1, result_shapes=result_shapes ) @@ -1561,8 +1561,8 @@ def _householder_product_lowering_rule(ctx, a, taus): result_shapes = None op = mlir.custom_call( "ProductOfElementaryHouseholderReflectors", - [mlir.aval_to_ir_type(aval_out)], - [a, taus], + result_types=[mlir.aval_to_ir_type(aval_out)], + operands=[a, taus], api_version=1, result_shapes=result_shapes) return [op.result] diff --git a/jax/experimental/jax2tf/shape_poly.py b/jax/experimental/jax2tf/shape_poly.py index d97b5d6e3086..a66497b7413a 100644 --- a/jax/experimental/jax2tf/shape_poly.py +++ b/jax/experimental/jax2tf/shape_poly.py @@ -830,8 +830,8 @@ def _shape_assertion_lowering_rule(ctx: mlir.LoweringRuleContext, error_message: str): op = mlir.custom_call( "shape_assertion", - [], # No results - [assert_what, *error_message_inputs], + result_types=[], # No results + operands=[assert_what, *error_message_inputs], has_side_effect=True, extra_attributes=dict(error_message=mlir.ir.StringAttr.get(error_message)) ) diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 0d31aab5a5c1..5fc2cabb9e2d 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -31,6 +31,7 @@ aval_to_ir_type as aval_to_ir_type, aval_to_ir_types as aval_to_ir_types, core_call_lowering as core_call_lowering, + custom_call as custom_call, dense_bool_elements as dense_bool_elements, dense_int_elements as dense_int_elements, dtype_to_ir_type as dtype_to_ir_type, diff --git a/tests/filecheck/custom_call.filecheck.py b/tests/filecheck/custom_call.filecheck.py new file mode 100644 index 000000000000..1f33f2391b7c --- /dev/null +++ b/tests/filecheck/custom_call.filecheck.py @@ -0,0 +1,83 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Tests for mlir.custom_call(). + +# RUN: %PYTHON %s | FileCheck %s + +from absl import app + +import jax +from jax.interpreters import mlir +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import func as func_dialect +import numpy as np + +ShapedArray = jax.core.ShapedArray + +def print_custom_call(name, arg_avals, result_avals, **kw): + print(f"TEST: {name}") + ctx = mlir.make_ir_context() + loc = ir.Location.unknown(context=ctx) + with ctx, loc: + module = ir.Module.create(loc=ir.Location.unknown()) + ip = ir.InsertionPoint(module.body) + arg_types = [mlir.aval_to_ir_type(aval) for aval in arg_avals] + result_types = [mlir.aval_to_ir_type(aval) for aval in result_avals] + ftype = ir.FunctionType.get(arg_types, result_types) + func = func_dialect.FuncOp("func", ftype, ip=ip) + entry_block = func.add_entry_block() + with ir.InsertionPoint(entry_block): + outs = mlir.custom_call( + name, result_types=result_types, operands=entry_block.arguments, **kw + ) + func_dialect.ReturnOp(outs) + module.operation.verify() + print(str(module)) + +def main(_): + aval1 = ShapedArray((2, 3), np.dtype(np.float32)) + aval2 = ShapedArray((3, 4), np.dtype(np.int64)) + # CHECK-LABEL: TEST: simple + # CHECK: stablehlo.custom_call @simple(%arg0) {api_version = 2 : i32} : (tensor<2x3xf32>) -> tensor<3x4xi64> + print_custom_call("simple", [aval1], [aval2]) + + # CHECK-LABEL: TEST: sideeffect + # CHECK: stablehlo.custom_call @sideeffect(%arg0) {has_side_effect = true} : (tensor<2x3xf32>) -> tensor<3x4xi64> + print_custom_call("sideeffect", [aval1], [aval2], api_version=1, + has_side_effect=True) + + # CHECK-LABEL: TEST: backendconfig + # CHECK: stablehlo.custom_call @backendconfig(%arg0) {backend_config = "hello"} : (tensor<2x3xf32>) -> tensor<3x4xi64> + print_custom_call("backendconfig", [aval1], [aval2], api_version=1, + backend_config=b"hello") + + # CHECK-LABEL: TEST: calledcomputations + # CHECK: stablehlo.custom_call @calledcomputations(%arg0) {called_computations = [@a, @b]} : (tensor<2x3xf32>) -> tensor<3x4xi64> + print_custom_call("calledcomputations", [aval1], [aval2], api_version=1, + called_computations=["a", "b"]) + + # CHECK-LABEL: TEST: aliases + # CHECK: stablehlo.custom_call @aliases(%arg0, %arg1) {output_operand_aliases = [#stablehlo.output_operand_alias]} : (tensor<2x3xf32>, tensor<3x4xi64>) -> (tensor<3x4xi64>, tensor<2x3xf32>) + print_custom_call("aliases", [aval1, aval2], [aval2, aval1], api_version=1, + operand_output_aliases={1: 0}) + + # CHECK-LABEL: TEST: layouts + # CHECK: stablehlo.custom_call @layouts(%arg0, %arg1) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>]} : (tensor<2x3xf32>, tensor<3x4xi64>) -> (tensor<3x4xi64>, tensor<2x3xf32>) + print_custom_call("layouts", [aval1, aval2], [aval2, aval1], api_version=1, + operand_layouts=[[0, 1], [1, 0]], + result_layouts=[[1, 0], [0, 1]]) + +if __name__ == "__main__": + app.run(main)