Skip to content

[MLIR][Python] Support op adaptor for Python-defined operations#183528

Merged
PragmaTwice merged 2 commits into
llvm:mainfrom
PragmaTwice:mlir-python-ext-adaptor
Feb 27, 2026
Merged

[MLIR][Python] Support op adaptor for Python-defined operations#183528
PragmaTwice merged 2 commits into
llvm:mainfrom
PragmaTwice:mlir-python-ext-adaptor

Conversation

@PragmaTwice
Copy link
Copy Markdown
Member

Previously, in #177782, we added support for dialect conversion and generated an OpAdaptor subtype for every ODS-defined operation. In this PR, we will also generate OpAdaptor subtypes for Python-defined operations, so that they can be applied in dialect conversion as well.

@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Feb 26, 2026

@llvm/pr-subscribers-mlir

Author: Twice (PragmaTwice)

Changes

Previously, in #177782, we added support for dialect conversion and generated an OpAdaptor subtype for every ODS-defined operation. In this PR, we will also generate OpAdaptor subtypes for Python-defined operations, so that they can be applied in dialect conversion as well.


Full diff: https://github.com/llvm/llvm-project/pull/183528.diff

2 Files Affected:

  • (modified) mlir/python/mlir/dialects/ext.py (+49-1)
  • (modified) mlir/test/python/dialects/ext.py (+10)
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 39aacf32dabb9..d88e25cced8f6 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -41,7 +41,17 @@
 Region = ir.Region
 
 register_dialect = _cext.register_dialect
-register_operation = _cext.register_operation
+
+
+def register_operation(dialect_cls: type) -> Callable[[type], type]:
+    register = _cext.register_operation(dialect_cls)
+
+    def decorator(op_cls: type) -> type:
+        register(op_cls)
+        _cext.register_op_adaptor(op_cls)(op_cls.Adaptor)
+        return op_cls
+
+    return decorator
 
 
 def construct_instance(origin, args):
@@ -307,6 +317,13 @@ def __init_subclass__(
         cls._generate_result_properties(results)
         cls._generate_region_properties(regions)
 
+        cls.Adaptor = type(
+            "Adaptor",
+            (OperationAdator,),
+            dict(),
+            operation=cls,
+        )
+
         dialect_obj.operations.append(cls)
 
     @staticmethod
@@ -507,6 +524,37 @@ def _emit_operation(cls) -> None:
                 )
 
 
+class OperationAdator(ir.OpAdaptor):
+    @classmethod
+    def __init_subclass__(cls, *, operation: type):
+        cls.OPERATION_NAME = operation.OPERATION_NAME
+        cls._operation_cls = operation
+
+        operands, attrs, results, regions = partition_fields(operation._fields)
+
+        for attr in attrs:
+            setattr(
+                cls,
+                attr.name,
+                property(lambda self, name=attr.name: self.attributes[name]),
+            )
+
+        for i, operand in enumerate(operands):
+            if operation._ODS_OPERAND_SEGMENTS:
+
+                def getter(self, i=i, operand=operand):
+                    operand_range = segmented_accessor(
+                        self.operands,
+                        self.attributes["operandSegmentSizes"],
+                        i,
+                    )
+                    return normalize_value_range(operand_range, operand.variadicity)
+
+                setattr(cls, operand.name, property(getter))
+            else:
+                setattr(cls, operand.name, property(lambda self, i=i: self.operands[i]))
+
+
 @dataclass
 class ParamDef:
     name: str
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index f9252bad37a39..2921615e75d54 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -91,6 +91,16 @@ class AddOp(Operation, dialect=MyInt, name="add"):
         # CHECK: (self, /, value, *, loc=None, ip=None)
         print(ConstantOp.__init__.__signature__)
 
+        # CHECK: True
+        print(issubclass(AddOp.Adaptor, OpAdaptor))
+        adaptor1 = AddOp.Adaptor(list(add1.operands), add1)
+        # CHECK: myint.add
+        print(adaptor1.OPERATION_NAME)
+        # CHECK: OpResult(%0 = "myint.constant"() {value = 2 : i32} : () -> i32)
+        print(adaptor1.lhs)
+        # CHECK: OpResult(%1 = "myint.constant"() {value = 3 : i32} : () -> i32)
+        print(adaptor1.rhs)
+
 
 # CHECK: TEST: testExtDialect
 @run

Copy link
Copy Markdown
Contributor

@rolfmorel rolfmorel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@PragmaTwice PragmaTwice merged commit 361e235 into llvm:main Feb 27, 2026
13 of 14 checks passed
sujianIBM pushed a commit to sujianIBM/llvm-project that referenced this pull request Mar 5, 2026
…#183528)

Previously, in llvm#177782, we added support for dialect conversion and
generated an `OpAdaptor` subtype for every ODS-defined operation. In
this PR, we will also generate `OpAdaptor` subtypes for Python-defined
operations, so that they can be applied in dialect conversion as well.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:python MLIR Python bindings mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants