Navigation Menu

Skip to content

Commit

Permalink
[mlir][PDL] Define a new PDLInterp::FuncOp operation and drop uses of…
Browse files Browse the repository at this point in the history
… FuncOp

Defining our own function operation allows for the PDL interpreter
to be more self contained, and also removes any dependency on FuncOp;
which is moving out of the Builtin dialect.

Differential Revision: https://reviews.llvm.org/D121253
  • Loading branch information
River707 committed Mar 15, 2022
1 parent e9c9ee9 commit f96a867
Show file tree
Hide file tree
Showing 15 changed files with 245 additions and 155 deletions.
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h
Expand Up @@ -16,6 +16,8 @@

#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

Expand Down
60 changes: 60 additions & 0 deletions mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
Expand Up @@ -14,6 +14,8 @@
#define MLIR_DIALECT_PDLINTERP_IR_PDLINTERPOPS

include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/IR/FunctionInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -627,6 +629,64 @@ def PDLInterp_ForEachOp
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// pdl_interp::FuncOp
//===----------------------------------------------------------------------===//

def PDLInterp_FuncOp : PDLInterp_Op<"func", [
FunctionOpInterface, IsolatedFromAbove, Symbol
]> {
let summary = "PDL Interpreter Function Operation";
let description = [{
`pdl_interp.func` operations act as interpreter functions. These are
callable SSA-region operations that contain other interpreter operations.
Interpreter functions are used for both the matching and the rewriting
portion of the interpreter.

Example:

```mlir
pdl_interp.func @rewriter(%root: !pdl.operation) {
%op = pdl_interp.create_operation "foo.new_operation"
pdl_interp.erase %root
pdl_interp.finalize
}
```
}];

let arguments = (ins
SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$type
);
let regions = (region MinSizedRegion<1>:$body);

// Create the function with the given name and type. This also automatically
// inserts the entry block for the function.
let builders = [OpBuilder<(ins
"StringRef":$name, "FunctionType":$type,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
>];
let extraClassDeclaration = [{
/// Returns the type of this function.
/// FIXME: We should drive this via the ODS `type` param.
FunctionType getType() {
return getTypeAttr().getValue().cast<FunctionType>();
}

//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//

/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return type().getInputs(); }

/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return type().getResults(); }
}];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}

//===----------------------------------------------------------------------===//
// pdl_interp::GetAttributeOp
//===----------------------------------------------------------------------===//
Expand Down
11 changes: 0 additions & 11 deletions mlir/include/mlir/IR/BuiltinOps.td
Expand Up @@ -133,24 +133,13 @@ def FuncOp : Builtin_Op<"func", [
/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return getType().getResults(); }

/// Verify the type attribute of this function. Returns failure and emits
/// an error if the attribute is invalid.
LogicalResult verifyType() {
auto type = getTypeAttr().getValue();
if (!type.isa<FunctionType>())
return emitOpError("requires '" + FunctionOpInterface::getTypeAttrName() +
"' attribute of function type");
return success();
}

//===------------------------------------------------------------------===//
// SymbolOpInterface Methods
//===------------------------------------------------------------------===//

bool isDeclaration() { return isExternal(); }
}];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
6 changes: 2 additions & 4 deletions mlir/include/mlir/IR/FunctionImplementation.h
Expand Up @@ -86,10 +86,8 @@ ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result,
bool allowVariadic,
FuncTypeBuilder funcTypeBuilder);

/// Printer implementation for function-like operations. Accepts lists of
/// argument and result types to use while printing.
void printFunctionOp(OpAsmPrinter &p, Operation *op, ArrayRef<Type> argTypes,
bool isVariadic, ArrayRef<Type> resultTypes);
/// Printer implementation for function-like operations.
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic);

/// Prints the signature of the function-like operation `op`. Assumes `op` has
/// is a FunctionOpInterface and has passed verification.
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/IR/FunctionInterfaces.h
Expand Up @@ -14,6 +14,7 @@
#ifndef MLIR_IR_FUNCTIONINTERFACES_H
#define MLIR_IR_FUNCTIONINTERFACES_H

