Skip to content

Conversation

@PragmaTwice
Copy link
Member

@PragmaTwice PragmaTwice commented Nov 21, 2025

Python bindings for the IRDL dialect were introduced in #158488. They are currently usable—for constructing IR and dynamically loading modules that contain irdl.dialect into MLIR. However, there are still several pain points when working with them:

  • The IRDL IR-building interface is not very intuitive and tends to be quite verbose.
  • We do not yet have the corresponding OpView classes or builder functions for IRDL-defined operations.

To address these issues, I propose creating a wrapper (effectively a small “DSL”) on top of the existing IRDL Python bindings. This wrapper aims to simplify IR construction and automatically generate the corresponding OpView types. A simple example is shown below.

Currently, using the IRDL bindings looks like this:

m = Module.create()
with InsertionPoint(m.body):
    myint = irdl.dialect("myint")
    with InsertionPoint(myint.body):
        constant = irdl.operation_("constant")
        with InsertionPoint(constant.body):
            iattr = irdl.base(base_name="#builtin.integer")
            i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32)))
            irdl.attributes_([iattr], ["value"])
            irdl.results_([i32], ["cst"], [irdl.Variadicity.single])

        add = irdl.operation_("add")
        with InsertionPoint(add.body):
            i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32)))
            irdl.operands_(
                [i32, i32],
                ["lhs", "rhs"],
                [irdl.Variadicity.single, irdl.Variadicity.single],
            )
            irdl.results_([i32], ["res"], [irdl.Variadicity.single])

irdl.load_dialects(m)

With the proposed DSL, the equivalent implementation becomes:

myint = irdsl.Dialect("myint")
iattr = irdsl.BaseName("#builtin.integer")
i32 = irdsl.IsType(IntegerType.get_signless(32))

@myint.op("constant")
class ConstantOp:
    value = irdsl.Attribute(iattr)
    cst = irdsl.Result(i32)

@myint.op("add")
class AddOp:
    lhs = irdsl.Operand(i32)
    rhs = irdsl.Operand(i32)
    res = irdsl.Result(i32)

myint.load()

Compared with the current IRDL Python bindings, this DSL mainly adds the following:

  • A more intuitive interface for constructing IRDL definitions (as shown in the example).
  • Automatic generation of the corresponding OpView classes—including __init__ methods and property getters—as well as builder functions for each defined operation. Similar to TableGen’s ins, operands and attributes can be interleaved in arbitrary order. Special handling is also implemented for optional and variadic operands/results (such as computing segment sizes) so that they feel as natural to use as native operations.
  • Lazy insertion of ops: all ops are created and inserted only when Dialect.load() is called, which makes it unnecessary to specify an MLIR context immediately when defining an IRDL dialect.

The current DSL does not yet cover all IRDL operations. Several features are not supported at the moment:

  • Defining new types or attributes
  • Parametric constraints
  • Adding regions to operations
  • Any form of "type inference"

@github-actions
Copy link

github-actions bot commented Nov 21, 2025

🐧 Linux x64 Test Results

  • 7121 tests passed
  • 594 tests skipped

@PragmaTwice PragmaTwice marked this pull request as ready for review November 24, 2025 06:17
@llvmbot llvmbot added mlir:python MLIR Python bindings mlir labels Nov 24, 2025
@llvmbot
Copy link
Member

llvmbot commented Nov 24, 2025

@llvm/pr-subscribers-mlir

Author: Twice (PragmaTwice)

Changes

Python bindings for the IRDL dialect were introduced in #158488. They are currently usable—for constructing IR and dynamically loading modules that contain irdl.dialect into MLIR. However, there are still several pain points when working with them:

  • The IRDL IR-building interface is not very intuitive and tends to be quite verbose.
  • We do not yet have the corresponding OpView classes or builder functions for IRDL-defined operations.

To address these issues, I propose creating a wrapper (effectively a small “DSL”) on top of the existing IRDL Python bindings. This wrapper aims to simplify IR construction and automatically generate the corresponding OpView types. A simple example is shown below.

Currently, using the IRDL bindings looks like this:

m = Module.create()
with InsertionPoint(m.body):
    myint = irdl.dialect("myint")
    with InsertionPoint(myint.body):
        constant = irdl.operation_("constant")
        with InsertionPoint(constant.body):
            iattr = irdl.base(base_name="#builtin.integer")
            i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32)))
            irdl.attributes_([iattr], ["value"])
            irdl.results_([i32], ["cst"], [irdl.Variadicity.single])

        add = irdl.operation_("add")
        with InsertionPoint(add.body):
            i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32)))
            irdl.operands_(
                [i32, i32],
                ["lhs", "rhs"],
                [irdl.Variadicity.single, irdl.Variadicity.single],
            )
            irdl.results_([i32], ["res"], [irdl.Variadicity.single])

irdl.load_dialects(m)

With the proposed DSL, the equivalent implementation becomes:

myint = irdsl.Dialect("myint")
iattr = irdsl.BaseName("#builtin.integer")
i32 = irdsl.IsType(IntegerType.get_signless(32))

@<!-- -->myint.op("constant")
class ConstantOp:
    value = irdsl.Attribute(iattr)
    cst = irdsl.Result(i32)

@<!-- -->myint.op("add")
class AddOp:
    lhs = irdsl.Operand(i32)
    rhs = irdsl.Operand(i32)
    res = irdsl.Result(i32)

myint.load()

Compared with the current IRDL Python bindings, this DSL mainly adds the following:

  • A more intuitive interface for constructing IRDL definitions (as shown in the example).
  • Automatic generation of the corresponding OpView classes—including __init__ methods and property getters—as well as builder functions for each defined operation. Similar to TableGen’s ins, operands and attributes can be interleaved in arbitrary order. Special handling is also implemented for optional and variadic operands/results (such as computing segment sizes) so that they feel as natural to use as native operations.
  • Lazy insertion of ops: all ops are created and inserted only when Dialect.load() is called, which makes it unnecessary to specify an MLIR context immediately when defining an IRDL dialect.

The current DSL does not yet cover all IRDL operations. Several features are not supported at the moment:

  • Defining new types or attributes
  • Parametric constraints
  • Adding regions to operations

Patch is 26.81 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/169045.diff

4 Files Affected:

  • (modified) mlir/python/CMakeLists.txt (+3-1)
  • (renamed) mlir/python/mlir/dialects/irdl/init.py (+7-6)
  • (added) mlir/python/mlir/dialects/irdl/dsl.py (+359)
  • (added) mlir/test/python/dialects/irdsl.py (+308)
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 2acb6ee6cfda5..27b87a8b70144 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -511,7 +511,9 @@ declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/IRDLOps.td
-  SOURCES dialects/irdl.py
+  SOURCES
+    dialects/irdl/__init__.py
+    dialects/irdl/dsl.py
   DIALECT_NAME irdl
   GEN_ENUM_BINDINGS
 )
diff --git a/mlir/python/mlir/dialects/irdl.py b/mlir/python/mlir/dialects/irdl/__init__.py
similarity index 91%
rename from mlir/python/mlir/dialects/irdl.py
rename to mlir/python/mlir/dialects/irdl/__init__.py
index 1ec951b69b646..6b2787ed7966c 100644
--- a/mlir/python/mlir/dialects/irdl.py
+++ b/mlir/python/mlir/dialects/irdl/__init__.py
@@ -2,13 +2,14 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from ._irdl_ops_gen import *
-from ._irdl_ops_gen import _Dialect
-from ._irdl_enum_gen import *
-from .._mlir_libs._mlirDialectsIRDL import *
-from ..ir import register_attribute_builder
-from ._ods_common import _cext as _ods_cext
+from .._irdl_ops_gen import *
+from .._irdl_ops_gen import _Dialect
+from .._irdl_enum_gen import *
+from ..._mlir_libs._mlirDialectsIRDL import *
+from ...ir import register_attribute_builder
+from .._ods_common import _cext as _ods_cext
 from typing import Union, Sequence
+from . import dsl
 
 _ods_ir = _ods_cext.ir
 
diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
new file mode 100644
index 0000000000000..41520f0e0f68c
--- /dev/null
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -0,0 +1,359 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ...dialects import irdl as _irdl
+from .._ods_common import (
+    _cext as _ods_cext,
+    segmented_accessor as _ods_segmented_accessor,
+)
+from . import Variadicity
+from typing import Dict, List, Union, Callable, Tuple
+from dataclasses import dataclass
+from inspect import Parameter as _Parameter, Signature as _Signature
+from types import SimpleNamespace as _SimpleNameSpace
+
+_ods_ir = _ods_cext.ir
+
+
+class ConstraintExpr:
+    def _lower(self, ctx: "ConstraintLoweringContext") -> _ods_ir.Value:
+        raise NotImplementedError()
+
+    def __or__(self, other: "ConstraintExpr") -> "ConstraintExpr":
+        return AnyOf(self, other)
+
+    def __and__(self, other: "ConstraintExpr") -> "ConstraintExpr":
+        return AllOf(self, other)
+
+
+class ConstraintLoweringContext:
+    def __init__(self):
+        # Cache so that the same ConstraintExpr instance reuses its SSA value.
+        self._cache: Dict[int, _ods_ir.Value] = {}
+
+    def lower(self, expr: ConstraintExpr) -> _ods_ir.Value:
+        key = id(expr)
+        if key in self._cache:
+            return self._cache[key]
+        v = expr._lower(self)
+        self._cache[key] = v
+        return v
+
+
+class Is(ConstraintExpr):
+    def __init__(self, attr: _ods_ir.Attribute):
+        self.attr = attr
+
+    def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+        return _irdl.is_(self.attr)
+
+
+class IsType(Is):
+    def __init__(self, typ: _ods_ir.Type):
+        super().__init__(_ods_ir.TypeAttr.get(typ))
+
+
+class AnyOf(ConstraintExpr):
+    def __init__(self, *exprs: ConstraintExpr):
+        self.exprs = exprs
+
+    def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+        return _irdl.any_of(ctx.lower(expr) for expr in self.exprs)
+
+
+class AllOf(ConstraintExpr):
+    def __init__(self, *exprs: ConstraintExpr):
+        self.exprs = exprs
+
+    def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+        return _irdl.all_of(ctx.lower(expr) for expr in self.exprs)
+
+
+class Any(ConstraintExpr):
+    def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+        return _irdl.any()
+
+
+class BaseName(ConstraintExpr):
+    def __init__(self, name: str):
+        self.name = name
+
+    def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+        return _irdl.base(base_name=self.name)
+
+
+class BaseRef(ConstraintExpr):
+    def __init__(self, ref):
+        self.ref = ref
+
+    def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+        return _irdl.base(base_ref=self.ref)
+
+
+class _FieldDef:
+    def __set_name__(self, owner, name: str):
+        self.name = name
+
+
+@dataclass
+class Operand(_FieldDef):
+    constraint: ConstraintExpr
+    variadicity: Variadicity = Variadicity.single
+
+
+@dataclass
+class Result(_FieldDef):
+    constraint: ConstraintExpr
+    variadicity: Variadicity = Variadicity.single
+
+
+@dataclass
+class Attribute(_FieldDef):
+    constraint: ConstraintExpr
+
+    def __post_init__(self):
+        # just for unified processing,
+        # currently optional attribute is not supported by IRDL
+        self.variadicity = Variadicity.single
+
+
+@dataclass
+class Operation:
+    dialect_name: str
+    name: str
+    # We store operands and attributes into one list to maintain relative orders
+    # among them for generating OpView class.
+    fields: List[Union[Operand, Attribute, Result]]
+
+    def _partition_fields(self) -> Tuple[List[Operand], List[Attribute], List[Result]]:
+        operands = [i for i in self.fields if isinstance(i, Operand)]
+        attrs = [i for i in self.fields if isinstance(i, Attribute)]
+        results = [i for i in self.fields if isinstance(i, Result)]
+        return operands, attrs, results
+
+    def _emit(self) -> None:
+        ctx = ConstraintLoweringContext()
+        operands, attrs, results = self._partition_fields()
+
+        op = _irdl.operation_(self.name)
+        with _ods_ir.InsertionPoint(op.body):
+            if operands:
+                _irdl.operands_(
+                    [ctx.lower(i.constraint) for i in operands],
+                    [i.name for i in operands],
+                    [i.variadicity for i in operands],
+                )
+            if attrs:
+                _irdl.attributes_(
+                    [ctx.lower(i.constraint) for i in attrs],
+                    [i.name for i in attrs],
+                )
+            if results:
+                _irdl.results_(
+                    [ctx.lower(i.constraint) for i in results],
+                    [i.name for i in results],
+                    [i.variadicity for i in results],
+                )
+
+    @staticmethod
+    def _variadicity_to_segment(variadicity: Variadicity) -> int:
+        if variadicity == Variadicity.variadic:
+            return -1
+        if variadicity == Variadicity.optional:
+            return 0
+        return 1
+
+    @staticmethod
+    def _generate_segments(
+        operands_or_results: List[Union[Operand, Result]],
+    ) -> List[int]:
+        if any(i.variadicity != Variadicity.single for i in operands_or_results):
+            return [
+                Operation._variadicity_to_segment(i.variadicity)
+                for i in operands_or_results
+            ]
+        return None
+
+    def _generate_init_params(self) -> List[_Parameter]:
+        # results are placed at the beginning of the parameter list,
+        # but operands and attributes can appear in any relative order.
+        args = [i for i in self.fields if isinstance(i, Result)] + [
+            i for i in self.fields if not isinstance(i, Result)
+        ]
+        positional_args = [
+            i.name for i in args if i.variadicity != Variadicity.optional
+        ]
+        optional_args = [i.name for i in args if i.variadicity == Variadicity.optional]
+
+        params = [_Parameter("self", _Parameter.POSITIONAL_ONLY)]
+        for i in positional_args:
+            params.append(_Parameter(i, _Parameter.POSITIONAL_OR_KEYWORD))
+        for i in optional_args:
+            params.append(_Parameter(i, _Parameter.KEYWORD_ONLY, default=None))
+        params.append(_Parameter("loc", _Parameter.KEYWORD_ONLY, default=None))
+        params.append(_Parameter("ip", _Parameter.KEYWORD_ONLY, default=None))
+
+        return params
+
+    def _make_op_view_and_builder(self) -> Tuple[type, Callable]:
+        operands, attrs, results = self._partition_fields()
+
+        operand_segments = Operation._generate_segments(operands)
+        result_segments = Operation._generate_segments(results)
+
+        params = self._generate_init_params()
+        init_sig = _Signature(params)
+        op = self
+
+        class _OpView(_ods_ir.OpView):
+            OPERATION_NAME = f"{op.dialect_name}.{op.name}"
+            _ODS_REGIONS = (0, True)
+            _ODS_OPERAND_SEGMENTS = operand_segments
+            _ODS_RESULT_SEGMENTS = result_segments
+
+            def __init__(*args, **kwargs):
+                bound = init_sig.bind(*args, **kwargs)
+                bound.apply_defaults()
+                args = bound.arguments
+
+                _operands = [args[operand.name] for operand in operands]
+                _results = [args[result.name] for result in results]
+                _attributes = dict(
+                    (attr.name, args[attr.name])
+                    for attr in attrs
+                    if args[attr.name] is not None
+                )
+                _regions = None
+                _ods_successors = None
+                self = args["self"]
+                super(_OpView, self).__init__(
+                    self.OPERATION_NAME,
+                    self._ODS_REGIONS,
+                    self._ODS_OPERAND_SEGMENTS,
+                    self._ODS_RESULT_SEGMENTS,
+                    attributes=_attributes,
+                    results=_results,
+                    operands=_operands,
+                    successors=_ods_successors,
+                    regions=_regions,
+                    loc=args["loc"],
+                    ip=args["ip"],
+                )
+
+            __init__.__signature__ = init_sig
+
+        for attr in attrs:
+            setattr(
+                _OpView,
+                attr.name,
+                property(lambda self, name=attr.name: self.attributes[name]),
+            )
+
+        def value_range_getter(
+            value_range: Union[_ods_ir.OpOperandList, _ods_ir.OpResultList],
+            variadicity: Variadicity,
+        ):
+            if variadicity == Variadicity.single:
+                return value_range[0]
+            if variadicity == Variadicity.optional:
+                return value_range[0] if len(value_range) > 0 else None
+            return value_range
+
+        for i, operand in enumerate(operands):
+            if operand_segments:
+
+                def getter(self, i=i, operand=operand):
+                    operand_range = _ods_segmented_accessor(
+                        self.operation.operands,
+                        self.operation.attributes["operandSegmentSizes"],
+                        i,
+                    )
+                    return value_range_getter(operand_range, operand.variadicity)
+
+                setattr(_OpView, operand.name, property(getter))
+            else:
+                setattr(
+                    _OpView, operand.name, property(lambda self, i=i: self.operands[i])
+                )
+        for i, result in enumerate(results):
+            if result_segments:
+
+                def getter(self, i=i, result=result):
+                    result_range = _ods_segmented_accessor(
+                        self.operation.results,
+                        self.operation.attributes["resultSegmentSizes"],
+                        i,
+                    )
+                    return value_range_getter(result_range, result.variadicity)
+
+                setattr(_OpView, result.name, property(getter))
+            else:
+                setattr(
+                    _OpView, result.name, property(lambda self, i=i: self.results[i])
+                )
+
+        def _builder(*args, **kwargs) -> _OpView:
+            return _OpView(*args, **kwargs)
+
+        _builder.__signature__ = _Signature(params[1:])
+
+        return _OpView, _builder
+
+
+class Dialect:
+    def __init__(self, name: str):
+        self.name = name
+        self.operations: List[Operation] = []
+        self.namespace = _SimpleNameSpace()
+
+    def _emit(self) -> None:
+        d = _irdl.dialect(self.name)
+        with _ods_ir.InsertionPoint(d.body):
+            for op in self.operations:
+                op._emit()
+
+    def _make_module(self) -> _ods_ir.Module:
+        with _ods_ir.Location.unknown():
+            m = _ods_ir.Module.create()
+            with _ods_ir.InsertionPoint(m.body):
+                self._emit()
+        return m
+
+    def _make_dialect_class(self) -> type:
+        class _Dialect(_ods_ir.Dialect):
+            DIALECT_NAMESPACE = self.name
+
+        return _Dialect
+
+    def load(self) -> _SimpleNameSpace:
+        _irdl.load_dialects(self._make_module())
+        dialect_class = self._make_dialect_class()
+        _ods_cext.register_dialect(dialect_class)
+        for op in self.operations:
+            _ods_cext.register_operation(dialect_class)(op.op_view)
+        return self.namespace
+
+    @staticmethod
+    def _op_name_to_python_id(name: str) -> str:
+        return name.replace(".", "_").replace("$", "_")
+
+    def op(self, name: str) -> Callable[[type], type]:
+        def decorator(cls: type) -> type:
+            fields = [
+                field for field in cls.__dict__.values() if isinstance(field, _FieldDef)
+            ]
+            op_def = Operation(self.name, name, fields)
+
+            op_view, builder = op_def._make_op_view_and_builder()
+            op_view.__name__ = cls.__name__
+            setattr(op_def, "op_view", op_view)
+            setattr(op_def, "builder", builder)
+            self.operations.append(op_def)
+
+            self.namespace.__dict__[cls.__name__] = op_view
+            self.namespace.__dict__[Dialect._op_name_to_python_id(name)] = builder
+
+            return cls
+
+        return decorator
diff --git a/mlir/test/python/dialects/irdsl.py b/mlir/test/python/dialects/irdsl.py
new file mode 100644
index 0000000000000..8ef30ae0a4c13
--- /dev/null
+++ b/mlir/test/python/dialects/irdsl.py
@@ -0,0 +1,308 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects.irdl import dsl as irdsl
+from mlir.dialects import arith
+import sys
+
+
+def run(f):
+    print("\nTEST:", f.__name__, file=sys.stderr)
+    with Context():
+        f()
+
+
+# CHECK: TEST: testMyInt
+@run
+def testMyInt():
+    myint = irdsl.Dialect("myint")
+    iattr = irdsl.BaseName("#builtin.integer")
+    i32 = irdsl.IsType(IntegerType.get_signless(32))
+
+    @myint.op("constant")
+    class ConstantOp:
+        value = irdsl.Attribute(iattr)
+        cst = irdsl.Result(i32)
+
+    @myint.op("add")
+    class AddOp:
+        lhs = irdsl.Operand(i32)
+        rhs = irdsl.Operand(i32)
+        res = irdsl.Result(i32)
+
+    # CHECK: irdl.dialect @myint {
+    # CHECK:   irdl.operation @constant {
+    # CHECK:     %0 = irdl.base "#builtin.integer"
+    # CHECK:     irdl.attributes {"value" = %0}
+    # CHECK:     %1 = irdl.is i32
+    # CHECK:     irdl.results(cst: %1)
+    # CHECK:   }
+    # CHECK:   irdl.operation @add {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(lhs: %0, rhs: %0)
+    # CHECK:     irdl.results(res: %0)
+    # CHECK:   }
+    # CHECK: }
+    print(myint._make_module())
+    myint = myint.load()
+
+    # CHECK: ['ConstantOp', 'constant', 'AddOp', 'add']
+    print([i for i in myint.__dict__.keys()])
+
+    i32 = IntegerType.get_signless(32)
+    with Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            two = myint.constant(i32, IntegerAttr.get(i32, 2))
+            three = myint.constant(i32, IntegerAttr.get(i32, 3))
+            add1 = myint.add(i32, two, three)
+            add2 = myint.add(i32, add1, two)
+            add3 = myint.add(i32, add2, three)
+
+    # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
+    # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
+    # CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
+    # CHECK: %3 = "myint.add"(%2, %0) : (i32, i32) -> i32
+    # CHECK: %4 = "myint.add"(%3, %1) : (i32, i32) -> i32
+    print(module)
+    assert module.operation.verify()
+
+    # CHECK: AddOp
+    print(type(add1).__name__)
+    # CHECK: ConstantOp
+    print(type(two).__name__)
+    # CHECK: myint.add
+    print(add1.OPERATION_NAME)
+    # CHECK: None
+    print(add1._ODS_OPERAND_SEGMENTS)
+    # CHECK: None
+    print(add1._ODS_RESULT_SEGMENTS)
+    # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
+    print(add1.lhs.owner)
+    # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
+    print(add1.rhs.owner)
+    # CHECK: 2 : i32
+    print(two.value)
+    # CHECK: Value(%0
+    print(two.cst)
+    # CHECK: (res, lhs, rhs, *, loc=None, ip=None)
+    print(myint.add.__signature__)
+    # CHECK: (cst, value, *, loc=None, ip=None)
+    print(myint.constant.__signature__)
+
+
+@run
+def testIRDSL():
+    test = irdsl.Dialect("irdsl_test")
+    i32 = irdsl.IsType(IntegerType.get_signless(32))
+    i64 = irdsl.IsType(IntegerType.get_signless(64))
+    i32or64 = i32 | i64
+    any = irdsl.Any()
+    f32 = irdsl.IsType(F32Type.get())
+    iattr = irdsl.BaseName("#builtin.integer")
+    fattr = irdsl.BaseName("#builtin.float")
+
+    @test.op("constraint")
+    class ConstraintOp:
+        a = irdsl.Operand(i32or64)
+        b = irdsl.Operand(any)
+        c = irdsl.Operand(f32 | i32)
+        d = irdsl.Operand(any)
+        x = irdsl.Attribute(iattr)
+        y = irdsl.Attribute(fattr)
+
+    @test.op("optional")
+    class OptionalOp:
+        a = irdsl.Operand(i32)
+        b = irdsl.Operand(i32, irdsl.Variadicity.optional)
+        out1 = irdsl.Result(i32)
+        out2 = irdsl.Result(i32, irdsl.Variadicity.optional)
+        out3 = irdsl.Result(i32)
+
+    @test.op("optional2")
+    class Optional2Op:
+        a = irdsl.Operand(i32, irdsl.Variadicity.optional)
+        b = irdsl.Result(i32, irdsl.Variadicity.optional)
+
+    @test.op("variadic")
+    class VariadicOp:
+        a = irdsl.Operand(i32)
+        b = irdsl.Operand(i32, irdsl.Variadicity.optional)
+        c = irdsl.Operand(i32, irdsl.Variadicity.variadic)
+        out1 = irdsl.Result(i32, irdsl.Variadicity.variadic)
+        out2 = irdsl.Result(i32, irdsl.Variadicity.variadic)
+        out3 = irdsl.Result(i32, irdsl.Variadicity.optional)
+        out4 = irdsl.Result(i32)
+
+    @test.op("variadic2")
+    class Variadic2Op:
+        a = irdsl.Operand(i32, irdsl.Variadicity.variadic)
+        b = irdsl.Result(i32, irdsl.Variadicity.variadic)
+
+    @test.op("mixed")
+    class MixedOp:
+        out = irdsl.Result(i32)
+        in1 = irdsl.Operand(i32)
+        in2 = irdsl.Attribute(iattr)
+        in3 = irdsl.Operand(i32, irdsl.Variadicity.optional)
+        in4 = irdsl.Attribute(iattr)
+        in5 = irdsl.Operand(i32)
+
+    # CHECK: irdl.dialect @irdsl_test {
+    # CHECK:   irdl.operation @constraint {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     %1 = irdl.is i64
+    # CHECK:     %2 = irdl.any_of(%0, %1)
+    # CHECK:     %3 = irdl.any
+    # CHECK:     %4 = irdl.is f32
+    # CHECK:     %5 = irdl.any_of(%4, %0)
+    # CHECK:     irdl.operands(a: %2, b: %3, c: %5, d: %3)
+    # CHECK:     %6 = irdl.base "#builtin.integer"
+    # CHECK:     %7 = irdl.base "#builtin.float"
+    # CHECK:     irdl.attributes {"x" = %6, "y" = %7}
+    # CHECK:   }
+    # CHECK:   irdl.operation @optional {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(a: %0, b: optional %0)
+    # CHECK:     irdl.results(out1: %0, out2: optional %0, out3: %0)
+    # CHECK:   }
+    # CHECK:   irdl.operation @optional2 {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(a: optional %0)
+    # CHECK:     irdl.results(b: optional %0)
+    # CHECK:   }
+    # CHECK:   irdl.operation @variadic {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(a: %0, b: optional %0, c: variadic %0)
+    # CHECK:     irdl.results(out1...
[truncated]

@rolfmorel
Copy link
Contributor

Nice -- this has been on my wishlist!

Could we get a comparison with the xDSL API for defining ops/dialects? E.g. versus https://github.com/xdslproject/xdsl/blob/main/tests/irdl/test_operation_definition.py. Surely they have "lessons learnt" we should consider and there might be some utility in having API compatibility.

(P.S. @PragmaTwice, I have a local proof-of-concept for exposing MLIR C++ Interfaces' ExternalModels to Python so that their methods can be implemented in Python and the external model attached to any op in Python. For now the API is rather low-level, i.e. actually implement the ExternalModel as a separate class and attach it to a op. The long-term vision is that such a op-definition DSL would allow for multiple inheritance so that you just specify the interfaces as additional parent classes and implement their interface's methods alongside other Python methods on the op. Some Python metaprogramming magic would take care of still making the appropriate ExternalModels and attaching them to the op, upon the op getting registered.)

@PragmaTwice
Copy link
Member Author

PragmaTwice commented Nov 25, 2025

Could we get a comparison with the xDSL API for defining ops/dialects?

Absolutely, I’ll try to put together a comparison table. Because of some differences in the underlying mechanisms (for example, MLIR Python’s OpView/builder story is a bit different, also the implementation of IRDL I think?), the syntax will likely diverge in a few places, but if we can incorporate the nicer parts of their design, that would definitely be ideal — and some level of API compatibility would be a great bonus.

I have a local proof-of-concept for exposing MLIR C++ Interfaces' ExternalModels to Python so that their methods can be implemented in Python and the external model attached to any op in Python.

Does this (eventually) imply extending the IRDL dialect to introduce concepts like interfaces and traits? I’m wondering if this overlaps with this issue: #158066. I’m really looking forward to seeing this — it sounds like very cool work. I’ve also been thinking about how to extend IRDL to support a more flexible constraint system (#158049), but that’s still in a very early stage.

We can keep these possible IRDL evolutions in mind so that, when we introduce things like interfaces/traits, the DSL in this PR won’t need to change too drastically and users won’t have to rewrite large amounts of code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:python MLIR Python bindings mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants