From 8cbbb6b7a3b21775498cf12bf64ebd7295d4cdf8 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Sun, 14 Sep 2025 22:16:38 +0800 Subject: [PATCH 01/10] [MLIR][Python] Add python bindings for IRDL dialect --- .../mlir/Dialect/IRDL/IR/IRDLAttributes.td | 2 +- mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td | 8 ++-- .../include/mlir/Dialect/IRDL/IR/IRDLTypes.td | 2 +- mlir/lib/Bindings/Python/DialectIRDL.cpp | 36 +++++++++++++++ mlir/python/CMakeLists.txt | 23 ++++++++++ mlir/python/mlir/dialects/IRDLOps.td | 14 ++++++ mlir/python/mlir/dialects/irdl.py | 43 ++++++++++++++++++ mlir/test/python/dialects/irdl.py | 45 +++++++++++++++++++ 8 files changed, 167 insertions(+), 6 deletions(-) create mode 100644 mlir/lib/Bindings/Python/DialectIRDL.cpp create mode 100644 mlir/python/mlir/dialects/IRDLOps.td create mode 100644 mlir/python/mlir/dialects/irdl.py create mode 100644 mlir/test/python/dialects/irdl.py diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td index 1bfa7f5cb894b..2f568e8b6c42a 100644 --- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td +++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td @@ -13,7 +13,7 @@ #ifndef MLIR_DIALECT_IRDL_IR_IRDLATTRIBUTES #define MLIR_DIALECT_IRDL_IR_IRDLATTRIBUTES -include "IRDL.td" +include "mlir/Dialect/IRDL/IR/IRDL.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/EnumAttr.td" diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td index 4a83eb62fba32..3b6b09973645c 100644 --- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td +++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td @@ -13,10 +13,10 @@ #ifndef MLIR_DIALECT_IRDL_IR_IRDLOPS #define MLIR_DIALECT_IRDL_IR_IRDLOPS -include "IRDL.td" -include "IRDLAttributes.td" -include "IRDLTypes.td" -include "IRDLInterfaces.td" +include "mlir/Dialect/IRDL/IR/IRDL.td" +include "mlir/Dialect/IRDL/IR/IRDLAttributes.td" +include "mlir/Dialect/IRDL/IR/IRDLTypes.td" +include "mlir/Dialect/IRDL/IR/IRDLInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/IR/SymbolInterfaces.td" diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td index 9b17bf23df182..9cde433cf33a6 100644 --- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td +++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td @@ -14,7 +14,7 @@ #define MLIR_DIALECT_IRDL_IR_IRDLTYPES include "mlir/IR/AttrTypeBase.td" -include "IRDL.td" +include "mlir/Dialect/IRDL/IR/IRDL.td" class IRDL_Type traits = []> : TypeDef { diff --git a/mlir/lib/Bindings/Python/DialectIRDL.cpp b/mlir/lib/Bindings/Python/DialectIRDL.cpp new file mode 100644 index 0000000000000..8264d21d4fa03 --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectIRDL.cpp @@ -0,0 +1,36 @@ +//===--- DialectIRDL.cpp - Pybind module for IRDL dialect API support ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/IRDL.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" + +namespace nb = nanobind; +using namespace llvm; +using namespace mlir; +using namespace mlir::python; +using namespace mlir::python::nanobind_adaptors; + +static void populateDialectIRDLSubmodule(nb::module_ &m) { + m.def( + "load_dialects", + [](MlirModule module) { + if (mlirLogicalResultIsFailure(mlirLoadIRDLDialects(module))) + throw std::runtime_error( + "failed to load IRDL dialects from the input module"); + }, + nb::arg("module"), "Load IRDL dialects from the given module."); +} + +NB_MODULE(_mlirDialectsIRDL, m) { + m.doc() = "MLIR IRDL dialect."; + + populateDialectIRDLSubmodule(m); +} diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index c983914722ce1..7b2e1b8c36f25 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -470,6 +470,15 @@ declare_mlir_dialect_python_bindings( GEN_ENUM_BINDINGS_TD_FILE "dialects/VectorAttributes.td") +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/IRDLOps.td + SOURCES dialects/irdl.py + DIALECT_NAME irdl + GEN_ENUM_BINDINGS +) + ################################################################################ # Python extensions. # The sources for these are all in lib/Bindings/Python, but since they have to @@ -645,6 +654,20 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind MLIRCAPITransformDialect ) +declare_mlir_python_extension(MLIRPythonExtension.Dialects.IRDL.Pybind + MODULE_NAME _mlirDialectsIRDL + ADD_TO_PARENT MLIRPythonSources.Dialects.irdl + ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + DialectIRDL.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPIIRDL +) + declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses MODULE_NAME _mlirAsyncPasses ADD_TO_PARENT MLIRPythonSources.Dialects.async diff --git a/mlir/python/mlir/dialects/IRDLOps.td b/mlir/python/mlir/dialects/IRDLOps.td new file mode 100644 index 0000000000000..695839f1aa08b --- /dev/null +++ b/mlir/python/mlir/dialects/IRDLOps.td @@ -0,0 +1,14 @@ +//===-- IRDLOps.td - Entry point for IRDL bind ---------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_IRDL_OPS +#define PYTHON_BINDINGS_IRDL_OPS + +include "mlir/Dialect/IRDL/IR/IRDLOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/irdl.py b/mlir/python/mlir/dialects/irdl.py new file mode 100644 index 0000000000000..2314ee99950e0 --- /dev/null +++ b/mlir/python/mlir/dialects/irdl.py @@ -0,0 +1,43 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._irdl_ops_gen import * +from ._irdl_ops_gen import _Dialect +from ._irdl_enum_gen import * +from .._mlir_libs._mlirDialectsIRDL import * +from ..ir import register_attribute_builder +from ._ods_common import ( + get_op_result_or_value as _get_value, + get_op_results_or_values as _get_values, + _cext as _ods_cext, +) +from ..extras.meta import region_op + +@_ods_cext.register_operation(_Dialect, replace=True) +class DialectOp(DialectOp): + """Specialization for the dialect op class.""" + + def __init__(self, sym_name, *, loc=None, ip=None): + super().__init__(sym_name, loc=loc, ip=ip) + self.regions[0].blocks.append() + + @property + def body(self): + return self.regions[0].blocks[0] + +@_ods_cext.register_operation(_Dialect, replace=True) +class OperationOp(OperationOp): + """Specialization for the operation op class.""" + + def __init__(self, sym_name, *, loc=None, ip=None): + super().__init__(sym_name, loc=loc, ip=ip) + self.regions[0].blocks.append() + + @property + def body(self): + return self.regions[0].blocks[0] + +@register_attribute_builder("VariadicityArrayAttr") +def _variadicity_array_attr(x, context): + return _ods_cext.ir.Attribute.parse(f"#irdl") diff --git a/mlir/test/python/dialects/irdl.py b/mlir/test/python/dialects/irdl.py new file mode 100644 index 0000000000000..30983af302a52 --- /dev/null +++ b/mlir/test/python/dialects/irdl.py @@ -0,0 +1,45 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +from mlir.ir import * +from mlir.dialects import irdl +import sys + + +def run(f): + print("\nTEST:", f.__name__, file=sys.stderr) + f() + + +# CHECK: TEST: testIRDL +@run +def testIRDL(): + with Context() as ctx, Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + dialect = irdl.DialectOp("irdl_test") + with InsertionPoint(dialect.body): + op = irdl.OperationOp("test_op") + with InsertionPoint(op.body): + f32 = irdl.is_(TypeAttr.get(F32Type.get())) + irdl.operands_([f32], ["input"], [irdl.Variadicity.single]) + + # CHECK: module { + # CHECK: irdl.dialect @irdl_test { + # CHECK: irdl.operation @test_op { + # CHECK: %0 = irdl.is f32 + # CHECK: irdl.operands(input: %0) + # CHECK: } + # CHECK: } + # CHECK: } + module.dump() + + irdl.load_dialects(module) + + m = Module.parse(""" + module { + %a = arith.constant 1.0 : f32 + "irdl_test.test_op"(%a) : (f32) -> () + } + """) + # CHECK: "irdl_test.test_op"(%cst) : (f32) -> () + m.dump() From 99375299298093d26b020e86e69b9715f16b26ea Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Sun, 14 Sep 2025 22:43:44 +0800 Subject: [PATCH 02/10] format --- mlir/python/mlir/dialects/irdl.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/irdl.py b/mlir/python/mlir/dialects/irdl.py index 2314ee99950e0..743870162a8bb 100644 --- a/mlir/python/mlir/dialects/irdl.py +++ b/mlir/python/mlir/dialects/irdl.py @@ -14,6 +14,7 @@ ) from ..extras.meta import region_op + @_ods_cext.register_operation(_Dialect, replace=True) class DialectOp(DialectOp): """Specialization for the dialect op class.""" @@ -26,6 +27,7 @@ def __init__(self, sym_name, *, loc=None, ip=None): def body(self): return self.regions[0].blocks[0] + @_ods_cext.register_operation(_Dialect, replace=True) class OperationOp(OperationOp): """Specialization for the operation op class.""" @@ -38,6 +40,9 @@ def __init__(self, sym_name, *, loc=None, ip=None): def body(self): return self.regions[0].blocks[0] + @register_attribute_builder("VariadicityArrayAttr") def _variadicity_array_attr(x, context): - return _ods_cext.ir.Attribute.parse(f"#irdl") + return _ods_cext.ir.Attribute.parse( + f"#irdl" + ) From 5bafd4247920ec9db95468c130d3cba885e50ca1 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Sun, 14 Sep 2025 22:50:03 +0800 Subject: [PATCH 03/10] format --- mlir/test/python/dialects/irdl.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlir/test/python/dialects/irdl.py b/mlir/test/python/dialects/irdl.py index 30983af302a52..0c4007b4d2a95 100644 --- a/mlir/test/python/dialects/irdl.py +++ b/mlir/test/python/dialects/irdl.py @@ -35,11 +35,13 @@ def testIRDL(): irdl.load_dialects(module) - m = Module.parse(""" + m = Module.parse( + """ module { %a = arith.constant 1.0 : f32 "irdl_test.test_op"(%a) : (f32) -> () } - """) + """ + ) # CHECK: "irdl_test.test_op"(%cst) : (f32) -> () m.dump() From fc7728d98cac3a3c155680a0030b5e18a00a07f3 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Sun, 14 Sep 2025 22:56:46 +0800 Subject: [PATCH 04/10] refine test case --- mlir/test/python/dialects/irdl.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mlir/test/python/dialects/irdl.py b/mlir/test/python/dialects/irdl.py index 0c4007b4d2a95..307a3f90af8b6 100644 --- a/mlir/test/python/dialects/irdl.py +++ b/mlir/test/python/dialects/irdl.py @@ -43,5 +43,7 @@ def testIRDL(): } """ ) - # CHECK: "irdl_test.test_op"(%cst) : (f32) -> () + # CHECK: module { + # CHECK: "irdl_test.test_op"(%cst) : (f32) -> () + # CHECK: } m.dump() From df14b403ad05b9dea48e6206dad5df0a5a75339f Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Wed, 17 Sep 2025 10:54:43 +0800 Subject: [PATCH 05/10] add class for typeop and attributeop --- mlir/python/mlir/dialects/irdl.py | 26 ++++++++++++++++++++++++++ mlir/test/python/dialects/irdl.py | 17 +++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/mlir/python/mlir/dialects/irdl.py b/mlir/python/mlir/dialects/irdl.py index 743870162a8bb..8fb2e21bbe0d3 100644 --- a/mlir/python/mlir/dialects/irdl.py +++ b/mlir/python/mlir/dialects/irdl.py @@ -41,6 +41,32 @@ def body(self): return self.regions[0].blocks[0] +@_ods_cext.register_operation(_Dialect, replace=True) +class TypeOp(TypeOp): + """Specialization for the type op class.""" + + def __init__(self, sym_name, *, loc=None, ip=None): + super().__init__(sym_name, loc=loc, ip=ip) + self.regions[0].blocks.append() + + @property + def body(self): + return self.regions[0].blocks[0] + + +@_ods_cext.register_operation(_Dialect, replace=True) +class AttributeOp(AttributeOp): + """Specialization for the attribute op class.""" + + def __init__(self, sym_name, *, loc=None, ip=None): + super().__init__(sym_name, loc=loc, ip=ip) + self.regions[0].blocks.append() + + @property + def body(self): + return self.regions[0].blocks[0] + + @register_attribute_builder("VariadicityArrayAttr") def _variadicity_array_attr(x, context): return _ods_cext.ir.Attribute.parse( diff --git a/mlir/test/python/dialects/irdl.py b/mlir/test/python/dialects/irdl.py index 307a3f90af8b6..e01c5be238976 100644 --- a/mlir/test/python/dialects/irdl.py +++ b/mlir/test/python/dialects/irdl.py @@ -22,6 +22,14 @@ def testIRDL(): with InsertionPoint(op.body): f32 = irdl.is_(TypeAttr.get(F32Type.get())) irdl.operands_([f32], ["input"], [irdl.Variadicity.single]) + type1 = irdl.TypeOp("type1") + with InsertionPoint(type1.body): + f32 = irdl.is_(TypeAttr.get(F32Type.get())) + irdl.parameters([f32], ["val"]) + attr1 = irdl.AttributeOp("attr1") + with InsertionPoint(attr1.body): + test = irdl.is_(StringAttr.get("test")) + irdl.parameters([test], ["val"]) # CHECK: module { # CHECK: irdl.dialect @irdl_test { @@ -29,8 +37,17 @@ def testIRDL(): # CHECK: %0 = irdl.is f32 # CHECK: irdl.operands(input: %0) # CHECK: } + # CHECK: irdl.type @type1 { + # CHECK: %0 = irdl.is f32 + # CHECK: irdl.parameters(val: %0) + # CHECK: } + # CHECK: irdl.attribute @attr1 { + # CHECK: %0 = irdl.is "test" + # CHECK: irdl.parameters(val: %0) + # CHECK: } # CHECK: } # CHECK: } + module.operation.verify() module.dump() irdl.load_dialects(module) From 4357853afcca1ba3f7d18cd871f004d6e04a101d Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Thu, 18 Sep 2025 10:37:09 +0800 Subject: [PATCH 06/10] apply review suggestions --- mlir/python/mlir/dialects/irdl.py | 31 +++++++++++++++++++++---------- mlir/test/python/dialects/irdl.py | 26 +++++++++++++------------- 2 files changed, 34 insertions(+), 23 deletions(-) diff --git a/mlir/python/mlir/dialects/irdl.py b/mlir/python/mlir/dialects/irdl.py index 8fb2e21bbe0d3..f4748375d3db5 100644 --- a/mlir/python/mlir/dialects/irdl.py +++ b/mlir/python/mlir/dialects/irdl.py @@ -7,17 +7,12 @@ from ._irdl_enum_gen import * from .._mlir_libs._mlirDialectsIRDL import * from ..ir import register_attribute_builder -from ._ods_common import ( - get_op_result_or_value as _get_value, - get_op_results_or_values as _get_values, - _cext as _ods_cext, -) -from ..extras.meta import region_op +from ._ods_common import _cext as _ods_cext @_ods_cext.register_operation(_Dialect, replace=True) class DialectOp(DialectOp): - """Specialization for the dialect op class.""" + __doc__ = DialectOp.__doc__ def __init__(self, sym_name, *, loc=None, ip=None): super().__init__(sym_name, loc=loc, ip=ip) @@ -28,9 +23,13 @@ def body(self): return self.regions[0].blocks[0] +def dialect(sym_name, *, loc=None, ip=None) -> DialectOp: + return DialectOp(sym_name=sym_name, loc=loc, ip=ip) + + @_ods_cext.register_operation(_Dialect, replace=True) class OperationOp(OperationOp): - """Specialization for the operation op class.""" + __doc__ = OperationOp.__doc__ def __init__(self, sym_name, *, loc=None, ip=None): super().__init__(sym_name, loc=loc, ip=ip) @@ -41,9 +40,13 @@ def body(self): return self.regions[0].blocks[0] +def operation_(sym_name, *, loc=None, ip=None) -> OperationOp: + return OperationOp(sym_name=sym_name, loc=loc, ip=ip) + + @_ods_cext.register_operation(_Dialect, replace=True) class TypeOp(TypeOp): - """Specialization for the type op class.""" + __doc__ = TypeOp.__doc__ def __init__(self, sym_name, *, loc=None, ip=None): super().__init__(sym_name, loc=loc, ip=ip) @@ -54,9 +57,13 @@ def body(self): return self.regions[0].blocks[0] +def type_(sym_name, *, loc=None, ip=None) -> TypeOp: + return TypeOp(sym_name=sym_name, loc=loc, ip=ip) + + @_ods_cext.register_operation(_Dialect, replace=True) class AttributeOp(AttributeOp): - """Specialization for the attribute op class.""" + __doc__ = AttributeOp.__doc__ def __init__(self, sym_name, *, loc=None, ip=None): super().__init__(sym_name, loc=loc, ip=ip) @@ -67,6 +74,10 @@ def body(self): return self.regions[0].blocks[0] +def attribute(sym_name, *, loc=None, ip=None) -> AttributeOp: + return AttributeOp(sym_name=sym_name, loc=loc, ip=ip) + + @register_attribute_builder("VariadicityArrayAttr") def _variadicity_array_attr(x, context): return _ods_cext.ir.Attribute.parse( diff --git a/mlir/test/python/dialects/irdl.py b/mlir/test/python/dialects/irdl.py index e01c5be238976..ed62db9b69968 100644 --- a/mlir/test/python/dialects/irdl.py +++ b/mlir/test/python/dialects/irdl.py @@ -1,7 +1,7 @@ # RUN: %PYTHON %s 2>&1 | FileCheck %s from mlir.ir import * -from mlir.dialects import irdl +from mlir.dialects.irdl import * import sys @@ -16,20 +16,20 @@ def testIRDL(): with Context() as ctx, Location.unknown(): module = Module.create() with InsertionPoint(module.body): - dialect = irdl.DialectOp("irdl_test") - with InsertionPoint(dialect.body): - op = irdl.OperationOp("test_op") + irdl_test = dialect("irdl_test") + with InsertionPoint(irdl_test.body): + op = operation_("test_op") with InsertionPoint(op.body): - f32 = irdl.is_(TypeAttr.get(F32Type.get())) - irdl.operands_([f32], ["input"], [irdl.Variadicity.single]) - type1 = irdl.TypeOp("type1") + f32 = is_(TypeAttr.get(F32Type.get())) + operands_([f32], ["input"], [Variadicity.single]) + type1 = type_("type1") with InsertionPoint(type1.body): - f32 = irdl.is_(TypeAttr.get(F32Type.get())) - irdl.parameters([f32], ["val"]) - attr1 = irdl.AttributeOp("attr1") + f32 = is_(TypeAttr.get(F32Type.get())) + parameters([f32], ["val"]) + attr1 = attribute("attr1") with InsertionPoint(attr1.body): - test = irdl.is_(StringAttr.get("test")) - irdl.parameters([test], ["val"]) + test = is_(StringAttr.get("test")) + parameters([test], ["val"]) # CHECK: module { # CHECK: irdl.dialect @irdl_test { @@ -50,7 +50,7 @@ def testIRDL(): module.operation.verify() module.dump() - irdl.load_dialects(module) + load_dialects(module) m = Module.parse( """ From 0cbf5203a0c564547f093233aaa96738b867ef59 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Thu, 18 Sep 2025 10:41:04 +0800 Subject: [PATCH 07/10] remove useless using namespace --- mlir/lib/Bindings/Python/DialectIRDL.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/DialectIRDL.cpp b/mlir/lib/Bindings/Python/DialectIRDL.cpp index 8264d21d4fa03..08bcab97c03ec 100644 --- a/mlir/lib/Bindings/Python/DialectIRDL.cpp +++ b/mlir/lib/Bindings/Python/DialectIRDL.cpp @@ -13,7 +13,6 @@ #include "mlir/Bindings/Python/NanobindAdaptors.h" namespace nb = nanobind; -using namespace llvm; using namespace mlir; using namespace mlir::python; using namespace mlir::python::nanobind_adaptors; From e76fa6520c243274fcea51fd8f5da47959b24fe1 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Thu, 18 Sep 2025 10:43:06 +0800 Subject: [PATCH 08/10] make it 80 characters --- mlir/python/mlir/dialects/IRDLOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/IRDLOps.td b/mlir/python/mlir/dialects/IRDLOps.td index 695839f1aa08b..7b061fcf30836 100644 --- a/mlir/python/mlir/dialects/IRDLOps.td +++ b/mlir/python/mlir/dialects/IRDLOps.td @@ -1,4 +1,4 @@ -//===-- IRDLOps.td - Entry point for IRDL bind ---------*- tablegen -*-===// +//===-- IRDLOps.td - Entry point for IRDL binding ----------*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. From 0c7af71ad2a51abbcdf7cabf117118b02a4d63ae Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Thu, 18 Sep 2025 11:12:11 +0800 Subject: [PATCH 09/10] format --- mlir/python/mlir/dialects/irdl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/python/mlir/dialects/irdl.py b/mlir/python/mlir/dialects/irdl.py index f4748375d3db5..40780ece8828f 100644 --- a/mlir/python/mlir/dialects/irdl.py +++ b/mlir/python/mlir/dialects/irdl.py @@ -24,7 +24,7 @@ def body(self): def dialect(sym_name, *, loc=None, ip=None) -> DialectOp: - return DialectOp(sym_name=sym_name, loc=loc, ip=ip) + return DialectOp(sym_name=sym_name, loc=loc, ip=ip) @_ods_cext.register_operation(_Dialect, replace=True) @@ -41,7 +41,7 @@ def body(self): def operation_(sym_name, *, loc=None, ip=None) -> OperationOp: - return OperationOp(sym_name=sym_name, loc=loc, ip=ip) + return OperationOp(sym_name=sym_name, loc=loc, ip=ip) @_ods_cext.register_operation(_Dialect, replace=True) @@ -58,7 +58,7 @@ def body(self): def type_(sym_name, *, loc=None, ip=None) -> TypeOp: - return TypeOp(sym_name=sym_name, loc=loc, ip=ip) + return TypeOp(sym_name=sym_name, loc=loc, ip=ip) @_ods_cext.register_operation(_Dialect, replace=True) @@ -75,7 +75,7 @@ def body(self): def attribute(sym_name, *, loc=None, ip=None) -> AttributeOp: - return AttributeOp(sym_name=sym_name, loc=loc, ip=ip) + return AttributeOp(sym_name=sym_name, loc=loc, ip=ip) @register_attribute_builder("VariadicityArrayAttr") From 4be1ac3c50d375b781b351ce7919d89c49257c2f Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Thu, 18 Sep 2025 11:24:08 +0800 Subject: [PATCH 10/10] add type annotations --- mlir/python/mlir/dialects/irdl.py | 37 ++++++++++++++++++------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/mlir/python/mlir/dialects/irdl.py b/mlir/python/mlir/dialects/irdl.py index 40780ece8828f..1ec951b69b646 100644 --- a/mlir/python/mlir/dialects/irdl.py +++ b/mlir/python/mlir/dialects/irdl.py @@ -8,22 +8,25 @@ from .._mlir_libs._mlirDialectsIRDL import * from ..ir import register_attribute_builder from ._ods_common import _cext as _ods_cext +from typing import Union, Sequence + +_ods_ir = _ods_cext.ir @_ods_cext.register_operation(_Dialect, replace=True) class DialectOp(DialectOp): __doc__ = DialectOp.__doc__ - def __init__(self, sym_name, *, loc=None, ip=None): + def __init__(self, sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None): super().__init__(sym_name, loc=loc, ip=ip) self.regions[0].blocks.append() @property - def body(self): + def body(self) -> _ods_ir.Block: return self.regions[0].blocks[0] -def dialect(sym_name, *, loc=None, ip=None) -> DialectOp: +def dialect(sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None) -> DialectOp: return DialectOp(sym_name=sym_name, loc=loc, ip=ip) @@ -31,16 +34,18 @@ def dialect(sym_name, *, loc=None, ip=None) -> DialectOp: class OperationOp(OperationOp): __doc__ = OperationOp.__doc__ - def __init__(self, sym_name, *, loc=None, ip=None): + def __init__(self, sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None): super().__init__(sym_name, loc=loc, ip=ip) self.regions[0].blocks.append() @property - def body(self): + def body(self) -> _ods_ir.Block: return self.regions[0].blocks[0] -def operation_(sym_name, *, loc=None, ip=None) -> OperationOp: +def operation_( + sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None +) -> OperationOp: return OperationOp(sym_name=sym_name, loc=loc, ip=ip) @@ -48,16 +53,16 @@ def operation_(sym_name, *, loc=None, ip=None) -> OperationOp: class TypeOp(TypeOp): __doc__ = TypeOp.__doc__ - def __init__(self, sym_name, *, loc=None, ip=None): + def __init__(self, sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None): super().__init__(sym_name, loc=loc, ip=ip) self.regions[0].blocks.append() @property - def body(self): + def body(self) -> _ods_ir.Block: return self.regions[0].blocks[0] -def type_(sym_name, *, loc=None, ip=None) -> TypeOp: +def type_(sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None) -> TypeOp: return TypeOp(sym_name=sym_name, loc=loc, ip=ip) @@ -65,21 +70,23 @@ def type_(sym_name, *, loc=None, ip=None) -> TypeOp: class AttributeOp(AttributeOp): __doc__ = AttributeOp.__doc__ - def __init__(self, sym_name, *, loc=None, ip=None): + def __init__(self, sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None): super().__init__(sym_name, loc=loc, ip=ip) self.regions[0].blocks.append() @property - def body(self): + def body(self) -> _ods_ir.Block: return self.regions[0].blocks[0] -def attribute(sym_name, *, loc=None, ip=None) -> AttributeOp: +def attribute( + sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None +) -> AttributeOp: return AttributeOp(sym_name=sym_name, loc=loc, ip=ip) @register_attribute_builder("VariadicityArrayAttr") -def _variadicity_array_attr(x, context): - return _ods_cext.ir.Attribute.parse( - f"#irdl" +def _variadicity_array_attr(x: Sequence[Variadicity], context) -> _ods_ir.Attribute: + return _ods_ir.Attribute.parse( + f"#irdl", context )