Skip to content

Commit

Permalink
Make mlir.custom_call() more general and expose it as jax.interpreter…
Browse files Browse the repository at this point in the history
…s.mlir.custom_call().

This change is in preparation for deprecating the XlaBuilder APIs for building non-MLIR HLO. In general JAX would be best served by adding a more user-friendly "custom kernel" API that doesn't require the user to build IR directly, but for the moment the best we can do is migrate users to use MLIR/StableHLO utilities instead of classic HLO utilities.

Since most users of custom kernels probably want to build a custom-call we can get most of the benefit by providing an ergonomic helper function for building the IR for custom calls that can be called by external primitive lowering rules.

This function has two benefits over just building the stablehlo directly:
a) it is a JAX API, and we can be more confident the API won't change because of upstream MLIR changes
b) the Python API to build stablehlo.custom_call generated by the bindings isn't that easy to use (e.g. it doesn't have sensible defaults).

Next step will be to deprecate XlaBuilder and encourage users to switch to lowering rules using this helper.

PiperOrigin-RevId: 561042402
  • Loading branch information
hawkinsp authored and jax authors committed Aug 29, 2023
1 parent ba8c4c3 commit d0a6813
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 34 deletions.
81 changes: 64 additions & 17 deletions jax/_src/interpreters/mlir.py
Expand Up @@ -1157,9 +1157,8 @@ def wrap_with_memory_kind(
result_type = x.type
else:
result_type = aval_to_ir_type(aval_out)
op = custom_call("annotate_device_placement", [result_type], [x],
has_side_effect=False,
api_version=1)
op = custom_call("annotate_device_placement", result_types=[result_type],
operands=[x], api_version=1)
mka = get_compute_type(memory_kind)
dict_attr = {"_xla_compute_type": ir.StringAttr.get(mka)}
if is_input and mka == 'host':
Expand Down Expand Up @@ -1698,9 +1697,8 @@ def _wrap_with_spmd_op(name: str,
else:
result_shapes = [eval_dynamic_shape_as_tensor(ctx, out_shape)]

op = custom_call(name, [result_type], [x],
op = custom_call(name, result_types=[result_type], operands=[x],
backend_config=backend_config,
has_side_effect=False,
api_version=1,
result_shapes=result_shapes)
set_sharding(op, sharding_proto)
Expand Down Expand Up @@ -2153,27 +2151,42 @@ def build_xla_computation_helper(

def custom_call(
call_target_name: str,
out_types: Sequence[ir.Type],
operands: Sequence[ir.Value],
*,
backend_config: str | dict[str, ir.Attribute] = "",
result_types: Sequence[ir.Type],
operands: Sequence[ir.Value],
backend_config: str | bytes | dict[str, ir.Attribute] = "",
has_side_effect: bool = False,
result_shapes: Sequence[ir.Value] | None = None,
called_computations: Sequence[str] = (),
api_version: int = 2,
extra_attributes: dict[str, ir.Attribute] = {},
operand_output_aliases: dict[int, int] | None = None,
operand_layouts: Sequence[Sequence[int]] | None = None,
result_layouts: Sequence[Sequence[int]] | None = None,
extra_attributes: dict[str, ir.Attribute] | None = None,
) -> ir.Operation:
"""Wraps a hlo.CustomCall.
"""Helper function for building an hlo.CustomCall.
Args:
call_target_name: the name of the custom call target
result_types: the MLIR types of the results of the custom call
operands: the MLIR IR values that are arguments to the custom call
backend_config: an opaque string passed to the custom call kernel
has_side_effect: if True, marks the custom call as effectful
result_shapes: tensors that represent the result shapes, to be used when
the results have dynamic shapes. If not-None, its length must match the
number of the results.
called_computations: the list of function names called by the custom call.
api_version: the ABI contract version of the custom call
operand_output_aliases: a dict mapping operand numbers to outputs they alias
operand_layouts: a sequence of layouts (dimension orders) for each operand
result_layouts: a sequence of layouts (dimension orders) for each result
extra_attributes: additional IR attributes to apply to the custom_call.
"""
operands = list(operands)

if backend_config is None:
backend_config_attr = ir.StringAttr.get("")
elif isinstance(backend_config, str):
elif isinstance(backend_config, (str, bytes)):
backend_config_attr = ir.StringAttr.get(backend_config)
elif isinstance(backend_config, dict):
# TODO(necula): it seems that the CustomCallOp constructor requires that
Expand All @@ -2193,10 +2206,24 @@ def custom_call(
has_side_effect=ir.BoolAttr.get(has_side_effect),
backend_config=backend_config_attr,
api_version=i32_attr(api_version),
called_computations=ir.ArrayAttr.get([
ir.FlatSymbolRefAttr.get(name) for name in called_computations]),
called_computations=ir.ArrayAttr.get(
[ir.FlatSymbolRefAttr.get(name) for name in called_computations]
),
)
attributes.update(extra_attributes)
if operand_output_aliases is not None:
attributes["output_operand_aliases"] = ir.ArrayAttr.get([
hlo.OutputOperandAlias.get(
# if len(result_types) == 1 then the aliasing refers implicitly to
# the only output.
output_tuple_indices=[output_idx] if len(result_types) > 1 else [],
operand_index=input_idx,
operand_tuple_indices=[],
)
for input_idx, output_idx in (operand_output_aliases.items() or ())
])

if extra_attributes is not None:
attributes.update(extra_attributes)

if result_shapes is not None:
# We add the result_shapes at the end of the operands, and must pass
Expand All @@ -2205,9 +2232,29 @@ def custom_call(
attributes["indices_of_shape_operands"] = ir.DenseIntElementsAttr.get(
np.asarray(list(range(len(operands), len(operands) + len(result_shapes))),
dtype=np.int64))
if operand_layouts is not None:
assert len(operand_layouts) == len(operands), (operand_layouts, operands)
operand_layouts = list(operand_layouts) + [(0,)] * len(result_shapes)
operands = list(operands) + list(result_shapes)

op = hlo.CustomCallOp.build_generic(results=out_types, operands=operands, attributes=attributes)
if operand_layouts is not None:
attributes["operand_layouts"] = ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(
np.atleast_1d(np.asarray(l, dtype=np.int64)),
type=ir.IndexType.get()) for l in operand_layouts
])
if result_layouts is not None:
assert result_layouts is not None
assert len(result_layouts) == len(result_types), (
result_layouts, result_types)
attributes["result_layouts"] = ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(
np.atleast_1d(np.asarray(l, dtype=np.int64)),
type=ir.IndexType.get()) for l in result_layouts
])

op = hlo.CustomCallOp.build_generic(results=result_types, operands=operands,
attributes=attributes)
if isinstance(backend_config, dict):
backend_config_attr = ir.DictAttr.get(backend_config)
op.operation.attributes["mhlo.backend_config"] = backend_config_attr
Expand Down Expand Up @@ -2251,8 +2298,8 @@ def prep_one_pad(pad_lo_hi: tuple[core.DimSize, core.DimSize]):

rw = custom_call(
"stablehlo.dynamic_reduce_window",
list(map(aval_to_ir_type, out_avals)),
[
result_types=list(map(aval_to_ir_type, out_avals)),
operands=[
*operands, *init_values,
eval_dynamic_shape_as_tensor(ctx, window_dimensions),
eval_dynamic_shape_as_tensor(ctx, window_strides),
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/ann.py
Expand Up @@ -347,8 +347,8 @@ def _approx_top_k_lowering(ctx, operand, *, k,

out = mlir.custom_call(
"ApproxTopK",
[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
[operand, iota, init_val, init_arg],
result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
operands=[operand, iota, init_val, init_arg],
called_computations=[comparator.name.value],
backend_config=backend_config,
result_shapes=result_shapes)
Expand Down
10 changes: 5 additions & 5 deletions jax/_src/lax/lax.py
Expand Up @@ -4233,9 +4233,9 @@ def _top_k_lower(ctx, operand, k):
out_values_aval, out_indices_aval, = ctx.avals_out
return mlir.custom_call(
"stablehlo.dynamic_top_k",
[mlir.aval_to_ir_type(out_values_aval),
result_types=[mlir.aval_to_ir_type(out_values_aval),
mlir.aval_to_ir_type(out_indices_aval)],
[operand, k_value]).results
operands=[operand, k_value]).results

mlir.register_lowering(top_k_p, _top_k_lower)
ad.primitive_jvps[top_k_p] = _top_k_jvp
Expand Down Expand Up @@ -4499,9 +4499,9 @@ def _rng_bit_generator_lowering(
mlir.eval_dynamic_shape(ctx, out_vals_aval.shape))
out_key, out_vals = mlir.custom_call(
"stablehlo.dynamic_rng_bit_generator",
[key.type,
mlir.aval_to_ir_type(core.ShapedArray(shape, rbg_dtype))],
[key, output_shape],
result_types=[key.type,
mlir.aval_to_ir_type(core.ShapedArray(shape, rbg_dtype))],
operands=[key, output_shape],
extra_attributes=dict(rng_algorithm=algorithm_attr)).results
else:
out_key, out_vals = hlo.RngBitGeneratorOp(
Expand Down
16 changes: 8 additions & 8 deletions jax/_src/lax/linalg.py
Expand Up @@ -644,8 +644,8 @@ def _eigh_jacobi_lowering_rule(ctx, operand, lower, sort_eigenvalues):
result_shapes = None
op = mlir.custom_call(
"Eigh",
result_types,
[operand],
result_types=result_types,
operands=[operand],
backend_config=backend_config,
api_version=1,
result_shapes=result_shapes,
Expand Down Expand Up @@ -1301,8 +1301,8 @@ def _lu_tpu_lowering_rule(ctx, operand):
result_shapes = None
op = mlir.custom_call(
"LuDecomposition",
result_types,
[operand],
result_types=result_types,
operands=[operand],
result_shapes=result_shapes)
return op.results

Expand Down Expand Up @@ -1436,8 +1436,8 @@ def _geqrf_lowering_rule(ctx, operand):
result_shapes = None
op = mlir.custom_call(
"Qr",
result_types,
[operand],
result_types=result_types,
operands=[operand],
api_version=1,
result_shapes=result_shapes
)
Expand Down Expand Up @@ -1561,8 +1561,8 @@ def _householder_product_lowering_rule(ctx, a, taus):
result_shapes = None
op = mlir.custom_call(
"ProductOfElementaryHouseholderReflectors",
[mlir.aval_to_ir_type(aval_out)],
[a, taus],
result_types=[mlir.aval_to_ir_type(aval_out)],
operands=[a, taus],
api_version=1,
result_shapes=result_shapes)
return [op.result]
Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/jax2tf/shape_poly.py
Expand Up @@ -830,8 +830,8 @@ def _shape_assertion_lowering_rule(ctx: mlir.LoweringRuleContext,
error_message: str):
op = mlir.custom_call(
"shape_assertion",
[], # No results
[assert_what, *error_message_inputs],
result_types=[], # No results
operands=[assert_what, *error_message_inputs],
has_side_effect=True,
extra_attributes=dict(error_message=mlir.ir.StringAttr.get(error_message))
)
Expand Down
1 change: 1 addition & 0 deletions jax/interpreters/mlir.py
Expand Up @@ -31,6 +31,7 @@
aval_to_ir_type as aval_to_ir_type,
aval_to_ir_types as aval_to_ir_types,
core_call_lowering as core_call_lowering,
custom_call as custom_call,
dense_bool_elements as dense_bool_elements,
dense_int_elements as dense_int_elements,
dtype_to_ir_type as dtype_to_ir_type,
Expand Down
83 changes: 83 additions & 0 deletions tests/filecheck/custom_call.filecheck.py
@@ -0,0 +1,83 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Tests for mlir.custom_call().

# RUN: %PYTHON %s | FileCheck %s

from absl import app

import jax
from jax.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
import numpy as np

ShapedArray = jax.core.ShapedArray

def print_custom_call(name, arg_avals, result_avals, **kw):
print(f"TEST: {name}")
ctx = mlir.make_ir_context()
loc = ir.Location.unknown(context=ctx)
with ctx, loc:
module = ir.Module.create(loc=ir.Location.unknown())
ip = ir.InsertionPoint(module.body)
arg_types = [mlir.aval_to_ir_type(aval) for aval in arg_avals]
result_types = [mlir.aval_to_ir_type(aval) for aval in result_avals]
ftype = ir.FunctionType.get(arg_types, result_types)
func = func_dialect.FuncOp("func", ftype, ip=ip)
entry_block = func.add_entry_block()
with ir.InsertionPoint(entry_block):
outs = mlir.custom_call(
name, result_types=result_types, operands=entry_block.arguments, **kw
)
func_dialect.ReturnOp(outs)
module.operation.verify()
print(str(module))

def main(_):
aval1 = ShapedArray((2, 3), np.dtype(np.float32))
aval2 = ShapedArray((3, 4), np.dtype(np.int64))
# CHECK-LABEL: TEST: simple
# CHECK: stablehlo.custom_call @simple(%arg0) {api_version = 2 : i32} : (tensor<2x3xf32>) -> tensor<3x4xi64>
print_custom_call("simple", [aval1], [aval2])

# CHECK-LABEL: TEST: sideeffect
# CHECK: stablehlo.custom_call @sideeffect(%arg0) {has_side_effect = true} : (tensor<2x3xf32>) -> tensor<3x4xi64>
print_custom_call("sideeffect", [aval1], [aval2], api_version=1,
has_side_effect=True)

# CHECK-LABEL: TEST: backendconfig
# CHECK: stablehlo.custom_call @backendconfig(%arg0) {backend_config = "hello"} : (tensor<2x3xf32>) -> tensor<3x4xi64>
print_custom_call("backendconfig", [aval1], [aval2], api_version=1,
backend_config=b"hello")

# CHECK-LABEL: TEST: calledcomputations
# CHECK: stablehlo.custom_call @calledcomputations(%arg0) {called_computations = [@a, @b]} : (tensor<2x3xf32>) -> tensor<3x4xi64>
print_custom_call("calledcomputations", [aval1], [aval2], api_version=1,
called_computations=["a", "b"])

# CHECK-LABEL: TEST: aliases
# CHECK: stablehlo.custom_call @aliases(%arg0, %arg1) {output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 1, operand_tuple_indices = []>]} : (tensor<2x3xf32>, tensor<3x4xi64>) -> (tensor<3x4xi64>, tensor<2x3xf32>)
print_custom_call("aliases", [aval1, aval2], [aval2, aval1], api_version=1,
operand_output_aliases={1: 0})

# CHECK-LABEL: TEST: layouts
# CHECK: stablehlo.custom_call @layouts(%arg0, %arg1) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>]} : (tensor<2x3xf32>, tensor<3x4xi64>) -> (tensor<3x4xi64>, tensor<2x3xf32>)
print_custom_call("layouts", [aval1, aval2], [aval2, aval1], api_version=1,
operand_layouts=[[0, 1], [1, 0]],
result_layouts=[[1, 0], [0, 1]])

if __name__ == "__main__":
app.run(main)

0 comments on commit d0a6813

Please sign in to comment.