-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[MLIR][Python] Add a DSL for defining IRDL dialects in Python bindings #169045
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🐧 Linux x64 Test Results
|
338605e to
3ec4809
Compare
|
@llvm/pr-subscribers-mlir Author: Twice (PragmaTwice) ChangesPython bindings for the IRDL dialect were introduced in #158488. They are currently usable—for constructing IR and dynamically loading modules that contain
To address these issues, I propose creating a wrapper (effectively a small “DSL”) on top of the existing IRDL Python bindings. This wrapper aims to simplify IR construction and automatically generate the corresponding Currently, using the IRDL bindings looks like this: m = Module.create()
with InsertionPoint(m.body):
myint = irdl.dialect("myint")
with InsertionPoint(myint.body):
constant = irdl.operation_("constant")
with InsertionPoint(constant.body):
iattr = irdl.base(base_name="#builtin.integer")
i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32)))
irdl.attributes_([iattr], ["value"])
irdl.results_([i32], ["cst"], [irdl.Variadicity.single])
add = irdl.operation_("add")
with InsertionPoint(add.body):
i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32)))
irdl.operands_(
[i32, i32],
["lhs", "rhs"],
[irdl.Variadicity.single, irdl.Variadicity.single],
)
irdl.results_([i32], ["res"], [irdl.Variadicity.single])
irdl.load_dialects(m)With the proposed DSL, the equivalent implementation becomes: myint = irdsl.Dialect("myint")
iattr = irdsl.BaseName("#builtin.integer")
i32 = irdsl.IsType(IntegerType.get_signless(32))
@<!-- -->myint.op("constant")
class ConstantOp:
value = irdsl.Attribute(iattr)
cst = irdsl.Result(i32)
@<!-- -->myint.op("add")
class AddOp:
lhs = irdsl.Operand(i32)
rhs = irdsl.Operand(i32)
res = irdsl.Result(i32)
myint.load()Compared with the current IRDL Python bindings, this DSL mainly adds the following:
The current DSL does not yet cover all IRDL operations. Several features are not supported at the moment:
Patch is 26.81 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/169045.diff 4 Files Affected:
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 2acb6ee6cfda5..27b87a8b70144 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -511,7 +511,9 @@ declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/IRDLOps.td
- SOURCES dialects/irdl.py
+ SOURCES
+ dialects/irdl/__init__.py
+ dialects/irdl/dsl.py
DIALECT_NAME irdl
GEN_ENUM_BINDINGS
)
diff --git a/mlir/python/mlir/dialects/irdl.py b/mlir/python/mlir/dialects/irdl/__init__.py
similarity index 91%
rename from mlir/python/mlir/dialects/irdl.py
rename to mlir/python/mlir/dialects/irdl/__init__.py
index 1ec951b69b646..6b2787ed7966c 100644
--- a/mlir/python/mlir/dialects/irdl.py
+++ b/mlir/python/mlir/dialects/irdl/__init__.py
@@ -2,13 +2,14 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from ._irdl_ops_gen import *
-from ._irdl_ops_gen import _Dialect
-from ._irdl_enum_gen import *
-from .._mlir_libs._mlirDialectsIRDL import *
-from ..ir import register_attribute_builder
-from ._ods_common import _cext as _ods_cext
+from .._irdl_ops_gen import *
+from .._irdl_ops_gen import _Dialect
+from .._irdl_enum_gen import *
+from ..._mlir_libs._mlirDialectsIRDL import *
+from ...ir import register_attribute_builder
+from .._ods_common import _cext as _ods_cext
from typing import Union, Sequence
+from . import dsl
_ods_ir = _ods_cext.ir
diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
new file mode 100644
index 0000000000000..41520f0e0f68c
--- /dev/null
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -0,0 +1,359 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ...dialects import irdl as _irdl
+from .._ods_common import (
+ _cext as _ods_cext,
+ segmented_accessor as _ods_segmented_accessor,
+)
+from . import Variadicity
+from typing import Dict, List, Union, Callable, Tuple
+from dataclasses import dataclass
+from inspect import Parameter as _Parameter, Signature as _Signature
+from types import SimpleNamespace as _SimpleNameSpace
+
+_ods_ir = _ods_cext.ir
+
+
+class ConstraintExpr:
+ def _lower(self, ctx: "ConstraintLoweringContext") -> _ods_ir.Value:
+ raise NotImplementedError()
+
+ def __or__(self, other: "ConstraintExpr") -> "ConstraintExpr":
+ return AnyOf(self, other)
+
+ def __and__(self, other: "ConstraintExpr") -> "ConstraintExpr":
+ return AllOf(self, other)
+
+
+class ConstraintLoweringContext:
+ def __init__(self):
+ # Cache so that the same ConstraintExpr instance reuses its SSA value.
+ self._cache: Dict[int, _ods_ir.Value] = {}
+
+ def lower(self, expr: ConstraintExpr) -> _ods_ir.Value:
+ key = id(expr)
+ if key in self._cache:
+ return self._cache[key]
+ v = expr._lower(self)
+ self._cache[key] = v
+ return v
+
+
+class Is(ConstraintExpr):
+ def __init__(self, attr: _ods_ir.Attribute):
+ self.attr = attr
+
+ def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+ return _irdl.is_(self.attr)
+
+
+class IsType(Is):
+ def __init__(self, typ: _ods_ir.Type):
+ super().__init__(_ods_ir.TypeAttr.get(typ))
+
+
+class AnyOf(ConstraintExpr):
+ def __init__(self, *exprs: ConstraintExpr):
+ self.exprs = exprs
+
+ def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+ return _irdl.any_of(ctx.lower(expr) for expr in self.exprs)
+
+
+class AllOf(ConstraintExpr):
+ def __init__(self, *exprs: ConstraintExpr):
+ self.exprs = exprs
+
+ def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+ return _irdl.all_of(ctx.lower(expr) for expr in self.exprs)
+
+
+class Any(ConstraintExpr):
+ def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+ return _irdl.any()
+
+
+class BaseName(ConstraintExpr):
+ def __init__(self, name: str):
+ self.name = name
+
+ def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+ return _irdl.base(base_name=self.name)
+
+
+class BaseRef(ConstraintExpr):
+ def __init__(self, ref):
+ self.ref = ref
+
+ def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+ return _irdl.base(base_ref=self.ref)
+
+
+class _FieldDef:
+ def __set_name__(self, owner, name: str):
+ self.name = name
+
+
+@dataclass
+class Operand(_FieldDef):
+ constraint: ConstraintExpr
+ variadicity: Variadicity = Variadicity.single
+
+
+@dataclass
+class Result(_FieldDef):
+ constraint: ConstraintExpr
+ variadicity: Variadicity = Variadicity.single
+
+
+@dataclass
+class Attribute(_FieldDef):
+ constraint: ConstraintExpr
+
+ def __post_init__(self):
+ # just for unified processing,
+ # currently optional attribute is not supported by IRDL
+ self.variadicity = Variadicity.single
+
+
+@dataclass
+class Operation:
+ dialect_name: str
+ name: str
+ # We store operands and attributes into one list to maintain relative orders
+ # among them for generating OpView class.
+ fields: List[Union[Operand, Attribute, Result]]
+
+ def _partition_fields(self) -> Tuple[List[Operand], List[Attribute], List[Result]]:
+ operands = [i for i in self.fields if isinstance(i, Operand)]
+ attrs = [i for i in self.fields if isinstance(i, Attribute)]
+ results = [i for i in self.fields if isinstance(i, Result)]
+ return operands, attrs, results
+
+ def _emit(self) -> None:
+ ctx = ConstraintLoweringContext()
+ operands, attrs, results = self._partition_fields()
+
+ op = _irdl.operation_(self.name)
+ with _ods_ir.InsertionPoint(op.body):
+ if operands:
+ _irdl.operands_(
+ [ctx.lower(i.constraint) for i in operands],
+ [i.name for i in operands],
+ [i.variadicity for i in operands],
+ )
+ if attrs:
+ _irdl.attributes_(
+ [ctx.lower(i.constraint) for i in attrs],
+ [i.name for i in attrs],
+ )
+ if results:
+ _irdl.results_(
+ [ctx.lower(i.constraint) for i in results],
+ [i.name for i in results],
+ [i.variadicity for i in results],
+ )
+
+ @staticmethod
+ def _variadicity_to_segment(variadicity: Variadicity) -> int:
+ if variadicity == Variadicity.variadic:
+ return -1
+ if variadicity == Variadicity.optional:
+ return 0
+ return 1
+
+ @staticmethod
+ def _generate_segments(
+ operands_or_results: List[Union[Operand, Result]],
+ ) -> List[int]:
+ if any(i.variadicity != Variadicity.single for i in operands_or_results):
+ return [
+ Operation._variadicity_to_segment(i.variadicity)
+ for i in operands_or_results
+ ]
+ return None
+
+ def _generate_init_params(self) -> List[_Parameter]:
+ # results are placed at the beginning of the parameter list,
+ # but operands and attributes can appear in any relative order.
+ args = [i for i in self.fields if isinstance(i, Result)] + [
+ i for i in self.fields if not isinstance(i, Result)
+ ]
+ positional_args = [
+ i.name for i in args if i.variadicity != Variadicity.optional
+ ]
+ optional_args = [i.name for i in args if i.variadicity == Variadicity.optional]
+
+ params = [_Parameter("self", _Parameter.POSITIONAL_ONLY)]
+ for i in positional_args:
+ params.append(_Parameter(i, _Parameter.POSITIONAL_OR_KEYWORD))
+ for i in optional_args:
+ params.append(_Parameter(i, _Parameter.KEYWORD_ONLY, default=None))
+ params.append(_Parameter("loc", _Parameter.KEYWORD_ONLY, default=None))
+ params.append(_Parameter("ip", _Parameter.KEYWORD_ONLY, default=None))
+
+ return params
+
+ def _make_op_view_and_builder(self) -> Tuple[type, Callable]:
+ operands, attrs, results = self._partition_fields()
+
+ operand_segments = Operation._generate_segments(operands)
+ result_segments = Operation._generate_segments(results)
+
+ params = self._generate_init_params()
+ init_sig = _Signature(params)
+ op = self
+
+ class _OpView(_ods_ir.OpView):
+ OPERATION_NAME = f"{op.dialect_name}.{op.name}"
+ _ODS_REGIONS = (0, True)
+ _ODS_OPERAND_SEGMENTS = operand_segments
+ _ODS_RESULT_SEGMENTS = result_segments
+
+ def __init__(*args, **kwargs):
+ bound = init_sig.bind(*args, **kwargs)
+ bound.apply_defaults()
+ args = bound.arguments
+
+ _operands = [args[operand.name] for operand in operands]
+ _results = [args[result.name] for result in results]
+ _attributes = dict(
+ (attr.name, args[attr.name])
+ for attr in attrs
+ if args[attr.name] is not None
+ )
+ _regions = None
+ _ods_successors = None
+ self = args["self"]
+ super(_OpView, self).__init__(
+ self.OPERATION_NAME,
+ self._ODS_REGIONS,
+ self._ODS_OPERAND_SEGMENTS,
+ self._ODS_RESULT_SEGMENTS,
+ attributes=_attributes,
+ results=_results,
+ operands=_operands,
+ successors=_ods_successors,
+ regions=_regions,
+ loc=args["loc"],
+ ip=args["ip"],
+ )
+
+ __init__.__signature__ = init_sig
+
+ for attr in attrs:
+ setattr(
+ _OpView,
+ attr.name,
+ property(lambda self, name=attr.name: self.attributes[name]),
+ )
+
+ def value_range_getter(
+ value_range: Union[_ods_ir.OpOperandList, _ods_ir.OpResultList],
+ variadicity: Variadicity,
+ ):
+ if variadicity == Variadicity.single:
+ return value_range[0]
+ if variadicity == Variadicity.optional:
+ return value_range[0] if len(value_range) > 0 else None
+ return value_range
+
+ for i, operand in enumerate(operands):
+ if operand_segments:
+
+ def getter(self, i=i, operand=operand):
+ operand_range = _ods_segmented_accessor(
+ self.operation.operands,
+ self.operation.attributes["operandSegmentSizes"],
+ i,
+ )
+ return value_range_getter(operand_range, operand.variadicity)
+
+ setattr(_OpView, operand.name, property(getter))
+ else:
+ setattr(
+ _OpView, operand.name, property(lambda self, i=i: self.operands[i])
+ )
+ for i, result in enumerate(results):
+ if result_segments:
+
+ def getter(self, i=i, result=result):
+ result_range = _ods_segmented_accessor(
+ self.operation.results,
+ self.operation.attributes["resultSegmentSizes"],
+ i,
+ )
+ return value_range_getter(result_range, result.variadicity)
+
+ setattr(_OpView, result.name, property(getter))
+ else:
+ setattr(
+ _OpView, result.name, property(lambda self, i=i: self.results[i])
+ )
+
+ def _builder(*args, **kwargs) -> _OpView:
+ return _OpView(*args, **kwargs)
+
+ _builder.__signature__ = _Signature(params[1:])
+
+ return _OpView, _builder
+
+
+class Dialect:
+ def __init__(self, name: str):
+ self.name = name
+ self.operations: List[Operation] = []
+ self.namespace = _SimpleNameSpace()
+
+ def _emit(self) -> None:
+ d = _irdl.dialect(self.name)
+ with _ods_ir.InsertionPoint(d.body):
+ for op in self.operations:
+ op._emit()
+
+ def _make_module(self) -> _ods_ir.Module:
+ with _ods_ir.Location.unknown():
+ m = _ods_ir.Module.create()
+ with _ods_ir.InsertionPoint(m.body):
+ self._emit()
+ return m
+
+ def _make_dialect_class(self) -> type:
+ class _Dialect(_ods_ir.Dialect):
+ DIALECT_NAMESPACE = self.name
+
+ return _Dialect
+
+ def load(self) -> _SimpleNameSpace:
+ _irdl.load_dialects(self._make_module())
+ dialect_class = self._make_dialect_class()
+ _ods_cext.register_dialect(dialect_class)
+ for op in self.operations:
+ _ods_cext.register_operation(dialect_class)(op.op_view)
+ return self.namespace
+
+ @staticmethod
+ def _op_name_to_python_id(name: str) -> str:
+ return name.replace(".", "_").replace("$", "_")
+
+ def op(self, name: str) -> Callable[[type], type]:
+ def decorator(cls: type) -> type:
+ fields = [
+ field for field in cls.__dict__.values() if isinstance(field, _FieldDef)
+ ]
+ op_def = Operation(self.name, name, fields)
+
+ op_view, builder = op_def._make_op_view_and_builder()
+ op_view.__name__ = cls.__name__
+ setattr(op_def, "op_view", op_view)
+ setattr(op_def, "builder", builder)
+ self.operations.append(op_def)
+
+ self.namespace.__dict__[cls.__name__] = op_view
+ self.namespace.__dict__[Dialect._op_name_to_python_id(name)] = builder
+
+ return cls
+
+ return decorator
diff --git a/mlir/test/python/dialects/irdsl.py b/mlir/test/python/dialects/irdsl.py
new file mode 100644
index 0000000000000..8ef30ae0a4c13
--- /dev/null
+++ b/mlir/test/python/dialects/irdsl.py
@@ -0,0 +1,308 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects.irdl import dsl as irdsl
+from mlir.dialects import arith
+import sys
+
+
+def run(f):
+ print("\nTEST:", f.__name__, file=sys.stderr)
+ with Context():
+ f()
+
+
+# CHECK: TEST: testMyInt
+@run
+def testMyInt():
+ myint = irdsl.Dialect("myint")
+ iattr = irdsl.BaseName("#builtin.integer")
+ i32 = irdsl.IsType(IntegerType.get_signless(32))
+
+ @myint.op("constant")
+ class ConstantOp:
+ value = irdsl.Attribute(iattr)
+ cst = irdsl.Result(i32)
+
+ @myint.op("add")
+ class AddOp:
+ lhs = irdsl.Operand(i32)
+ rhs = irdsl.Operand(i32)
+ res = irdsl.Result(i32)
+
+ # CHECK: irdl.dialect @myint {
+ # CHECK: irdl.operation @constant {
+ # CHECK: %0 = irdl.base "#builtin.integer"
+ # CHECK: irdl.attributes {"value" = %0}
+ # CHECK: %1 = irdl.is i32
+ # CHECK: irdl.results(cst: %1)
+ # CHECK: }
+ # CHECK: irdl.operation @add {
+ # CHECK: %0 = irdl.is i32
+ # CHECK: irdl.operands(lhs: %0, rhs: %0)
+ # CHECK: irdl.results(res: %0)
+ # CHECK: }
+ # CHECK: }
+ print(myint._make_module())
+ myint = myint.load()
+
+ # CHECK: ['ConstantOp', 'constant', 'AddOp', 'add']
+ print([i for i in myint.__dict__.keys()])
+
+ i32 = IntegerType.get_signless(32)
+ with Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ two = myint.constant(i32, IntegerAttr.get(i32, 2))
+ three = myint.constant(i32, IntegerAttr.get(i32, 3))
+ add1 = myint.add(i32, two, three)
+ add2 = myint.add(i32, add1, two)
+ add3 = myint.add(i32, add2, three)
+
+ # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
+ # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
+ # CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
+ # CHECK: %3 = "myint.add"(%2, %0) : (i32, i32) -> i32
+ # CHECK: %4 = "myint.add"(%3, %1) : (i32, i32) -> i32
+ print(module)
+ assert module.operation.verify()
+
+ # CHECK: AddOp
+ print(type(add1).__name__)
+ # CHECK: ConstantOp
+ print(type(two).__name__)
+ # CHECK: myint.add
+ print(add1.OPERATION_NAME)
+ # CHECK: None
+ print(add1._ODS_OPERAND_SEGMENTS)
+ # CHECK: None
+ print(add1._ODS_RESULT_SEGMENTS)
+ # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
+ print(add1.lhs.owner)
+ # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
+ print(add1.rhs.owner)
+ # CHECK: 2 : i32
+ print(two.value)
+ # CHECK: Value(%0
+ print(two.cst)
+ # CHECK: (res, lhs, rhs, *, loc=None, ip=None)
+ print(myint.add.__signature__)
+ # CHECK: (cst, value, *, loc=None, ip=None)
+ print(myint.constant.__signature__)
+
+
+@run
+def testIRDSL():
+ test = irdsl.Dialect("irdsl_test")
+ i32 = irdsl.IsType(IntegerType.get_signless(32))
+ i64 = irdsl.IsType(IntegerType.get_signless(64))
+ i32or64 = i32 | i64
+ any = irdsl.Any()
+ f32 = irdsl.IsType(F32Type.get())
+ iattr = irdsl.BaseName("#builtin.integer")
+ fattr = irdsl.BaseName("#builtin.float")
+
+ @test.op("constraint")
+ class ConstraintOp:
+ a = irdsl.Operand(i32or64)
+ b = irdsl.Operand(any)
+ c = irdsl.Operand(f32 | i32)
+ d = irdsl.Operand(any)
+ x = irdsl.Attribute(iattr)
+ y = irdsl.Attribute(fattr)
+
+ @test.op("optional")
+ class OptionalOp:
+ a = irdsl.Operand(i32)
+ b = irdsl.Operand(i32, irdsl.Variadicity.optional)
+ out1 = irdsl.Result(i32)
+ out2 = irdsl.Result(i32, irdsl.Variadicity.optional)
+ out3 = irdsl.Result(i32)
+
+ @test.op("optional2")
+ class Optional2Op:
+ a = irdsl.Operand(i32, irdsl.Variadicity.optional)
+ b = irdsl.Result(i32, irdsl.Variadicity.optional)
+
+ @test.op("variadic")
+ class VariadicOp:
+ a = irdsl.Operand(i32)
+ b = irdsl.Operand(i32, irdsl.Variadicity.optional)
+ c = irdsl.Operand(i32, irdsl.Variadicity.variadic)
+ out1 = irdsl.Result(i32, irdsl.Variadicity.variadic)
+ out2 = irdsl.Result(i32, irdsl.Variadicity.variadic)
+ out3 = irdsl.Result(i32, irdsl.Variadicity.optional)
+ out4 = irdsl.Result(i32)
+
+ @test.op("variadic2")
+ class Variadic2Op:
+ a = irdsl.Operand(i32, irdsl.Variadicity.variadic)
+ b = irdsl.Result(i32, irdsl.Variadicity.variadic)
+
+ @test.op("mixed")
+ class MixedOp:
+ out = irdsl.Result(i32)
+ in1 = irdsl.Operand(i32)
+ in2 = irdsl.Attribute(iattr)
+ in3 = irdsl.Operand(i32, irdsl.Variadicity.optional)
+ in4 = irdsl.Attribute(iattr)
+ in5 = irdsl.Operand(i32)
+
+ # CHECK: irdl.dialect @irdsl_test {
+ # CHECK: irdl.operation @constraint {
+ # CHECK: %0 = irdl.is i32
+ # CHECK: %1 = irdl.is i64
+ # CHECK: %2 = irdl.any_of(%0, %1)
+ # CHECK: %3 = irdl.any
+ # CHECK: %4 = irdl.is f32
+ # CHECK: %5 = irdl.any_of(%4, %0)
+ # CHECK: irdl.operands(a: %2, b: %3, c: %5, d: %3)
+ # CHECK: %6 = irdl.base "#builtin.integer"
+ # CHECK: %7 = irdl.base "#builtin.float"
+ # CHECK: irdl.attributes {"x" = %6, "y" = %7}
+ # CHECK: }
+ # CHECK: irdl.operation @optional {
+ # CHECK: %0 = irdl.is i32
+ # CHECK: irdl.operands(a: %0, b: optional %0)
+ # CHECK: irdl.results(out1: %0, out2: optional %0, out3: %0)
+ # CHECK: }
+ # CHECK: irdl.operation @optional2 {
+ # CHECK: %0 = irdl.is i32
+ # CHECK: irdl.operands(a: optional %0)
+ # CHECK: irdl.results(b: optional %0)
+ # CHECK: }
+ # CHECK: irdl.operation @variadic {
+ # CHECK: %0 = irdl.is i32
+ # CHECK: irdl.operands(a: %0, b: optional %0, c: variadic %0)
+ # CHECK: irdl.results(out1...
[truncated]
|
|
Nice -- this has been on my wishlist! Could we get a comparison with the xDSL API for defining ops/dialects? E.g. versus https://github.com/xdslproject/xdsl/blob/main/tests/irdl/test_operation_definition.py. Surely they have "lessons learnt" we should consider and there might be some utility in having API compatibility. (P.S. @PragmaTwice, I have a local proof-of-concept for exposing MLIR C++ Interfaces' ExternalModels to Python so that their methods can be implemented in Python and the external model attached to any op in Python. For now the API is rather low-level, i.e. actually implement the ExternalModel as a separate class and attach it to a op. The long-term vision is that such a op-definition DSL would allow for multiple inheritance so that you just specify the interfaces as additional parent classes and implement their interface's methods alongside other Python methods on the op. Some Python metaprogramming magic would take care of still making the appropriate ExternalModels and attaching them to the op, upon the op getting registered.) |
Absolutely, I’ll try to put together a comparison table. Because of some differences in the underlying mechanisms (for example, MLIR Python’s
Does this (eventually) imply extending the IRDL dialect to introduce concepts like interfaces and traits? I’m wondering if this overlaps with this issue: #158066. I’m really looking forward to seeing this — it sounds like very cool work. I’ve also been thinking about how to extend IRDL to support a more flexible constraint system (#158049), but that’s still in a very early stage. We can keep these possible IRDL evolutions in mind so that, when we introduce things like interfaces/traits, the DSL in this PR won’t need to change too drastically and users won’t have to rewrite large amounts of code. |
Python bindings for the IRDL dialect were introduced in #158488. They are currently usable—for constructing IR and dynamically loading modules that contain
irdl.dialectinto MLIR. However, there are still several pain points when working with them:OpViewclasses or builder functions for IRDL-defined operations.To address these issues, I propose creating a wrapper (effectively a small “DSL”) on top of the existing IRDL Python bindings. This wrapper aims to simplify IR construction and automatically generate the corresponding
OpViewtypes. A simple example is shown below.Currently, using the IRDL bindings looks like this:
With the proposed DSL, the equivalent implementation becomes:
Compared with the current IRDL Python bindings, this DSL mainly adds the following:
OpViewclasses—including__init__methods and property getters—as well as builder functions for each defined operation. Similar to TableGen’sins, operands and attributes can be interleaved in arbitrary order. Special handling is also implemented for optional and variadic operands/results (such as computing segment sizes) so that they feel as natural to use as native operations.Dialect.load()is called, which makes it unnecessary to specify an MLIR context immediately when defining an IRDL dialect.The current DSL does not yet cover all IRDL operations. Several features are not supported at the moment: