Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[mlir][python] Factor out standalone OpView._ods_build_default class …
…method. * This allows us to hoist trait level information for regions and sized-variadic to class level attributes (_ODS_REGIONS, _ODS_OPERAND_SEGMENTS, _ODS_RESULT_SEGMENTS). * Eliminates some splicey python generated code in favor of a native helper for it. * Makes it possible to implement custom, variadic and region based builders with one line of python, without needing to manually code access to the segment attributes. * Needs follow-on work for region based callbacks and support for SingleBlockImplicitTerminator. * A follow-up will actually add ODS support for generating custom Python builders that delegate to this new method. * Also includes the start of an e2e sample for constructing linalg ops where this limitation was discovered (working progressively through this example and cleaning up as I go). Differential Revision: https://reviews.llvm.org/D94738
- Loading branch information
1 parent
cbdde49
commit 71b6b01
Showing
7 changed files
with
713 additions
and
104 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
# This is a work in progress example to do end2end build and code generation | ||
# of a small linalg program with configuration options. It is currently non | ||
# functional and is being used to elaborate the APIs. | ||
|
||
from typing import Tuple | ||
|
||
from mlir.ir import * | ||
from mlir.dialects import linalg | ||
from mlir.dialects import std | ||
|
||
|
||
# TODO: This should be in the core API. | ||
def FuncOp(name: str, func_type: Type) -> Tuple[Operation, Block]: | ||
"""Creates a |func| op. | ||
TODO: This should really be in the MLIR API. | ||
Returns: | ||
(operation, entry_block) | ||
""" | ||
attrs = { | ||
"type": TypeAttr.get(func_type), | ||
"sym_name": StringAttr.get(name), | ||
} | ||
op = Operation.create("func", regions=1, attributes=attrs) | ||
body_region = op.regions[0] | ||
entry_block = body_region.blocks.append(*func_type.inputs) | ||
return op, entry_block | ||
|
||
|
||
# TODO: Generate customs builder vs patching one in. | ||
def PatchMatmulOpInit(self, lhs, rhs, result, loc=None, ip=None): | ||
super(linalg.MatmulOp, self).__init__( | ||
self._ods_build_default(operands=[[lhs, rhs], [result]], | ||
results=[], | ||
loc=loc, | ||
ip=ip)) | ||
# TODO: Implement support for SingleBlockImplicitTerminator | ||
block = self.regions[0].blocks.append() | ||
with InsertionPoint(block): | ||
linalg.YieldOp(values=[]) | ||
|
||
linalg.MatmulOp.__init__ = PatchMatmulOpInit | ||
|
||
|
||
def build_matmul_func(func_name, m, k, n, dtype): | ||
lhs_type = MemRefType.get(dtype, [m, k]) | ||
rhs_type = MemRefType.get(dtype, [k, n]) | ||
result_type = MemRefType.get(dtype, [m, n]) | ||
# TODO: There should be a one-liner for this. | ||
func_type = FunctionType.get([lhs_type, rhs_type, result_type], []) | ||
_, entry = FuncOp(func_name, func_type) | ||
lhs, rhs, result = entry.arguments | ||
with InsertionPoint(entry): | ||
linalg.MatmulOp(lhs, rhs, result) | ||
std.ReturnOp([]) | ||
|
||
|
||
def run(): | ||
with Context() as c, Location.unknown(): | ||
module = Module.create() | ||
# TODO: This at_block_terminator vs default construct distinction feels | ||
# wrong and is error-prone. | ||
with InsertionPoint.at_block_terminator(module.body): | ||
build_matmul_func('main', 18, 32, 96, F32Type.get()) | ||
|
||
print(module) | ||
print(module.operation.get_asm(print_generic_op_form=True)) | ||
|
||
|
||
if __name__ == '__main__': run() |
Oops, something went wrong.