#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/SymbolTable.h"
Expand Down
53 changes: 47 additions & 6 deletions mlir/include/mlir/IR/FunctionInterfaces.td
Expand Up @@ -79,25 +79,41 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
InterfaceMethod<[{
Verify the contents of the body of this function.

Note: The default implementation merely checks that of the entry block
exists, it has the same number arguments as the function type.
Note: The default implementation merely checks that if the entry block
exists, it has the same number and type of arguments as the function type.
}],
"::mlir::LogicalResult", "verifyBody", (ins),
/*methodBody=*/[{}], /*defaultImplementation=*/[{
if ($_op.isExternal())
return success();

unsigned numArguments = $_op.getNumArguments();
if ($_op.front().getNumArguments() != numArguments)
ArrayRef<Type> fnInputTypes = $_op.getArgumentTypes();
Block &entryBlock = $_op.front();

unsigned numArguments = fnInputTypes.size();
if (entryBlock.getNumArguments() != numArguments)
return $_op.emitOpError("entry block must have ")
<< numArguments << " arguments to match function signature";

for (unsigned i = 0, e = fnInputTypes.size(); i != e; ++i) {
Type argType = entryBlock.getArgument(i).getType();
if (fnInputTypes[i] != argType) {
return $_op.emitOpError("type of entry block argument #")
<< i << '(' << argType
<< ") must match the type of the corresponding argument in "
<< "function signature(" << fnInputTypes[i] << ')';
}
}

return success();
}]>,
InterfaceMethod<[{
Verify the type attribute of the function for derived op-specific
invariants.
}],
"::mlir::LogicalResult", "verifyType">,
"::mlir::LogicalResult", "verifyType", (ins),
/*methodBody=*/[{}], /*defaultImplementation=*/[{
return success();
}]>,
];

let extraClassDeclaration = [{
Expand All @@ -108,6 +124,31 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
/// Return the name of the function.
StringRef getName() { return SymbolTable::getSymbolName(*this); }
}];
let extraTraitClassDeclaration = [{
//===------------------------------------------------------------------===//
// Builders
//===------------------------------------------------------------------===//

/// Build the function with the given name, attributes, and type. This
/// builder also inserts an entry block into the function body with the
/// given argument types.
static void buildWithEntryBlock(
OpBuilder &builder, OperationState &state, StringRef name, Type type,
ArrayRef<NamedAttribute> attrs, ArrayRef<Type> inputTypes) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
state.addAttribute(function_interface_impl::getTypeAttrName(),
TypeAttr::get(type));
state.attributes.append(attrs.begin(), attrs.end());

// Add the function body.
Region *bodyRegion = state.addRegion();
Block *body = new Block();
bodyRegion->push_back(body);
for (Type input : inputTypes)
body->addArgument(input, state.location);
}
}];
let extraSharedClassDeclaration = [{
/// Block list iterator types.
using BlockListType = Region::BlockListType;
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/IR/OpBase.td
Expand Up @@ -1942,6 +1942,11 @@ class SizedRegion<int numBlocks> : Region<
CPred<"::llvm::hasNItems($_self, " # numBlocks # ")">,
"region with " # numBlocks # " blocks">;

// A region with at least the given number of blocks.
class MinSizedRegion<int numBlocks> : Region<
CPred<"::llvm::hasNItemsOrMore($_self, " # numBlocks # ")">,
"region with at least " # numBlocks # " blocks">;

// A variadic region constraint. It expands to zero or more of the base region.
class VariadicRegion<Region region>
: Region<region.predicate, region.summary>;
Expand Down
20 changes: 11 additions & 9 deletions mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
Expand Up @@ -32,7 +32,7 @@ namespace {
/// given module containing PDL pattern operations.
struct PatternLowering {
public:
PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule);
PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule);

/// Generate code for matching and rewriting based on the pattern operations
/// within the module.
Expand Down Expand Up @@ -110,7 +110,7 @@ struct PatternLowering {
OpBuilder builder;

/// The matcher function used for all match related logic within PDL patterns.
FuncOp matcherFunc;
pdl_interp::FuncOp matcherFunc;

/// The rewriter module containing the all rewrite related logic within PDL
/// patterns.
Expand All @@ -137,7 +137,8 @@ struct PatternLowering {
};
} // namespace

PatternLowering::PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule)
PatternLowering::PatternLowering(pdl_interp::FuncOp matcherFunc,
ModuleOp rewriterModule)
: builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {}

Expand All @@ -150,7 +151,7 @@ void PatternLowering::lower(ModuleOp module) {

// Insert the root operation, i.e. argument to the matcher, at the root
// position.
Block *matcherEntryBlock = matcherFunc.addEntryBlock();
Block *matcherEntryBlock = &matcherFunc.front();
values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));

