diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td index 1bfa7f5cb894b..2f568e8b6c42a 100644 --- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td +++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td @@ -13,7 +13,7 @@ #ifndef MLIR_DIALECT_IRDL_IR_IRDLATTRIBUTES #define MLIR_DIALECT_IRDL_IR_IRDLATTRIBUTES -include "IRDL.td" +include "mlir/Dialect/IRDL/IR/IRDL.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/EnumAttr.td" diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td index 4a83eb62fba32..3b6b09973645c 100644 --- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td +++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td @@ -13,10 +13,10 @@ #ifndef MLIR_DIALECT_IRDL_IR_IRDLOPS #define MLIR_DIALECT_IRDL_IR_IRDLOPS -include "IRDL.td" -include "IRDLAttributes.td" -include "IRDLTypes.td" -include "IRDLInterfaces.td" +include "mlir/Dialect/IRDL/IR/IRDL.td" +include "mlir/Dialect/IRDL/IR/IRDLAttributes.td" +include "mlir/Dialect/IRDL/IR/IRDLTypes.td" +include "mlir/Dialect/IRDL/IR/IRDLInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/IR/SymbolInterfaces.td" diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td index 9b17bf23df182..9cde433cf33a6 100644 --- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td +++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td @@ -14,7 +14,7 @@ #define MLIR_DIALECT_IRDL_IR_IRDLTYPES include "mlir/IR/AttrTypeBase.td" -include "IRDL.td" +include "mlir/Dialect/IRDL/IR/IRDL.td" class IRDL_Type traits = []> : TypeDef { diff --git a/mlir/lib/Bindings/Python/DialectIRDL.cpp b/mlir/lib/Bindings/Python/DialectIRDL.cpp new file mode 100644 index 0000000000000..08bcab97c03ec --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectIRDL.cpp @@ -0,0 +1,35 @@ +//===--- DialectIRDL.cpp - Pybind module for IRDL dialect API support ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/IRDL.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" + +namespace nb = nanobind; +using namespace mlir; +using namespace mlir::python; +using namespace mlir::python::nanobind_adaptors; + +static void populateDialectIRDLSubmodule(nb::module_ &m) { + m.def( + "load_dialects", + [](MlirModule module) { + if (mlirLogicalResultIsFailure(mlirLoadIRDLDialects(module))) + throw std::runtime_error( + "failed to load IRDL dialects from the input module"); + }, + nb::arg("module"), "Load IRDL dialects from the given module."); +} + +NB_MODULE(_mlirDialectsIRDL, m) { + m.doc() = "MLIR IRDL dialect."; + + populateDialectIRDLSubmodule(m); +} diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index c983914722ce1..7b2e1b8c36f25 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -470,6 +470,15 @@ declare_mlir_dialect_python_bindings( GEN_ENUM_BINDINGS_TD_FILE "dialects/VectorAttributes.td") +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/IRDLOps.td + SOURCES dialects/irdl.py + DIALECT_NAME irdl + GEN_ENUM_BINDINGS +) + ################################################################################ # Python extensions. # The sources for these are all in lib/Bindings/Python, but since they have to @@ -645,6 +654,20 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind MLIRCAPITransformDialect ) +declare_mlir_python_extension(MLIRPythonExtension.Dialects.IRDL.Pybind + MODULE_NAME _mlirDialectsIRDL + ADD_TO_PARENT MLIRPythonSources.Dialects.irdl + ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + DialectIRDL.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPIIRDL +) + declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses MODULE_NAME _mlirAsyncPasses ADD_TO_PARENT MLIRPythonSources.Dialects.async diff --git a/mlir/python/mlir/dialects/IRDLOps.td b/mlir/python/mlir/dialects/IRDLOps.td new file mode 100644 index 0000000000000..7b061fcf30836 --- /dev/null +++ b/mlir/python/mlir/dialects/IRDLOps.td @@ -0,0 +1,14 @@ +//===-- IRDLOps.td - Entry point for IRDL binding ----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_IRDL_OPS +#define PYTHON_BINDINGS_IRDL_OPS + +include "mlir/Dialect/IRDL/IR/IRDLOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/irdl.py b/mlir/python/mlir/dialects/irdl.py new file mode 100644 index 0000000000000..1ec951b69b646 --- /dev/null +++ b/mlir/python/mlir/dialects/irdl.py @@ -0,0 +1,92 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._irdl_ops_gen import * +from ._irdl_ops_gen import _Dialect +from ._irdl_enum_gen import * +from .._mlir_libs._mlirDialectsIRDL import * +from ..ir import register_attribute_builder +from ._ods_common import _cext as _ods_cext +from typing import Union, Sequence + +_ods_ir = _ods_cext.ir + + +@_ods_cext.register_operation(_Dialect, replace=True) +class DialectOp(DialectOp): + __doc__ = DialectOp.__doc__ + + def __init__(self, sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None): + super().__init__(sym_name, loc=loc, ip=ip) + self.regions[0].blocks.append() + + @property + def body(self) -> _ods_ir.Block: + return self.regions[0].blocks[0] + + +def dialect(sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None) -> DialectOp: + return DialectOp(sym_name=sym_name, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class OperationOp(OperationOp): + __doc__ = OperationOp.__doc__ + + def __init__(self, sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None): + super().__init__(sym_name, loc=loc, ip=ip) + self.regions[0].blocks.append() + + @property + def body(self) -> _ods_ir.Block: + return self.regions[0].blocks[0] + + +def operation_( + sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None +) -> OperationOp: + return OperationOp(sym_name=sym_name, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class TypeOp(TypeOp): + __doc__ = TypeOp.__doc__ + + def __init__(self, sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None): + super().__init__(sym_name, loc=loc, ip=ip) + self.regions[0].blocks.append() + + @property + def body(self) -> _ods_ir.Block: + return self.regions[0].blocks[0] + + +def type_(sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None) -> TypeOp: + return TypeOp(sym_name=sym_name, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class AttributeOp(AttributeOp): + __doc__ = AttributeOp.__doc__ + + def __init__(self, sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None): + super().__init__(sym_name, loc=loc, ip=ip) + self.regions[0].blocks.append() + + @property + def body(self) -> _ods_ir.Block: + return self.regions[0].blocks[0] + + +def attribute( + sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None +) -> AttributeOp: + return AttributeOp(sym_name=sym_name, loc=loc, ip=ip) + + +@register_attribute_builder("VariadicityArrayAttr") +def _variadicity_array_attr(x: Sequence[Variadicity], context) -> _ods_ir.Attribute: + return _ods_ir.Attribute.parse( + f"#irdl", context + ) diff --git a/mlir/test/python/dialects/irdl.py b/mlir/test/python/dialects/irdl.py new file mode 100644 index 0000000000000..ed62db9b69968 --- /dev/null +++ b/mlir/test/python/dialects/irdl.py @@ -0,0 +1,66 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +from mlir.ir import * +from mlir.dialects.irdl import * +import sys + + +def run(f): + print("\nTEST:", f.__name__, file=sys.stderr) + f() + + +# CHECK: TEST: testIRDL +@run +def testIRDL(): + with Context() as ctx, Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + irdl_test = dialect("irdl_test") + with InsertionPoint(irdl_test.body): + op = operation_("test_op") + with InsertionPoint(op.body): + f32 = is_(TypeAttr.get(F32Type.get())) + operands_([f32], ["input"], [Variadicity.single]) + type1 = type_("type1") + with InsertionPoint(type1.body): + f32 = is_(TypeAttr.get(F32Type.get())) + parameters([f32], ["val"]) + attr1 = attribute("attr1") + with InsertionPoint(attr1.body): + test = is_(StringAttr.get("test")) + parameters([test], ["val"]) + + # CHECK: module { + # CHECK: irdl.dialect @irdl_test { + # CHECK: irdl.operation @test_op { + # CHECK: %0 = irdl.is f32 + # CHECK: irdl.operands(input: %0) + # CHECK: } + # CHECK: irdl.type @type1 { + # CHECK: %0 = irdl.is f32 + # CHECK: irdl.parameters(val: %0) + # CHECK: } + # CHECK: irdl.attribute @attr1 { + # CHECK: %0 = irdl.is "test" + # CHECK: irdl.parameters(val: %0) + # CHECK: } + # CHECK: } + # CHECK: } + module.operation.verify() + module.dump() + + load_dialects(module) + + m = Module.parse( + """ + module { + %a = arith.constant 1.0 : f32 + "irdl_test.test_op"(%a) : (f32) -> () + } + """ + ) + # CHECK: module { + # CHECK: "irdl_test.test_op"(%cst) : (f32) -> () + # CHECK: } + m.dump()