diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h index 75d71754535830..fc461f8fe7e50e 100644 --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -129,7 +129,7 @@ struct function_traits { /// Overload for class function types. template struct function_traits - : function_traits {}; + : public function_traits {}; /// Overload for non-class function types. template struct function_traits { @@ -143,6 +143,9 @@ struct function_traits { template using arg_t = typename std::tuple_element>::type; }; +template +struct function_traits + : public function_traits {}; /// Overload for non-class function type references. template struct function_traits diff --git a/mlir/docs/PDLL.md b/mlir/docs/PDLL.md index ab24a680d37b82..33293505d37140 100644 --- a/mlir/docs/PDLL.md +++ b/mlir/docs/PDLL.md @@ -1006,17 +1006,11 @@ External constraints are those registered explicitly with the `RewritePatternSet the C++ PDL API. For example, the constraints above may be registered as: ```c++ -// TODO: Cleanup when we allow more accessible wrappers around PDL functions. -static LogicalResult hasOneUseImpl(PDLValue pdlValue, PatternRewriter &rewriter) { - Value value = pdlValue.cast(); - +static LogicalResult hasOneUseImpl(PatternRewriter &rewriter, Value value) { return success(value.hasOneUse()); } -static LogicalResult hasSameElementTypeImpl(ArrayRef pdlValues, - PatternRewriter &rewriter) { - Value value1 = pdlValues[0].cast(); - Value value2 = pdlValues[1].cast(); - +static LogicalResult hasSameElementTypeImpl(PatternRewriter &rewriter, + Value value1, Value Value2) { return success(value1.getType().cast().getElementType() == value2.getType().cast().getElementType()); } @@ -1307,14 +1301,10 @@ External rewrites are those registered explicitly with the `RewritePatternSet` v the C++ PDL API. For example, the rewrite above may be registered as: ```c++ -// TODO: Cleanup when we allow more accessible wrappers around PDL functions. -static void buildOpImpl(ArrayRef args, PatternRewriter &rewriter, - PDLResultList &results) { - Value value = args[0].cast(); - +static Operation *buildOpImpl(PDLResultList &results, Value value) { // insert special rewrite logic here. Operation *resultOp = ...; - results.push_back(resultOp); + return resultOp; } void registerNativeRewrite(RewritePatternSet &patterns) { diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td index 7f0253e59a32bf..4e43b2677f4a81 100644 --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -68,18 +68,14 @@ def PDL_ApplyNativeRewriteOp ```mlir // Apply a native rewrite method that returns an attribute. - %ret = pdl.apply_native_rewrite "myNativeFunc"(%arg0, %arg1) : !pdl.attribute + %ret = pdl.apply_native_rewrite "myNativeFunc"(%arg0, %attr1) : !pdl.attribute ``` ```c++ // The native rewrite as defined in C++: - static void myNativeFunc(ArrayRef args, PatternRewriter &rewriter, - PDLResultList &results) { - Value arg0 = args[0].cast(); - Value arg1 = args[1].cast(); - - // Just push back the first param attribute. - results.push_back(param0); + static Attribute myNativeFunc(PatternRewriter &rewriter, Value arg0, Attribute arg1) { + // Just return the second arg. + return arg1; } void registerNativeRewrite(PDLPatternModule &pdlModule) { diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 1bd6187a78d480..f4c8863624740d 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -409,7 +409,8 @@ class OpBuilder : public Builder { /// Creates an operation with the given fields. Operation *create(Location loc, StringAttr opName, ValueRange operands, - TypeRange types, ArrayRef attributes = {}, + TypeRange types = {}, + ArrayRef attributes = {}, BlockRange successors = {}, MutableArrayRef> regions = {}); diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 11f85ee38bef82..478fa2ae97b1c2 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -386,6 +386,222 @@ class OpTraitRewritePattern : public RewritePattern { benefit, context) {} }; +//===----------------------------------------------------------------------===// +// RewriterBase +//===----------------------------------------------------------------------===// + +/// This class coordinates the application of a rewrite on a set of IR, +/// providing a way for clients to track mutations and create new operations. +/// This class serves as a common API for IR mutation between pattern rewrites +/// and non-pattern rewrites, and facilitates the development of shared +/// IR transformation utilities. +class RewriterBase : public OpBuilder, public OpBuilder::Listener { +public: + /// Move the blocks that belong to "region" before the given position in + /// another region "parent". The two regions must be different. The caller + /// is responsible for creating or updating the operation transferring flow + /// of control to the region and passing it the correct block arguments. + virtual void inlineRegionBefore(Region ®ion, Region &parent, + Region::iterator before); + void inlineRegionBefore(Region ®ion, Block *before); + + /// Clone the blocks that belong to "region" before the given position in + /// another region "parent". The two regions must be different. The caller is + /// responsible for creating or updating the operation transferring flow of + /// control to the region and passing it the correct block arguments. + virtual void cloneRegionBefore(Region ®ion, Region &parent, + Region::iterator before, + BlockAndValueMapping &mapping); + void cloneRegionBefore(Region ®ion, Region &parent, + Region::iterator before); + void cloneRegionBefore(Region ®ion, Block *before); + + /// This method replaces the uses of the results of `op` with the values in + /// `newValues` when the provided `functor` returns true for a specific use. + /// The number of values in `newValues` is required to match the number of + /// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of + /// the uses of `op` were replaced. Note that in some rewriters, the given + /// 'functor' may be stored beyond the lifetime of the rewrite being applied. + /// As such, the function should not capture by reference and instead use + /// value capture as necessary. + virtual void + replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced, + llvm::unique_function functor); + void replaceOpWithIf(Operation *op, ValueRange newValues, + llvm::unique_function functor) { + replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr, + std::move(functor)); + } + + /// This method replaces the uses of the results of `op` with the values in + /// `newValues` when a use is nested within the given `block`. The number of + /// values in `newValues` is required to match the number of results of `op`. + /// If all uses of this operation are replaced, the operation is erased. + void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block, + bool *allUsesReplaced = nullptr); + + /// This method replaces the results of the operation with the specified list + /// of values. The number of provided values must match the number of results + /// of the operation. + virtual void replaceOp(Operation *op, ValueRange newValues); + + /// Replaces the result op with a new op that is created without verification. + /// The result values of the two ops must be the same types. + template + OpTy replaceOpWithNewOp(Operation *op, Args &&...args) { + auto newOp = create(op->getLoc(), std::forward(args)...); + replaceOpWithResultsOfAnotherOp(op, newOp.getOperation()); + return newOp; + } + + /// This method erases an operation that is known to have no uses. + virtual void eraseOp(Operation *op); + + /// This method erases all operations in a block. + virtual void eraseBlock(Block *block); + + /// Merge the operations of block 'source' into the end of block 'dest'. + /// 'source's predecessors must either be empty or only contain 'dest`. + /// 'argValues' is used to replace the block arguments of 'source' after + /// merging. + virtual void mergeBlocks(Block *source, Block *dest, + ValueRange argValues = llvm::None); + + // Merge the operations of block 'source' before the operation 'op'. Source + // block should not have existing predecessors or successors. + void mergeBlockBefore(Block *source, Operation *op, + ValueRange argValues = llvm::None); + + /// Split the operations starting at "before" (inclusive) out of the given + /// block into a new block, and return it. + virtual Block *splitBlock(Block *block, Block::iterator before); + + /// This method is used to notify the rewriter that an in-place operation + /// modification is about to happen. A call to this function *must* be + /// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`. + /// This is a minor efficiency win (it avoids creating a new operation and + /// removing the old one) but also often allows simpler code in the client. + virtual void startRootUpdate(Operation *op) {} + + /// This method is used to signal the end of a root update on the given + /// operation. This can only be called on operations that were provided to a + /// call to `startRootUpdate`. + virtual void finalizeRootUpdate(Operation *op) {} + + /// This method cancels a pending root update. This can only be called on + /// operations that were provided to a call to `startRootUpdate`. + virtual void cancelRootUpdate(Operation *op) {} + + /// This method is a utility wrapper around a root update of an operation. It + /// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given + /// callable. + template + void updateRootInPlace(Operation *root, CallableT &&callable) { + startRootUpdate(root); + callable(); + finalizeRootUpdate(root); + } + + /// Used to notify the rewriter that the IR failed to be rewritten because of + /// a match failure, and provide a callback to populate a diagnostic with the + /// reason why the failure occurred. This method allows for derived rewriters + /// to optionally hook into the reason why a rewrite failed, and display it to + /// users. + template + std::enable_if_t::value, LogicalResult> + notifyMatchFailure(Location loc, CallbackT &&reasonCallback) { +#ifndef NDEBUG + return notifyMatchFailure(loc, + function_ref(reasonCallback)); +#else + return failure(); +#endif + } + template + std::enable_if_t::value, LogicalResult> + notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) { + return notifyMatchFailure(op->getLoc(), + function_ref(reasonCallback)); + } + template + LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg) { + return notifyMatchFailure(std::forward(arg), + [&](Diagnostic &diag) { diag << msg; }); + } + template + LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg) { + return notifyMatchFailure(std::forward(arg), Twine(msg)); + } + +protected: + /// Initialize the builder with this rewriter as the listener. + explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx, /*listener=*/this) {} + explicit RewriterBase(const OpBuilder &otherBuilder) + : OpBuilder(otherBuilder) { + setListener(this); + } + ~RewriterBase() override; + + /// These are the callback methods that subclasses can choose to implement if + /// they would like to be notified about certain types of mutations. + + /// Notify the rewriter that the specified operation is about to be replaced + /// with another set of operations. This is called before the uses of the + /// operation have been changed. + virtual void notifyRootReplaced(Operation *op) {} + + /// This is called on an operation that a rewrite is removing, right before + /// the operation is deleted. At this point, the operation has zero uses. + virtual void notifyOperationRemoved(Operation *op) {} + + /// Notify the rewriter that the pattern failed to match the given operation, + /// and provide a callback to populate a diagnostic with the reason why the + /// failure occurred. This method allows for derived rewriters to optionally + /// hook into the reason why a rewrite failed, and display it to users. + virtual LogicalResult + notifyMatchFailure(Location loc, + function_ref reasonCallback) { + return failure(); + } + +private: + void operator=(const RewriterBase &) = delete; + RewriterBase(const RewriterBase &) = delete; + + /// 'op' and 'newOp' are known to have the same number of results, replace the + /// uses of op with uses of newOp. + void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp); +}; + +//===----------------------------------------------------------------------===// +// IRRewriter +//===----------------------------------------------------------------------===// + +/// This class coordinates rewriting a piece of IR outside of a pattern rewrite, +/// providing a way to keep track of the mutations made to the IR. This class +/// should only be used in situations where another `RewriterBase` instance, +/// such as a `PatternRewriter`, is not available. +class IRRewriter : public RewriterBase { +public: + explicit IRRewriter(MLIRContext *ctx) : RewriterBase(ctx) {} + explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {} +}; + +//===----------------------------------------------------------------------===// +// PatternRewriter +//===----------------------------------------------------------------------===// + +/// A special type of `RewriterBase` that coordinates the application of a +/// rewrite pattern on the current IR being matched, providing a way to keep +/// track of any mutations made. This class should be used to perform all +/// necessary IR mutations within a rewrite pattern, as the pattern driver may +/// be tracking various state that would be invalidated when a mutation takes +/// place. +class PatternRewriter : public RewriterBase { +public: + using RewriterBase::RewriterBase; +}; + //===----------------------------------------------------------------------===// // PDLPatternModule //===----------------------------------------------------------------------===// @@ -587,291 +803,561 @@ class PDLResultList { /// constraint to a given set of opaque PDLValue entities. Returns success if /// the constraint successfully held, failure otherwise. using PDLConstraintFunction = - std::function, PatternRewriter &)>; + std::function)>; /// A native PDL rewrite function. This function performs a rewrite on the /// given set of values. Any results from this rewrite that should be passed /// back to PDL should be added to the provided result list. This method is only /// invoked when the corresponding match was successful. using PDLRewriteFunction = - std::function, PatternRewriter &, PDLResultList &)>; - -/// This class contains all of the necessary data for a set of PDL patterns, or -/// pattern rewrites specified in the form of the PDL dialect. This PDL module -/// contained by this pattern may contain any number of `pdl.pattern` -/// operations. -class PDLPatternModule { -public: - PDLPatternModule() = default; + std::function)>; - /// Construct a PDL pattern with the given module. - PDLPatternModule(OwningOpRef pdlModule) - : pdlModule(std::move(pdlModule)) {} +namespace detail { +namespace pdl_function_builder { +/// A utility variable that always resolves to false. This is useful for static +/// asserts that are always false, but only should fire in certain templated +/// constructs. For example, if a templated function should never be called, the +/// function could be defined as: +/// +/// template +/// void foo() { +/// static_assert(always_false, "This function should never be called"); +/// } +/// +template +constexpr bool always_false = false; - /// Merge the state in `other` into this pattern module. - void mergeIn(PDLPatternModule &&other); +//===----------------------------------------------------------------------===// +// PDL Function Builder: Type Processing +//===----------------------------------------------------------------------===// - /// Return the internal PDL module of this pattern. - ModuleOp getModule() { return pdlModule.get(); } +/// This struct provides a convenient way to determine how to process a given +/// type as either a PDL parameter, or a result value. This allows for +/// supporting complex types in constraint and rewrite functions, without +/// requiring the user to hand-write the necessary glue code themselves. +/// Specializations of this class should implement the following methods to +/// enable support as a PDL argument or result type: +/// +/// static LogicalResult verifyAsArg( +/// function_ref errorFn, PDLValue pdlValue, +/// size_t argIdx); +/// +/// * This method verifies that the given PDLValue is valid for use as a +/// value of `T`. +/// +/// static T processAsArg(PDLValue pdlValue); +/// +/// * This method processes the given PDLValue as a value of `T`. +/// +/// static void processAsResult(PatternRewriter &, PDLResultList &results, +/// const T &value); +/// +/// * This method processes the given value of `T` as the result of a +/// function invocation. The method should package the value into an +/// appropriate form and append it to the given result list. +/// +/// If the type `T` is based on a higher order value, consider using +/// `ProcessPDLValueBasedOn` as a base class of the specialization to simplify +/// the implementation. +/// +template +struct ProcessPDLValue; + +/// This struct provides a simplified model for processing types that are based +/// on another type, e.g. APInt is based on the handling for IntegerAttr. This +/// allows for building the necessary processing functions on top of the base +/// value instead of a PDLValue. Derived users should implement the following +/// (which subsume the ProcessPDLValue variants): +/// +/// static LogicalResult verifyAsArg( +/// function_ref errorFn, +/// const BaseT &baseValue, size_t argIdx); +/// +/// * This method verifies that the given PDLValue is valid for use as a +/// value of `T`. +/// +/// static T processAsArg(BaseT baseValue); +/// +/// * This method processes the given base value as a value of `T`. +/// +template +struct ProcessPDLValueBasedOn { + static LogicalResult + verifyAsArg(function_ref errorFn, + PDLValue pdlValue, size_t argIdx) { + // Verify the base class before continuing. + if (failed(ProcessPDLValue::verifyAsArg(errorFn, pdlValue, argIdx))) + return failure(); + return ProcessPDLValue::verifyAsArg( + errorFn, ProcessPDLValue::processAsArg(pdlValue), argIdx); + } + static T processAsArg(PDLValue pdlValue) { + return ProcessPDLValue::processAsArg( + ProcessPDLValue::processAsArg(pdlValue)); + } - //===--------------------------------------------------------------------===// - // Function Registry + /// Explicitly add the expected parent API to ensure the parent class + /// implements the necessary API (and doesn't implicitly inherit it from + /// somewhere else). + static LogicalResult + verifyAsArg(function_ref errorFn, BaseT value, + size_t argIdx) { + return success(); + } + static T processAsArg(BaseT baseValue); +}; - /// Register a constraint function. - void registerConstraintFunction(StringRef name, - PDLConstraintFunction constraintFn); - /// Register a single entity constraint function. - template - std::enable_if_t, - PatternRewriter &>::value> - registerConstraintFunction(StringRef name, SingleEntityFn &&constraintFn) { - registerConstraintFunction( - name, [constraintFn = std::forward(constraintFn)]( - ArrayRef values, PatternRewriter &rewriter) { - assert(values.size() == 1 && - "expected values to have a single entity"); - return constraintFn(values[0], rewriter); - }); +/// This struct provides a simplified model for processing types that have +/// "builtin" PDLValue support: +/// * Attribute, Operation *, Type, TypeRange, ValueRange +template +struct ProcessBuiltinPDLValue { + static LogicalResult + verifyAsArg(function_ref errorFn, + PDLValue pdlValue, size_t argIdx) { + if (pdlValue) + return success(); + return errorFn("expected a non-null value for argument " + Twine(argIdx) + + " of type: " + llvm::getTypeName()); } - /// Register a rewrite function. - void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn); + static T processAsArg(PDLValue pdlValue) { return pdlValue.cast(); } + static void processAsResult(PatternRewriter &, PDLResultList &results, + T value) { + results.push_back(value); + } +}; - /// Return the set of the registered constraint functions. - const llvm::StringMap &getConstraintFunctions() const { - return constraintFunctions; +/// This struct provides a simplified model for processing types that inherit +/// from builtin PDLValue types. For example, derived attributes like +/// IntegerAttr, derived types like IntegerType, derived operations like +/// ModuleOp, Interfaces, etc. +template +struct ProcessDerivedPDLValue : public ProcessPDLValueBasedOn { + static LogicalResult + verifyAsArg(function_ref errorFn, + BaseT baseValue, size_t argIdx) { + return TypeSwitch(baseValue) + .Case([&](T) { return success(); }) + .Default([&](BaseT) { + return errorFn("expected argument " + Twine(argIdx) + + " to be of type: " + llvm::getTypeName()); + }); } - llvm::StringMap takeConstraintFunctions() { - return constraintFunctions; + static T processAsArg(BaseT baseValue) { + return baseValue.template cast(); } - /// Return the set of the registered rewrite functions. - const llvm::StringMap &getRewriteFunctions() const { - return rewriteFunctions; - } - llvm::StringMap takeRewriteFunctions() { - return rewriteFunctions; + static void processAsResult(PatternRewriter &, PDLResultList &results, + T value) { + results.push_back(value); } +}; - /// Clear out the patterns and functions within this module. - void clear() { - pdlModule = nullptr; - constraintFunctions.clear(); - rewriteFunctions.clear(); +//===----------------------------------------------------------------------===// +// Attribute + +template <> +struct ProcessPDLValue : public ProcessBuiltinPDLValue {}; +template +struct ProcessPDLValue::value>> + : public ProcessDerivedPDLValue {}; + +/// Handling for various Attribute value types. +template <> +struct ProcessPDLValue + : public ProcessPDLValueBasedOn { + static StringRef processAsArg(StringAttr value) { return value.getValue(); } + static void processAsResult(PatternRewriter &rewriter, PDLResultList &results, + StringRef value) { + results.push_back(rewriter.getStringAttr(value)); } +}; +template <> +struct ProcessPDLValue + : public ProcessPDLValueBasedOn { + template + static std::string processAsArg(T value) { + static_assert(always_false, + "`std::string` arguments require a string copy, use " + "`StringRef` for string-like arguments instead"); + } + static void processAsResult(PatternRewriter &rewriter, PDLResultList &results, + StringRef value) { + results.push_back(rewriter.getStringAttr(value)); + } +}; -private: - /// The module containing the `pdl.pattern` operations. - OwningOpRef pdlModule; - - /// The external functions referenced from within the PDL module. - llvm::StringMap constraintFunctions; - llvm::StringMap rewriteFunctions; +//===----------------------------------------------------------------------===// +// Operation + +template <> +struct ProcessPDLValue + : public ProcessBuiltinPDLValue {}; +template +struct ProcessPDLValue::value>> + : public ProcessDerivedPDLValue { + static T processAsArg(Operation *value) { return cast(value); } }; //===----------------------------------------------------------------------===// -// RewriterBase +// Type + +template <> +struct ProcessPDLValue : public ProcessBuiltinPDLValue {}; +template +struct ProcessPDLValue::value>> + : public ProcessDerivedPDLValue {}; + //===----------------------------------------------------------------------===// +// TypeRange + +template <> +struct ProcessPDLValue : public ProcessBuiltinPDLValue {}; +template <> +struct ProcessPDLValue> { + static void processAsResult(PatternRewriter &, PDLResultList &results, + ValueTypeRange types) { + results.push_back(types); + } +}; +template <> +struct ProcessPDLValue> { + static void processAsResult(PatternRewriter &, PDLResultList &results, + ValueTypeRange types) { + results.push_back(types); + } +}; -/// This class coordinates the application of a rewrite on a set of IR, -/// providing a way for clients to track mutations and create new operations. -/// This class serves as a common API for IR mutation between pattern rewrites -/// and non-pattern rewrites, and facilitates the development of shared -/// IR transformation utilities. -class RewriterBase : public OpBuilder, public OpBuilder::Listener { -public: - /// Move the blocks that belong to "region" before the given position in - /// another region "parent". The two regions must be different. The caller - /// is responsible for creating or updating the operation transferring flow - /// of control to the region and passing it the correct block arguments. - virtual void inlineRegionBefore(Region ®ion, Region &parent, - Region::iterator before); - void inlineRegionBefore(Region ®ion, Block *before); +//===----------------------------------------------------------------------===// +// Value - /// Clone the blocks that belong to "region" before the given position in - /// another region "parent". The two regions must be different. The caller is - /// responsible for creating or updating the operation transferring flow of - /// control to the region and passing it the correct block arguments. - virtual void cloneRegionBefore(Region ®ion, Region &parent, - Region::iterator before, - BlockAndValueMapping &mapping); - void cloneRegionBefore(Region ®ion, Region &parent, - Region::iterator before); - void cloneRegionBefore(Region ®ion, Block *before); +template <> +struct ProcessPDLValue : public ProcessBuiltinPDLValue {}; - /// This method replaces the uses of the results of `op` with the values in - /// `newValues` when the provided `functor` returns true for a specific use. - /// The number of values in `newValues` is required to match the number of - /// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of - /// the uses of `op` were replaced. Note that in some rewriters, the given - /// 'functor' may be stored beyond the lifetime of the rewrite being applied. - /// As such, the function should not capture by reference and instead use - /// value capture as necessary. - virtual void - replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced, - llvm::unique_function functor); - void replaceOpWithIf(Operation *op, ValueRange newValues, - llvm::unique_function functor) { - replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr, - std::move(functor)); +//===----------------------------------------------------------------------===// +// ValueRange + +template <> +struct ProcessPDLValue : public ProcessBuiltinPDLValue { +}; +template <> +struct ProcessPDLValue { + static void processAsResult(PatternRewriter &, PDLResultList &results, + OperandRange values) { + results.push_back(values); } +}; +template <> +struct ProcessPDLValue { + static void processAsResult(PatternRewriter &, PDLResultList &results, + ResultRange values) { + results.push_back(values); + } +}; - /// This method replaces the uses of the results of `op` with the values in - /// `newValues` when a use is nested within the given `block`. The number of - /// values in `newValues` is required to match the number of results of `op`. - /// If all uses of this operation are replaced, the operation is erased. - void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block, - bool *allUsesReplaced = nullptr); +//===----------------------------------------------------------------------===// +// PDL Function Builder: Argument Handling +//===----------------------------------------------------------------------===// - /// This method replaces the results of the operation with the specified list - /// of values. The number of provided values must match the number of results - /// of the operation. - virtual void replaceOp(Operation *op, ValueRange newValues); +/// Validate the given PDLValues match the constraints defined by the argument +/// types of the given function. In the case of failure, a match failure +/// diagnostic is emitted. +/// FIXME: This should be completely removed in favor of `assertArgs`, but PDL +/// does not currently preserve Constraint application ordering. +template +LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef values, + std::index_sequence) { + using FnTraitsT = llvm::function_traits; + + auto errorFn = [&](const Twine &msg) { + return rewriter.notifyMatchFailure(rewriter.getUnknownLoc(), msg); + }; + LogicalResult result = success(); + (void)std::initializer_list{ + (result = + succeeded(result) + ? ProcessPDLValue>:: + verifyAsArg(errorFn, values[I], I) + : failure(), + 0)...}; + return result; +} - /// Replaces the result op with a new op that is created without verification. - /// The result values of the two ops must be the same types. - template - OpTy replaceOpWithNewOp(Operation *op, Args &&... args) { - auto newOp = create(op->getLoc(), std::forward(args)...); - replaceOpWithResultsOfAnotherOp(op, newOp.getOperation()); - return newOp; - } +/// Assert that the given PDLValues match the constraints defined by the +/// arguments of the given function. In the case of failure, a fatal error +/// is emitted. +template +void assertArgs(PatternRewriter &rewriter, ArrayRef values, + std::index_sequence) { + using FnTraitsT = llvm::function_traits; + + // We only want to do verification in debug builds, same as with `assert`. +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + auto errorFn = [&](const Twine &msg) -> LogicalResult { + llvm::report_fatal_error(msg); + }; + (void)std::initializer_list{ + (assert(succeeded(ProcessPDLValue>::verifyAsArg(errorFn, values[I], I))), + 0)...}; +#endif +} - /// This method erases an operation that is known to have no uses. - virtual void eraseOp(Operation *op); +//===----------------------------------------------------------------------===// +// PDL Function Builder: Results Handling +//===----------------------------------------------------------------------===// - /// This method erases all operations in a block. - virtual void eraseBlock(Block *block); +/// Store a single result within the result list. +template +static void processResults(PatternRewriter &rewriter, PDLResultList &results, + T &&value) { + ProcessPDLValue::processAsResult(rewriter, results, + std::forward(value)); +} - /// Merge the operations of block 'source' into the end of block 'dest'. - /// 'source's predecessors must either be empty or only contain 'dest`. - /// 'argValues' is used to replace the block arguments of 'source' after - /// merging. - virtual void mergeBlocks(Block *source, Block *dest, - ValueRange argValues = llvm::None); +/// Store a std::pair<> as individual results within the result list. +template +static void processResults(PatternRewriter &rewriter, PDLResultList &results, + std::pair &&pair) { + processResults(rewriter, results, std::move(pair.first)); + processResults(rewriter, results, std::move(pair.second)); +} - // Merge the operations of block 'source' before the operation 'op'. Source - // block should not have existing predecessors or successors. - void mergeBlockBefore(Block *source, Operation *op, - ValueRange argValues = llvm::None); +/// Store a std::tuple<> as individual results within the result list. +template +static void processResults(PatternRewriter &rewriter, PDLResultList &results, + std::tuple &&tuple) { + auto applyFn = [&](auto &&...args) { + // TODO: Use proper fold expressions when we have C++17. For now we use a + // bogus std::initializer_list to work around C++14 limitations. + (void)std::initializer_list{ + (processResults(rewriter, results, std::move(args)), 0)...}; + }; + llvm::apply_tuple(applyFn, std::move(tuple)); +} - /// Split the operations starting at "before" (inclusive) out of the given - /// block into a new block, and return it. - virtual Block *splitBlock(Block *block, Block::iterator before); +//===----------------------------------------------------------------------===// +// PDL Constraint Builder +//===----------------------------------------------------------------------===// - /// This method is used to notify the rewriter that an in-place operation - /// modification is about to happen. A call to this function *must* be - /// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`. - /// This is a minor efficiency win (it avoids creating a new operation and - /// removing the old one) but also often allows simpler code in the client. - virtual void startRootUpdate(Operation *op) {} +/// Process the arguments of a native constraint and invoke it. +template > +typename FnTraitsT::result_t +processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter, + ArrayRef values, + std::index_sequence) { + return fn( + rewriter, + (ProcessPDLValue>::processAsArg( + values[I]))...); +} - /// This method is used to signal the end of a root update on the given - /// operation. This can only be called on operations that were provided to a - /// call to `startRootUpdate`. - virtual void finalizeRootUpdate(Operation *op) {} +/// Build a constraint function from the given function `ConstraintFnT`. This +/// allows for enabling the user to define simpler, more direct constraint +/// functions without needing to handle the low-level PDL goop. +/// +/// If the constraint function is already in the correct form, we just forward +/// it directly. +template +std::enable_if_t< + std::is_convertible::value, + PDLConstraintFunction> +buildConstraintFn(ConstraintFnT &&constraintFn) { + return std::forward(constraintFn); +} +/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form +/// we desire. +template +std::enable_if_t< + !std::is_convertible::value, + PDLConstraintFunction> +buildConstraintFn(ConstraintFnT &&constraintFn) { + return [constraintFn = std::forward(constraintFn)]( + PatternRewriter &rewriter, + ArrayRef values) -> LogicalResult { + auto argIndices = std::make_index_sequence< + llvm::function_traits::num_args - 1>(); + if (failed(verifyAsArgs(rewriter, values, argIndices))) + return failure(); + return processArgsAndInvokeConstraint(constraintFn, rewriter, values, + argIndices); + }; +} - /// This method cancels a pending root update. This can only be called on - /// operations that were provided to a call to `startRootUpdate`. - virtual void cancelRootUpdate(Operation *op) {} +//===----------------------------------------------------------------------===// +// PDL Rewrite Builder +//===----------------------------------------------------------------------===// - /// This method is a utility wrapper around a root update of an operation. It - /// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given - /// callable. - template - void updateRootInPlace(Operation *root, CallableT &&callable) { - startRootUpdate(root); - callable(); - finalizeRootUpdate(root); - } +/// Process the arguments of a native rewrite and invoke it. +/// This overload handles the case of no return values. +template > +std::enable_if_t::value> +processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter, + PDLResultList &, ArrayRef values, + std::index_sequence) { + fn(rewriter, + (ProcessPDLValue>::processAsArg( + values[I]))...); +} +/// This overload handles the case of return values, which need to be packaged +/// into the result list. +template > +std::enable_if_t::value> +processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter, + PDLResultList &results, ArrayRef values, + std::index_sequence) { + processResults( + rewriter, results, + fn(rewriter, (ProcessPDLValue>:: + processAsArg(values[I]))...)); +} - /// Used to notify the rewriter that the IR failed to be rewritten because of - /// a match failure, and provide a callback to populate a diagnostic with the - /// reason why the failure occurred. This method allows for derived rewriters - /// to optionally hook into the reason why a rewrite failed, and display it to - /// users. - template - std::enable_if_t::value, LogicalResult> - notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) { -#ifndef NDEBUG - return notifyMatchFailure(op, - function_ref(reasonCallback)); -#else - return failure(); -#endif - } - LogicalResult notifyMatchFailure(Operation *op, const Twine &msg) { - return notifyMatchFailure(op, [&](Diagnostic &diag) { diag << msg; }); - } - LogicalResult notifyMatchFailure(Operation *op, const char *msg) { - return notifyMatchFailure(op, Twine(msg)); - } +/// Build a rewrite function from the given function `RewriteFnT`. This +/// allows for enabling the user to define simpler, more direct rewrite +/// functions without needing to handle the low-level PDL goop. +/// +/// If the rewrite function is already in the correct form, we just forward +/// it directly. +template +std::enable_if_t::value, + PDLRewriteFunction> +buildRewriteFn(RewriteFnT &&rewriteFn) { + return std::forward(rewriteFn); +} +/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form +/// we desire. +template +std::enable_if_t::value, + PDLRewriteFunction> +buildRewriteFn(RewriteFnT &&rewriteFn) { + return [rewriteFn = std::forward(rewriteFn)]( + PatternRewriter &rewriter, PDLResultList &results, + ArrayRef values) { + auto argIndices = + std::make_index_sequence::num_args - + 1>(); + assertArgs(rewriter, values, argIndices); + processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values, + argIndices); + }; +} -protected: - /// Initialize the builder with this rewriter as the listener. - explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx, /*listener=*/this) {} - explicit RewriterBase(const OpBuilder &otherBuilder) - : OpBuilder(otherBuilder) { - setListener(this); - } - ~RewriterBase() override; +} // namespace pdl_function_builder +} // namespace detail - /// These are the callback methods that subclasses can choose to implement if - /// they would like to be notified about certain types of mutations. +/// This class contains all of the necessary data for a set of PDL patterns, or +/// pattern rewrites specified in the form of the PDL dialect. This PDL module +/// contained by this pattern may contain any number of `pdl.pattern` +/// operations. +class PDLPatternModule { +public: + PDLPatternModule() = default; - /// Notify the rewriter that the specified operation is about to be replaced - /// with another set of operations. This is called before the uses of the - /// operation have been changed. - virtual void notifyRootReplaced(Operation *op) {} + /// Construct a PDL pattern with the given module. + PDLPatternModule(OwningOpRef pdlModule) + : pdlModule(std::move(pdlModule)) {} - /// This is called on an operation that a rewrite is removing, right before - /// the operation is deleted. At this point, the operation has zero uses. - virtual void notifyOperationRemoved(Operation *op) {} + /// Merge the state in `other` into this pattern module. + void mergeIn(PDLPatternModule &&other); - /// Notify the rewriter that the pattern failed to match the given operation, - /// and provide a callback to populate a diagnostic with the reason why the - /// failure occurred. This method allows for derived rewriters to optionally - /// hook into the reason why a rewrite failed, and display it to users. - virtual LogicalResult - notifyMatchFailure(Operation *op, - function_ref reasonCallback) { - return failure(); - } + /// Return the internal PDL module of this pattern. + ModuleOp getModule() { return pdlModule.get(); } -private: - void operator=(const RewriterBase &) = delete; - RewriterBase(const RewriterBase &) = delete; + //===--------------------------------------------------------------------===// + // Function Registry - /// 'op' and 'newOp' are known to have the same number of results, replace the - /// uses of op with uses of newOp. - void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp); -}; + /// Register a constraint function with PDL. A constraint function may be + /// specified in one of two ways: + /// + /// * `LogicalResult (PatternRewriter &, ArrayRef)` + /// + /// In this overload the arguments of the constraint function are passed via + /// the low-level PDLValue form. + /// + /// * `LogicalResult (PatternRewriter &, ValueTs... values)` + /// + /// In this form the arguments of the constraint function are passed via the + /// expected high level C++ type. In this form, the framework will + /// automatically unwrap PDLValues and convert them to the expected ValueTs. + /// For example, if the constraint function accepts a `Operation *`, the + /// framework will automatically cast the input PDLValue. In the case of a + /// `StringRef`, the framework will automatically unwrap the argument as a + /// StringAttr and pass the underlying string value. To see the full list of + /// supported types, or to see how to add handling for custom types, view + /// the definition of `ProcessPDLValue` above. + void registerConstraintFunction(StringRef name, + PDLConstraintFunction constraintFn); + template + void registerConstraintFunction(StringRef name, + ConstraintFnT &&constraintFn) { + registerConstraintFunction(name, + detail::pdl_function_builder::buildConstraintFn( + std::forward(constraintFn))); + } -//===----------------------------------------------------------------------===// -// IRRewriter -//===----------------------------------------------------------------------===// + /// Register a rewrite function with PDL. A rewrite function may be specified + /// in one of two ways: + /// + /// * `void (PatternRewriter &, PDLResultList &, ArrayRef)` + /// + /// In this overload the arguments of the constraint function are passed via + /// the low-level PDLValue form, and the results are manually appended to + /// the given result list. + /// + /// * `ResultT (PatternRewriter &, ValueTs... values)` + /// + /// In this form the arguments and result of the rewrite function are passed + /// via the expected high level C++ type. In this form, the framework will + /// automatically unwrap the PDLValues arguments and convert them to the + /// expected ValueTs. It will also automatically handle the processing and + /// packaging of the result value to the result list. For example, if the + /// rewrite function takes a `Operation *`, the framework will automatically + /// cast the input PDLValue. In the case of a `StringRef`, the framework + /// will automatically unwrap the argument as a StringAttr and pass the + /// underlying string value. In the reverse case, if the rewrite returns a + /// StringRef or std::string, it will automatically package this as a + /// StringAttr and append it to the result list. To see the full list of + /// supported types, or to see how to add handling for custom types, view + /// the definition of `ProcessPDLValue` above. + void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn); + template + void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) { + registerRewriteFunction(name, detail::pdl_function_builder::buildRewriteFn( + std::forward(rewriteFn))); + } -/// This class coordinates rewriting a piece of IR outside of a pattern rewrite, -/// providing a way to keep track of the mutations made to the IR. This class -/// should only be used in situations where another `RewriterBase` instance, -/// such as a `PatternRewriter`, is not available. -class IRRewriter : public RewriterBase { -public: - explicit IRRewriter(MLIRContext *ctx) : RewriterBase(ctx) {} - explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {} -}; + /// Return the set of the registered constraint functions. + const llvm::StringMap &getConstraintFunctions() const { + return constraintFunctions; + } + llvm::StringMap takeConstraintFunctions() { + return constraintFunctions; + } + /// Return the set of the registered rewrite functions. + const llvm::StringMap &getRewriteFunctions() const { + return rewriteFunctions; + } + llvm::StringMap takeRewriteFunctions() { + return rewriteFunctions; + } -//===----------------------------------------------------------------------===// -// PatternRewriter -//===----------------------------------------------------------------------===// + /// Clear out the patterns and functions within this module. + void clear() { + pdlModule = nullptr; + constraintFunctions.clear(); + rewriteFunctions.clear(); + } -/// A special type of `RewriterBase` that coordinates the application of a -/// rewrite pattern on the current IR being matched, providing a way to keep -/// track of any mutations made. This class should be used to perform all -/// necessary IR mutations within a rewrite pattern, as the pattern driver may -/// be tracking various state that would be invalidated when a mutation takes -/// place. -class PatternRewriter : public RewriterBase { -public: - using RewriterBase::RewriterBase; +private: + /// The module containing the `pdl.pattern` operations. + OwningOpRef pdlModule; + + /// The external functions referenced from within the PDL module. + llvm::StringMap constraintFunctions; + llvm::StringMap rewriteFunctions; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 5b409d695baedd..d0be98a307d70d 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -629,7 +629,7 @@ class ConversionPatternRewriter final : public PatternRewriter { /// PatternRewriter hook for notifying match failure reasons. LogicalResult - notifyMatchFailure(Operation *op, + notifyMatchFailure(Location loc, function_ref reasonCallback) override; using PatternRewriter::notifyMatchFailure; diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp index 367f51ad601a84..c2dc41a81c6f8a 100644 --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -1340,7 +1340,7 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { }); // Invoke the constraint and jump to the proper destination. - selectJump(succeeded(constraintFn(args, rewriter))); + selectJump(succeeded(constraintFn(rewriter, args))); } void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { @@ -1357,7 +1357,7 @@ void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { // Execute the rewrite function. ByteCodeField numResults = read(); ByteCodeRewriteResultList results(numResults); - rewriteFn(args, rewriter, results); + rewriteFn(rewriter, results, args); assert(results.getResults().size() == numResults && "native PDL rewrite function returned unexpected number of results"); diff --git a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp index c6937aa736a11f..c3b5c957007e8d 100644 --- a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp +++ b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp @@ -184,9 +184,9 @@ void CodeGen::generateConstraintOrRewrite(StringRef name, bool isConstraint, .Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; }); }; os << "static " << (isConstraint ? "::mlir::LogicalResult " : "void ") << name - << "PDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, " - "::mlir::PatternRewriter &rewriter" - << (isConstraint ? "" : ", ::mlir::PDLResultList &results") << ") {\n"; + << "PDLFn(::mlir::PatternRewriter &rewriter, " + << (isConstraint ? "" : "::mlir::PDLResultList &results, ") + << "::llvm::ArrayRef<::mlir::PDLValue> values) {\n"; const char *argumentInitStr = R"( {0} {1} = {{}; diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index bdeb0fa222b31a..575b3cbd3c3358 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1673,8 +1673,8 @@ void ConversionPatternRewriter::cancelRootUpdate(Operation *op) { } LogicalResult ConversionPatternRewriter::notifyMatchFailure( - Operation *op, function_ref reasonCallback) { - return impl->notifyMatchFailure(op->getLoc(), reasonCallback); + Location loc, function_ref reasonCallback) { + return impl->notifyMatchFailure(loc, reasonCallback); } detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 7fd46c711db01e..81b57c420a7263 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -76,7 +76,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter { /// PatternRewriter hook for notifying match failure reasons. LogicalResult - notifyMatchFailure(Operation *op, + notifyMatchFailure(Location loc, function_ref reasonCallback) override; /// The low-level pattern applicator. @@ -348,9 +348,9 @@ void GreedyPatternRewriteDriver::eraseOp(Operation *op) { } LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure( - Operation *op, function_ref reasonCallback) { + Location loc, function_ref reasonCallback) { LLVM_DEBUG({ - Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark); + Diagnostic diag(loc, DiagnosticSeverity::Remark); reasonCallback(diag); logger.startLine() << "** Failure : " << diag.str() << "\n"; }); diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir index d06c500241b0b3..e1a8c6081d4e51 100644 --- a/mlir/test/Rewrite/pdl-bytecode.mlir +++ b/mlir/test/Rewrite/pdl-bytecode.mlir @@ -181,8 +181,9 @@ module @patterns { module @rewriters { pdl_interp.func @success(%root : !pdl.operation) { + %attr = pdl_interp.apply_rewrite "str_creator" : !pdl.attribute %type = pdl_interp.apply_rewrite "type_creator" : !pdl.type - %newOp = pdl_interp.create_operation "test.success" -> (%type : !pdl.type) + %newOp = pdl_interp.create_operation "test.success" {"attr" = %attr} -> (%type : !pdl.type) pdl_interp.erase %root pdl_interp.finalize } @@ -190,7 +191,7 @@ module @patterns { } // CHECK-LABEL: test.apply_rewrite_4 -// CHECK: "test.success"() : () -> f32 +// CHECK: "test.success"() {attr = "test.str"} : () -> f32 module @ir attributes { test.apply_rewrite_4 } { "test.op"() : () -> () } diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp index 13465ba2865e02..daa1c371f27c92 100644 --- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp +++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp @@ -14,53 +14,42 @@ using namespace mlir; /// Custom constraint invoked from PDL. -static LogicalResult customSingleEntityConstraint(PDLValue value, - PatternRewriter &rewriter) { - Operation *rootOp = value.cast(); +static LogicalResult customSingleEntityConstraint(PatternRewriter &rewriter, + Operation *rootOp) { return success(rootOp->getName().getStringRef() == "test.op"); } -static LogicalResult customMultiEntityConstraint(ArrayRef values, - PatternRewriter &rewriter) { - return customSingleEntityConstraint(values[1], rewriter); +static LogicalResult customMultiEntityConstraint(PatternRewriter &rewriter, + Operation *root, + Operation *rootCopy) { + return customSingleEntityConstraint(rewriter, rootCopy); } -static LogicalResult -customMultiEntityVariadicConstraint(ArrayRef values, - PatternRewriter &rewriter) { - if (llvm::any_of(values, [](const PDLValue &value) { return !value; })) - return failure(); - ValueRange operandValues = values[0].cast(); - TypeRange typeValues = values[1].cast(); +static LogicalResult customMultiEntityVariadicConstraint( + PatternRewriter &rewriter, ValueRange operandValues, TypeRange typeValues) { if (operandValues.size() != 2 || typeValues.size() != 2) return failure(); return success(); } // Custom creator invoked from PDL. -static void customCreate(ArrayRef args, PatternRewriter &rewriter, - PDLResultList &results) { - results.push_back(rewriter.create( - OperationState(args[0].cast()->getLoc(), "test.success"))); +static Operation *customCreate(PatternRewriter &rewriter, Operation *op) { + return rewriter.create(OperationState(op->getLoc(), "test.success")); } - -static void customVariadicResultCreate(ArrayRef args, - PatternRewriter &rewriter, - PDLResultList &results) { - Operation *root = args[0].cast(); - results.push_back(root->getOperands()); - results.push_back(root->getOperands().getTypes()); +static auto customVariadicResultCreate(PatternRewriter &rewriter, + Operation *root) { + return std::make_pair(root->getOperands(), root->getOperands().getTypes()); +} +static Type customCreateType(PatternRewriter &rewriter) { + return rewriter.getF32Type(); } -static void customCreateType(ArrayRef args, PatternRewriter &rewriter, - PDLResultList &results) { - results.push_back(rewriter.getF32Type()); +static std::string customCreateStrAttr(PatternRewriter &rewriter) { + return "test.str"; } /// Custom rewriter invoked from PDL. -static void customRewriter(ArrayRef args, PatternRewriter &rewriter, - PDLResultList &results) { - Operation *root = args[0].cast(); - OperationState successOpState(root->getLoc(), "test.success"); - successOpState.addOperands(args[1].cast()); - rewriter.create(successOpState); +static void customRewriter(PatternRewriter &rewriter, Operation *root, + Value input) { + rewriter.create(root->getLoc(), rewriter.getStringAttr("test.success"), + input); rewriter.eraseOp(root); } @@ -117,6 +106,7 @@ struct TestPDLByteCodePass pdlPattern.registerRewriteFunction("var_creator", customVariadicResultCreate); pdlPattern.registerRewriteFunction("type_creator", customCreateType); + pdlPattern.registerRewriteFunction("str_creator", customCreateStrAttr); pdlPattern.registerRewriteFunction("rewriter", customRewriter); patternList.add(std::move(pdlPattern)); diff --git a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll index 9f0ea1386322de..802958a3872f22 100644 --- a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll +++ b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll @@ -43,8 +43,8 @@ Pattern => erase op; // Check the generation of native constraints and rewrites. -// CHECK: static ::mlir::LogicalResult TestCstPDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, -// CHECK-SAME: ::mlir::PatternRewriter &rewriter) { +// CHECK: static ::mlir::LogicalResult TestCstPDLFn(::mlir::PatternRewriter &rewriter, +// CHECK-SAME: ::llvm::ArrayRef<::mlir::PDLValue> values) { // CHECK: ::mlir::Attribute attr = {}; // CHECK: if (values[0]) // CHECK: attr = values[0].cast<::mlir::Attribute>(); @@ -69,8 +69,8 @@ Pattern => erase op; // CHECK-NOT: TestUnusedCst -// CHECK: static void TestRewritePDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, -// CHECK-SAME: ::mlir::PatternRewriter &rewriter, ::mlir::PDLResultList &results) { +// CHECK: static void TestRewritePDLFn(::mlir::PatternRewriter &rewriter, ::mlir::PDLResultList &results, +// CHECK-SAME: ::llvm::ArrayRef<::mlir::PDLValue> values) { // CHECK: ::mlir::Attribute attr = {}; // CHECK: ::mlir::Operation * op = {}; // CHECK: ::mlir::Type type = {};