// Generate a root matcher node from the provided PDL module.
Expand Down Expand Up @@ -590,13 +591,14 @@ void PatternLowering::generate(SuccessNode *successNode, Block *&currentBlock) {

SymbolRefAttr PatternLowering::generateRewriter(
pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
FuncOp rewriterFunc =
FuncOp::create(pattern.getLoc(), "pdl_generated_rewriter",
builder.getFunctionType(llvm::None, llvm::None));
builder.setInsertionPointToEnd(rewriterModule.getBody());
auto rewriterFunc = builder.create<pdl_interp::FuncOp>(
pattern.getLoc(), "pdl_generated_rewriter",
builder.getFunctionType(llvm::None, llvm::None));
rewriterSymbolTable.insert(rewriterFunc);

// Generate the rewriter function body.
builder.setInsertionPointToEnd(rewriterFunc.addEntryBlock());
builder.setInsertionPointToEnd(&rewriterFunc.front());

// Map an input operand of the pattern to a generated interpreter value.
DenseMap<Value, Value> rewriteValues;
Expand Down Expand Up @@ -902,7 +904,7 @@ void PDLToPDLInterpPass::runOnOperation() {
// Create the main matcher function This function contains all of the match
// related functionality from patterns in the module.
OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
FuncOp matcherFunc = builder.create<FuncOp>(
auto matcherFunc = builder.create<pdl_interp::FuncOp>(
module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
builder.getFunctionType(builder.getType<pdl::OperationType>(),
/*results=*/llvm::None),
Expand Down
9 changes: 3 additions & 6 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Expand Up @@ -1997,8 +1997,8 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
// Returns a null type if any of the types provided are non-LLVM types, or if
// there is more than one output type.
static Type
buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc,
ArrayRef<Type> inputs, ArrayRef<Type> outputs,
buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef<Type> inputs,
ArrayRef<Type> outputs,
function_interface_impl::VariadicFlag variadicFlag) {
Builder &b = parser.getBuilder();
if (outputs.size() > 1) {
Expand Down Expand Up @@ -2159,7 +2159,7 @@ LogicalResult LLVMFuncOp::verify() {
}

/// Verifies LLVM- and implementation-specific properties of the LLVM func Op:
/// - entry block arguments are of LLVM types and match the function signature.
/// - entry block arguments are of LLVM types.
LogicalResult LLVMFuncOp::verifyRegions() {
if (isExternal())
return success();
Expand All @@ -2171,9 +2171,6 @@ LogicalResult LLVMFuncOp::verifyRegions() {
if (!isCompatibleType(argType))
return emitOpError("entry block argument #")
<< i << " is not of LLVM type";
if (getType().getParamType(i) != argType)
return emitOpError("the type of entry block argument #")
<< i << " does not match the function signature";
}

return success();
Expand Down
24 changes: 24 additions & 0 deletions mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
Expand Up @@ -10,6 +10,7 @@
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/FunctionImplementation.h"

using namespace mlir;
using namespace mlir::pdl_interp;
Expand Down Expand Up @@ -161,6 +162,29 @@ LogicalResult ForEachOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// pdl_interp::FuncOp
//===----------------------------------------------------------------------===//

void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
FunctionType type, ArrayRef<NamedAttribute> attrs) {
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
}

ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
auto buildFuncType =
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };

return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false, buildFuncType);
}

void FuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
}

//===----------------------------------------------------------------------===//
// pdl_interp::GetValueTypeOp
//===----------------------------------------------------------------------===//
Expand Down
24 changes: 1 addition & 23 deletions mlir/lib/IR/BuiltinDialect.cpp
Expand Up @@ -124,29 +124,7 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
}

void FuncOp::print(OpAsmPrinter &p) {
FunctionType fnType = getType();
function_interface_impl::printFunctionOp(
p, *this, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults());
}

LogicalResult FuncOp::verify() {
// If this function is external there is nothing to do.
if (isExternal())
return success();

// Verify that the argument list of the function and the arg list of the entry
// block line up. The trait already verified that the number of arguments is
// the same between the signature and the block.
auto fnInputTypes = getType().getInputs();
Block &entryBlock = front();
for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i)
if (fnInputTypes[i] != entryBlock.getArgument(i).getType())
return emitOpError("type of entry block argument #")
<< i << '(' << entryBlock.getArgument(i).getType()
<< ") must match the type of the corresponding argument in "
<< "function signature(" << fnInputTypes[i] << ')';

return success();
function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
}

/// Clone the internal blocks from this function into dest and all attributes
Expand Down

0 comments on commit f96a867

Please sign in to comment.