diff --git a/mlir/lib/Bindings/Python/BuiltinOps.td b/mlir/lib/Bindings/Python/BuiltinOps.td new file mode 100644 index 00000000000000..ecbb8227d49006 --- /dev/null +++ b/mlir/lib/Bindings/Python/BuiltinOps.td @@ -0,0 +1,15 @@ +//===-- BuiltinOps.td - Entry point for builtin bindings ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_BUILTIN_OPS +#define PYTHON_BINDINGS_BUILTIN_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/IR/BuiltinOps.td" + +#endif diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt index 1749ea2e5472a1..951aa7883c9019 100644 --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/CMakeLists.txt @@ -11,6 +11,7 @@ set(PY_SRC_FILES mlir/ir.py mlir/dialects/__init__.py mlir/dialects/_linalg.py + mlir/dialects/_builtin.py mlir/ir.py mlir/passmanager.py mlir/transforms/__init__.py @@ -36,6 +37,11 @@ endforeach() # Generate dialect-specific bindings. ################################################################################ +add_mlir_dialect_python_bindings(MLIRBindingsPythonBuiltinOps + TD_FILE BuiltinOps.td + DIALECT_NAME builtin) +add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonBuiltinOps) + add_mlir_dialect_python_bindings(MLIRBindingsPythonLinalgOps TD_FILE LinalgOps.td DIALECT_NAME linalg diff --git a/mlir/lib/Bindings/Python/mlir/dialects/__init__.py b/mlir/lib/Bindings/Python/mlir/dialects/__init__.py index 9c003b4154385a..f5a71bf8870034 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/__init__.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/__init__.py @@ -43,7 +43,7 @@ def class_decorator(parent_opview_cls: type): except AttributeError: # Try to default resolve it. try: - select_mixin = getattr(ext_module, parent_opview_cls.__name__) + mixin_cls = getattr(ext_module, parent_opview_cls.__name__) except AttributeError: pass else: diff --git a/mlir/lib/Bindings/Python/mlir/dialects/_builtin.py b/mlir/lib/Bindings/Python/mlir/dialects/_builtin.py new file mode 100644 index 00000000000000..8d430d5a50dae2 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/dialects/_builtin.py @@ -0,0 +1,93 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from mlir.ir import * + + +class ModuleOp: + """Specialization for the module op class.""" + + def __init__(self, loc=None, ip=None): + super().__init__( + self._ods_build_default(operands=[], results=[], loc=loc, ip=ip)) + body = self.regions[0].blocks.append() + with InsertionPoint(body): + Operation.create("module_terminator") + + @property + def body(self): + return self.regions[0].blocks[0] + + +class FuncOp: + """Specialization for the func op class.""" + + def __init__(self, + name, + type, + visibility, + body_builder=None, + loc=None, + ip=None): + """ + Create a FuncOp with the provided `name`, `type`, and `visibility`. + - `name` is a string representing the function name. + - `type` is either a FunctionType or a pair of list describing inputs and + results. + - `visibility` is a string matching `public`, `private`, or `nested`. The + empty string implies a private visibility. + - `body_builder` is an optional callback, when provided a new entry block + is created and the callback is invoked with the new op as argument within + an InsertionPoint context already set for the block. The callback is + expected to insert a terminator in the block. + """ + sym_name = StringAttr.get(str(name)) + + # If the type is passed as a tuple, build a FunctionType on the fly. + if isinstance(type, tuple): + type = FunctionType.get(inputs=type[0], results=type[1]) + + type = TypeAttr.get(type) + sym_visibility = StringAttr.get( + str(visibility)) if visibility is not None else None + super().__init__(sym_name, type, sym_visibility, loc, ip) + if body_builder: + entry_block = self.add_entry_block() + with InsertionPoint(entry_block): + body_builder(self) + + @property + def is_external(self): + return len(self.regions[0].blocks) == 0 + + @property + def body(self): + return self.regions[0] + + @property + def type(self): + return FunctionType(TypeAttr(self.attributes["type"]).value) + + @property + def visibility(self): + return self.attributes["sym_visibility"] + + @property + def name(self): + return self.attributes["sym_name"] + + @property + def entry_block(self): + if self.is_external: + raise IndexError('External function does not have a body') + return self.regions[0].blocks[0] + + def add_entry_block(self): + ''' + Add an entry block to the function body using the function signature to infer block arguments + Returns the newly created block + ''' + if not self.is_external: + raise IndexError('The function already has an entry block!') + self.body.blocks.append(*self.type.inputs) + return self.body.blocks[0] diff --git a/mlir/test/Bindings/Python/.style.yapf b/mlir/test/Bindings/Python/.style.yapf new file mode 100644 index 00000000000000..9ef1dc15ba6209 --- /dev/null +++ b/mlir/test/Bindings/Python/.style.yapf @@ -0,0 +1,4 @@ +[style] + based_on_style = google + column_limit = 80 + indent_width = 2 diff --git a/mlir/test/Bindings/Python/dialects/builtin.py b/mlir/test/Bindings/Python/dialects/builtin.py new file mode 100644 index 00000000000000..447a255f60215c --- /dev/null +++ b/mlir/test/Bindings/Python/dialects/builtin.py @@ -0,0 +1,69 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.ir import * +import mlir.dialects.builtin as builtin +import mlir.dialects.std as std + + +def run(f): + print("\nTEST:", f.__name__) + f() + + +# CHECK-LABEL: TEST: testBuildFuncOp +def testBuildFuncOp(): + ctx = Context() + with Location.unknown(ctx) as loc: + m = builtin.ModuleOp() + + f32 = F32Type.get() + tensor_type = RankedTensorType.get((2, 3, 4), f32) + with InsertionPoint.at_block_begin(m.body): + func = builtin.FuncOp(name="some_func", + type=FunctionType.get( + inputs=[tensor_type, tensor_type], + results=[tensor_type]), + visibility="nested") + # CHECK: Name is: "some_func" + print("Name is: ", func.name) + + # CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32> + print("Type is: ", func.type) + + # CHECK: Visibility is: "nested" + print("Visibility is: ", func.visibility) + + try: + entry_block = func.entry_block + except IndexError as e: + # CHECK: External function does not have a body + print(e) + + with InsertionPoint(func.add_entry_block()): + std.ReturnOp([func.entry_block.arguments[0]]) + pass + + try: + func.add_entry_block() + except IndexError as e: + # CHECK: The function already has an entry block! + print(e) + + # Try the callback builder and passing type as tuple. + func = builtin.FuncOp(name="some_other_func", + type=([tensor_type, tensor_type], [tensor_type]), + visibility="nested", + body_builder=lambda func: std.ReturnOp( + [func.entry_block.arguments[0]])) + + # CHECK: module { + # CHECK: func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + # CHECK: return %arg0 : tensor<2x3x4xf32> + # CHECK: } + # CHECK: func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + # CHECK: return %arg0 : tensor<2x3x4xf32> + # CHECK: } + print(m) + + +run(testBuildFuncOp) diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp index 0197bfb155772f..94bfd58ab3a5ad 100644 --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -716,6 +716,10 @@ static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) { os << llvm::formatv(fileHeader, clDialectName.getValue()); os << llvm::formatv(dialectClassTemplate, clDialectName.getValue()); + + if (clDialectName == "builtin") + clDialectName = ""; + for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) { Operator op(rec); if (op.getDialectName() == clDialectName.getValue())