230 changes: 67 additions & 163 deletions mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,55 +58,74 @@ class PatternBenefit {
};

//===----------------------------------------------------------------------===//
// Pattern class
// Pattern
//===----------------------------------------------------------------------===//

/// Instances of Pattern can be matched against SSA IR. These matches get used
/// in ways dependent on their subclasses and the driver doing the matching.
/// For example, RewritePatterns implement a rewrite from one matched pattern
/// to a replacement DAG tile.
/// This class contains all of the data related to a pattern, but does not
/// contain any methods or logic for the actual matching. This class is solely
/// used to interface with the metadata of a pattern, such as the benefit or
/// root operation.
class Pattern {
public:
/// Return a list of operations that may be generated when rewriting an
/// operation instance with this pattern.
ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }

/// Return the root node that this pattern matches. Patterns that can match
/// multiple root types return None.
Optional<OperationName> getRootKind() const { return rootKind; }

/// Return the benefit (the inverse of "cost") of matching this pattern. The
/// benefit of a Pattern is always static - rewrites that may have dynamic
/// benefit can be instantiated multiple times (different Pattern instances)
/// for each benefit that they may return, and be guarded by different match
/// condition predicates.
PatternBenefit getBenefit() const { return benefit; }

/// Return the root node that this pattern matches. Patterns that can match
/// multiple root types return None.
Optional<OperationName> getRootKind() const { return rootKind; }

//===--------------------------------------------------------------------===//
// Implementation hooks for patterns to implement.
//===--------------------------------------------------------------------===//

/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind().
virtual LogicalResult match(Operation *op) const = 0;

virtual ~Pattern() {}
/// Returns true if this pattern is known to result in recursive application,
/// i.e. this pattern may generate IR that also matches this pattern, but is
/// known to bound the recursion. This signals to a rewrite driver that it is
/// safe to apply this pattern recursively to generated IR.
bool hasBoundedRewriteRecursion() const { return hasBoundedRecursion; }

protected:
/// This class acts as a special tag that makes the desire to match "any"
/// operation type explicit. This helps to avoid unnecessary usages of this
/// feature, and ensures that the user is making a conscious decision.
struct MatchAnyOpTypeTag {};

/// This constructor is used for patterns that match against a specific
/// operation type. The `benefit` is the expected benefit of matching this
/// pattern.
/// Construct a pattern with a certain benefit that matches the operation
/// with the given root name.
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context);

/// This constructor is used when a pattern may match against multiple
/// different types of operations. The `benefit` is the expected benefit of
/// matching this pattern. `MatchAnyOpTypeTag` is just a tag to ensure that
/// the "match any" behavior is what the user actually desired,
/// `MatchAnyOpTypeTag()` should always be supplied here.
Pattern(PatternBenefit benefit, MatchAnyOpTypeTag);
/// Construct a pattern with a certain benefit that matches any operation
/// type. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
/// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
/// always be supplied here.
Pattern(PatternBenefit benefit, MatchAnyOpTypeTag tag);
/// Construct a pattern with a certain benefit that matches the operation with
/// the given root name. `generatedNames` contains the names of operations
/// that may be generated during a successful rewrite.
Pattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
PatternBenefit benefit, MLIRContext *context);
/// Construct a pattern that may match any operation type. `generatedNames`
/// contains the names of operations that may be generated during a successful
/// rewrite. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
/// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
/// always be supplied here.
Pattern(ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
MLIRContext *context, MatchAnyOpTypeTag tag);

/// Set the flag detailing if this pattern has bounded rewrite recursion or
/// not.
void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg = true) {
hasBoundedRecursion = hasBoundedRecursionArg;
}

private:
/// A list of the potential operations that may be generated when rewriting
/// an op with this pattern.
SmallVector<OperationName, 2> generatedOps;

/// The root operation of the pattern. If the pattern matches a specific
/// operation, this contains the name of that operation. Contains None
/// otherwise.
Expand All @@ -115,9 +134,14 @@ class Pattern {
/// The expected benefit of matching this pattern.
const PatternBenefit benefit;

virtual void anchor();
/// A boolean flag of whether this pattern has bounded recursion or not.
bool hasBoundedRecursion = false;
};

//===----------------------------------------------------------------------===//
// RewritePattern
//===----------------------------------------------------------------------===//

/// RewritePattern is the common base class for all DAG to DAG replacements.
/// There are two possible usages of this class:
/// * Multi-step RewritePattern with "match" and "rewrite"
Expand All @@ -129,6 +153,8 @@ class Pattern {
///
class RewritePattern : public Pattern {
public:
virtual ~RewritePattern() {}

/// Rewrite the IR rooted at the specified operation with the result of
/// this pattern, generating any new operations with the specified
/// builder. If an unexpected error is encountered (an internal
Expand All @@ -138,7 +164,7 @@ class RewritePattern : public Pattern {

/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind().
LogicalResult match(Operation *op) const override;
virtual LogicalResult match(Operation *op) const;

/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind(). If successful, this
Expand All @@ -152,44 +178,12 @@ class RewritePattern : public Pattern {
return failure();
}

/// Returns true if this pattern is known to result in recursive application,
/// i.e. this pattern may generate IR that also matches this pattern, but is
/// known to bound the recursion. This signals to a rewriter that it is safe
/// to apply this pattern recursively to generated IR.
virtual bool hasBoundedRewriteRecursion() const { return false; }

/// Return a list of operations that may be generated when rewriting an
/// operation instance with this pattern.
ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }

protected:
/// Construct a rewrite pattern with a certain benefit that matches the
/// operation with the given root name.
RewritePattern(StringRef rootName, PatternBenefit benefit,
MLIRContext *context)
: Pattern(rootName, benefit, context) {}
/// Construct a rewrite pattern with a certain benefit that matches any
/// operation type. `MatchAnyOpTypeTag` is just a tag to ensure that the
/// "match any" behavior is what the user actually desired,
/// `MatchAnyOpTypeTag()` should always be supplied here.
RewritePattern(PatternBenefit benefit, MatchAnyOpTypeTag tag)
: Pattern(benefit, tag) {}
/// Construct a rewrite pattern with a certain benefit that matches the
/// operation with the given root name. `generatedNames` contains the names of
/// operations that may be generated during a successful rewrite.
RewritePattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
PatternBenefit benefit, MLIRContext *context);
/// Construct a rewrite pattern that may match any operation type.
/// `generatedNames` contains the names of operations that may be generated
/// during a successful rewrite. `MatchAnyOpTypeTag` is just a tag to ensure
/// that the "match any" behavior is what the user actually desired,
/// `MatchAnyOpTypeTag()` should always be supplied here.
RewritePattern(ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
MLIRContext *context, MatchAnyOpTypeTag tag);
/// Inherit the base constructors from `Pattern`.
using Pattern::Pattern;

/// A list of the potential operations that may be generated when rewriting
/// an op with this pattern.
SmallVector<OperationName, 2> generatedOps;
/// An anchor for the virtual table.
virtual void anchor();
};

/// OpRewritePattern is a wrapper around RewritePattern that allows for
Expand Down Expand Up @@ -232,7 +226,7 @@ template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
};

//===----------------------------------------------------------------------===//
// PatternRewriter class
// PatternRewriter
//===----------------------------------------------------------------------===//

/// This class coordinates the application of a pattern to the current function,
Expand Down Expand Up @@ -422,12 +416,9 @@ class PatternRewriter : public OpBuilder, public OpBuilder::Listener {
void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp);
};

//===----------------------------------------------------------------------===//
// Pattern-driven rewriters
//===----------------------------------------------------------------------===//

//===----------------------------------------------------------------------===//
// OwningRewritePatternList
//===----------------------------------------------------------------------===//

class OwningRewritePatternList {
using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
Expand All @@ -449,6 +440,11 @@ class OwningRewritePatternList {
PatternListT::size_type size() const { return patterns.size(); }
void clear() { patterns.clear(); }

/// Take ownership of the patterns held by this list.
std::vector<std::unique_ptr<RewritePattern>> takePatterns() {
return std::move(patterns);
}

//===--------------------------------------------------------------------===//
// Pattern Insertion
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -487,98 +483,6 @@ class OwningRewritePatternList {
PatternListT patterns;
};

//===----------------------------------------------------------------------===//
// PatternApplicator

/// This class manages the application of a group of rewrite patterns, with a
/// user-provided cost model.
class PatternApplicator {
public:
/// The cost model dynamically assigns a PatternBenefit to a particular
/// pattern. Users can query contained patterns and pass analysis results to
/// applyCostModel. Patterns to be discarded should have a benefit of
/// `impossibleToMatch`.
using CostModel = function_ref<PatternBenefit(const RewritePattern &)>;

explicit PatternApplicator(const OwningRewritePatternList &owningPatternList)
: owningPatternList(owningPatternList) {}

/// Attempt to match and rewrite the given op with any pattern, allowing a
/// predicate to decide if a pattern can be applied or not, and hooks for if
/// the pattern match was a success or failure.
///
/// canApply: called before each match and rewrite attempt; return false to
/// skip pattern.
/// onFailure: called when a pattern fails to match to perform cleanup.
/// onSuccess: called when a pattern match succeeds; return failure() to
/// invalidate the match and try another pattern.
LogicalResult matchAndRewrite(
Operation *op, PatternRewriter &rewriter,
function_ref<bool(const RewritePattern &)> canApply = {},
function_ref<void(const RewritePattern &)> onFailure = {},
function_ref<LogicalResult(const RewritePattern &)> onSuccess = {});

/// Apply a cost model to the patterns within this applicator.
void applyCostModel(CostModel model);

/// Apply the default cost model that solely uses the pattern's static
/// benefit.
void applyDefaultCostModel() {
applyCostModel(
[](const RewritePattern &pattern) { return pattern.getBenefit(); });
}

/// Walk all of the rewrite patterns within the applicator.
void walkAllPatterns(function_ref<void(const RewritePattern &)> walk);

private:
/// Attempt to match and rewrite the given op with the given pattern, allowing
/// a predicate to decide if a pattern can be applied or not, and hooks for if
/// the pattern match was a success or failure.
LogicalResult matchAndRewrite(
Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter,
function_ref<bool(const RewritePattern &)> canApply,
function_ref<void(const RewritePattern &)> onFailure,
function_ref<LogicalResult(const RewritePattern &)> onSuccess);

/// The list that owns the patterns used within this applicator.
const OwningRewritePatternList &owningPatternList;

/// The set of patterns to match for each operation, stable sorted by benefit.
DenseMap<OperationName, SmallVector<RewritePattern *, 2>> patterns;
/// The set of patterns that may match against any operation type, stable
/// sorted by benefit.
SmallVector<RewritePattern *, 1> anyOpPatterns;
};

//===----------------------------------------------------------------------===//
// applyPatternsGreedily
//===----------------------------------------------------------------------===//

/// Rewrite the regions of the specified operation, which must be isolated from
/// above, by repeatedly applying the highest benefit patterns in a greedy
/// work-list driven manner. Return success if no more patterns can be matched
/// in the result operation regions.
/// Note: This does not apply patterns to the top-level operation itself. Note:
/// These methods also perform folding and simple dead-code elimination
/// before attempting to match any of the provided patterns.
///
LogicalResult
applyPatternsAndFoldGreedily(Operation *op,
const OwningRewritePatternList &patterns);
/// Rewrite the given regions, which must be isolated from above.
LogicalResult
applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
const OwningRewritePatternList &patterns);

/// Applies the specified patterns on `op` alone while also trying to fold it,
/// by selecting the highest benefits patterns in a greedy manner. Returns
/// success if no more patterns can be matched. `erased` is set to true if `op`
/// was folded away or erased as a result of becoming dead. Note: This does not
/// apply any patterns recursively to the regions of `op`.
LogicalResult applyOpPatternsAndFold(Operation *op,
const OwningRewritePatternList &patterns,
bool *erased = nullptr);
} // end namespace mlir

#endif // MLIR_PATTERN_MATCH_H
38 changes: 38 additions & 0 deletions mlir/include/mlir/Rewrite/FrozenRewritePatternList.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//===- FrozenRewritePatternList.h - FrozenRewritePatternList ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_REWRITE_FROZENREWRITEPATTERNLIST_H
#define MLIR_REWRITE_FROZENREWRITEPATTERNLIST_H

#include "mlir/IR/PatternMatch.h"

namespace mlir {
/// This class represents a frozen set of patterns that can be processed by a
/// pattern applicator. This class is designed to enable caching pattern lists
/// such that they need not be continuously recomputed.
class FrozenRewritePatternList {
using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;

public:
/// Freeze the patterns held in `patterns`, and take ownership.
FrozenRewritePatternList(OwningRewritePatternList &&patterns);

/// Return the patterns held by this list.
iterator_range<llvm::pointee_iterator<PatternListT::const_iterator>>
getPatterns() const {
return llvm::make_pointee_range(patterns);
}

private:
/// The patterns held by this list.
std::vector<std::unique_ptr<RewritePattern>> patterns;
};

} // end namespace mlir

#endif // MLIR_REWRITE_FROZENREWRITEPATTERNLIST_H
84 changes: 84 additions & 0 deletions mlir/include/mlir/Rewrite/PatternApplicator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
//===- PatternApplicator.h - PatternApplicator -------==---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements an applicator that applies pattern rewrites based upon a
// user defined cost model.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_REWRITE_PATTERNAPPLICATOR_H
#define MLIR_REWRITE_PATTERNAPPLICATOR_H

#include "mlir/Rewrite/FrozenRewritePatternList.h"

namespace mlir {
class PatternRewriter;

/// This class manages the application of a group of rewrite patterns, with a
/// user-provided cost model.
class PatternApplicator {
public:
/// The cost model dynamically assigns a PatternBenefit to a particular
/// pattern. Users can query contained patterns and pass analysis results to
/// applyCostModel. Patterns to be discarded should have a benefit of
/// `impossibleToMatch`.
using CostModel = function_ref<PatternBenefit(const Pattern &)>;

explicit PatternApplicator(const FrozenRewritePatternList &frozenPatternList)
: frozenPatternList(frozenPatternList) {}

/// Attempt to match and rewrite the given op with any pattern, allowing a
/// predicate to decide if a pattern can be applied or not, and hooks for if
/// the pattern match was a success or failure.
///
/// canApply: called before each match and rewrite attempt; return false to
/// skip pattern.
/// onFailure: called when a pattern fails to match to perform cleanup.
/// onSuccess: called when a pattern match succeeds; return failure() to
/// invalidate the match and try another pattern.
LogicalResult
matchAndRewrite(Operation *op, PatternRewriter &rewriter,
function_ref<bool(const Pattern &)> canApply = {},
function_ref<void(const Pattern &)> onFailure = {},
function_ref<LogicalResult(const Pattern &)> onSuccess = {});

/// Apply a cost model to the patterns within this applicator.
void applyCostModel(CostModel model);

/// Apply the default cost model that solely uses the pattern's static
/// benefit.
void applyDefaultCostModel() {
applyCostModel([](const Pattern &pattern) { return pattern.getBenefit(); });
}

/// Walk all of the patterns within the applicator.
void walkAllPatterns(function_ref<void(const Pattern &)> walk);

private:
/// Attempt to match and rewrite the given op with the given pattern, allowing
/// a predicate to decide if a pattern can be applied or not, and hooks for if
/// the pattern match was a success or failure.
LogicalResult
matchAndRewrite(Operation *op, const RewritePattern &pattern,
PatternRewriter &rewriter,
function_ref<bool(const Pattern &)> canApply,
function_ref<void(const Pattern &)> onFailure,
function_ref<LogicalResult(const Pattern &)> onSuccess);

/// The list that owns the patterns used within this applicator.
const FrozenRewritePatternList &frozenPatternList;
/// The set of patterns to match for each operation, stable sorted by benefit.
DenseMap<OperationName, SmallVector<const RewritePattern *, 2>> patterns;
/// The set of patterns that may match against any operation type, stable
/// sorted by benefit.
SmallVector<const RewritePattern *, 1> anyOpPatterns;
};

} // end namespace mlir

#endif // MLIR_REWRITE_PATTERNAPPLICATOR_H
3 changes: 1 addition & 2 deletions mlir/include/mlir/Support/StorageUniquer.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,7 @@ class StorageUniquer {
}
/// Utility override when the storage type represents the type id.
template <typename Storage>
void registerSingletonStorageType(
function_ref<void(Storage *)> initFn = llvm::None) {
void registerSingletonStorageType(function_ref<void(Storage *)> initFn = {}) {
registerSingletonStorageType<Storage>(TypeID::get<Storage>(), initFn);
}

Expand Down
16 changes: 7 additions & 9 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
#ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_
#define MLIR_TRANSFORMS_DIALECTCONVERSION_H_

#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Rewrite/FrozenRewritePatternList.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/StringMap.h"

Expand Down Expand Up @@ -805,11 +803,11 @@ class ConversionTarget {
/// the `unconvertedOps` set will not necessarily be complete.)
LLVM_NODISCARD LogicalResult
applyPartialConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
const OwningRewritePatternList &patterns,
const FrozenRewritePatternList &patterns,
DenseSet<Operation *> *unconvertedOps = nullptr);
LLVM_NODISCARD LogicalResult
applyPartialConversion(Operation *op, ConversionTarget &target,
const OwningRewritePatternList &patterns,
const FrozenRewritePatternList &patterns,
DenseSet<Operation *> *unconvertedOps = nullptr);

/// Apply a complete conversion on the given operations, and all nested
Expand All @@ -818,10 +816,10 @@ applyPartialConversion(Operation *op, ConversionTarget &target,
/// within 'ops'.
LLVM_NODISCARD LogicalResult
applyFullConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
const OwningRewritePatternList &patterns);
const FrozenRewritePatternList &patterns);
LLVM_NODISCARD LogicalResult
applyFullConversion(Operation *op, ConversionTarget &target,
const OwningRewritePatternList &patterns);
const FrozenRewritePatternList &patterns);

/// Apply an analysis conversion on the given operations, and all nested
/// operations. This method analyzes which operations would be successfully
Expand All @@ -833,11 +831,11 @@ applyFullConversion(Operation *op, ConversionTarget &target,
/// the regions nested within 'ops'.
LLVM_NODISCARD LogicalResult
applyAnalysisConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
const OwningRewritePatternList &patterns,
const FrozenRewritePatternList &patterns,
DenseSet<Operation *> &convertedOps);
LLVM_NODISCARD LogicalResult
applyAnalysisConversion(Operation *op, ConversionTarget &target,
const OwningRewritePatternList &patterns,
const FrozenRewritePatternList &patterns,
DenseSet<Operation *> &convertedOps);
} // end namespace mlir

Expand Down
52 changes: 52 additions & 0 deletions mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares methods for applying a set of patterns greedily, choosing
// the patterns with the highest local benefit, until a fixed point is reached.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_
#define MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_

#include "mlir/Rewrite/FrozenRewritePatternList.h"

namespace mlir {

//===----------------------------------------------------------------------===//
// applyPatternsGreedily
//===----------------------------------------------------------------------===//

/// Rewrite the regions of the specified operation, which must be isolated from
/// above, by repeatedly applying the highest benefit patterns in a greedy
/// work-list driven manner. Return success if no more patterns can be matched
/// in the result operation regions.
/// Note: This does not apply patterns to the top-level operation itself. Note:
/// These methods also perform folding and simple dead-code elimination
/// before attempting to match any of the provided patterns.
///
LogicalResult
applyPatternsAndFoldGreedily(Operation *op,
const FrozenRewritePatternList &patterns);
/// Rewrite the given regions, which must be isolated from above.
LogicalResult
applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
const FrozenRewritePatternList &patterns);

/// Applies the specified patterns on `op` alone while also trying to fold it,
/// by selecting the highest benefits patterns in a greedy manner. Returns
/// success if no more patterns can be matched. `erased` is set to true if `op`
/// was folded away or erased as a result of becoming dead. Note: This does not
/// apply any patterns recursively to the regions of `op`.
LogicalResult applyOpPatternsAndFold(Operation *op,
const FrozenRewritePatternList &patterns,
bool *erased = nullptr);

} // end namespace mlir

#endif // MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_
1 change: 1 addition & 0 deletions mlir/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ add_subdirectory(Interfaces)
add_subdirectory(Parser)
add_subdirectory(Pass)
add_subdirectory(Reducer)
add_subdirectory(Rewrite)
add_subdirectory(Support)
add_subdirectory(TableGen)
add_subdirectory(Target)
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ void ConvertAVX512ToLLVMPass::runOnOperation() {
target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalDialect<LLVM::LLVMAVX512Dialect>();
target.addIllegalDialect<avx512::AVX512Dialect>();
if (failed(applyPartialConversion(getOperation(), target, patterns))) {
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
signalPassFailure();
}
}
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,8 @@ class LowerAffinePass : public ConvertAffineToStandardBase<LowerAffinePass> {
ConversionTarget target(getContext());
target
.addLegalDialect<scf::SCFDialect, StandardOpsDialect, VectorDialect>();
if (failed(applyPartialConversion(getOperation(), target, patterns)))
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
target.addDynamicallyLegalOp<CallOp>(
[&](CallOp op) { return converter.isLegal(op.getResultTypes()); });

if (failed(applyPartialConversion(module, target, patterns)))
if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
}
} // namespace
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ add_subdirectory(LinalgToLLVM)
add_subdirectory(LinalgToSPIRV)
add_subdirectory(LinalgToStandard)
add_subdirectory(OpenMPToLLVM)
add_subdirectory(PDLToPDLInterp)
add_subdirectory(SCFToGPU)
add_subdirectory(SCFToSPIRV)
add_subdirectory(SCFToStandard)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ void GpuToLLVMConversionPass::runOnOperation() {
populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation);

LLVMConversionTarget target(getContext());
if (failed(applyPartialConversion(getOperation(), target, patterns)))
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
}

Expand Down
12 changes: 6 additions & 6 deletions mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/FormatVariadic.h"

#include "../GPUCommon/GPUOpsLowering.h"
Expand Down Expand Up @@ -123,17 +124,16 @@ struct LowerGpuOpsToNVVMOpsPass
return converter.convertType(MemRefType::Builder(type).setMemorySpace(0));
});

OwningRewritePatternList patterns;
OwningRewritePatternList patterns, llvmPatterns;

// Apply in-dialect lowering first. In-dialect lowering will replace ops
// which need to be lowered further, which is not supported by a single
// conversion pass.
populateGpuRewritePatterns(m.getContext(), patterns);
applyPatternsAndFoldGreedily(m, patterns);
patterns.clear();
applyPatternsAndFoldGreedily(m, std::move(patterns));

populateStdToLLVMConversionPatterns(converter, patterns);
populateGpuToNVVMConversionPatterns(converter, patterns);
populateStdToLLVMConversionPatterns(converter, llvmPatterns);
populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
LLVMConversionTarget target(getContext());
target.addIllegalDialect<gpu::GPUDialect>();
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::FAbsOp, LLVM::FCeilOp,
Expand All @@ -143,7 +143,7 @@ struct LowerGpuOpsToNVVMOpsPass
target.addLegalDialect<NVVM::NVVMDialect>();
// TODO: Remove once we support replacing non-root ops.
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
if (failed(applyPartialConversion(m, target, patterns)))
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
signalPassFailure();
}
};
Expand Down
16 changes: 8 additions & 8 deletions mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/FormatVariadic.h"

#include "../GPUCommon/GPUOpsLowering.h"
Expand Down Expand Up @@ -58,16 +59,15 @@ struct LowerGpuOpsToROCDLOpsPass
/*useAlignedAlloc =*/false};
LLVMTypeConverter converter(m.getContext(), options);

OwningRewritePatternList patterns;
OwningRewritePatternList patterns, llvmPatterns;

populateGpuRewritePatterns(m.getContext(), patterns);
applyPatternsAndFoldGreedily(m, patterns);
patterns.clear();
applyPatternsAndFoldGreedily(m, std::move(patterns));

populateVectorToLLVMConversionPatterns(converter, patterns);
populateVectorToROCDLConversionPatterns(converter, patterns);
populateStdToLLVMConversionPatterns(converter, patterns);
populateGpuToROCDLConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
populateVectorToROCDLConversionPatterns(converter, llvmPatterns);
populateStdToLLVMConversionPatterns(converter, llvmPatterns);
populateGpuToROCDLConversionPatterns(converter, llvmPatterns);
LLVMConversionTarget target(getContext());
target.addIllegalDialect<gpu::GPUDialect>();
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::FAbsOp, LLVM::FCeilOp,
Expand All @@ -77,7 +77,7 @@ struct LowerGpuOpsToROCDLOpsPass
target.addLegalDialect<ROCDL::ROCDLDialect>();
// TODO: Remove once we support replacing non-root ops.
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
if (failed(applyPartialConversion(m, target, patterns)))
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
signalPassFailure();
}
};
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void GPUToSPIRVPass::runOnOperation() {
populateSCFToSPIRVPatterns(context, typeConverter,scfContext, patterns);
populateStandardToSPIRVPatterns(context, typeConverter, patterns);

if (failed(applyFullConversion(kernelModules, *target, patterns)))
if (failed(applyFullConversion(kernelModules, *target, std::move(patterns))))
return signalPassFailure();
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ void ConvertLinalgToLLVMPass::runOnOperation() {

LLVMConversionTarget target(getContext());
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
if (failed(applyFullConversion(module, target, patterns)))
if (failed(applyFullConversion(module, target, std::move(patterns))))
signalPassFailure();
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ void LinalgToSPIRVPass::runOnOperation() {
typeConverter.isLegal(&op.getBody());
});

if (failed(applyFullConversion(module, *target, patterns)))
if (failed(applyFullConversion(module, *target, std::move(patterns))))
return signalPassFailure();
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ void ConvertLinalgToStandardPass::runOnOperation() {
target.addLegalOp<linalg::ReshapeOp, linalg::RangeOp>();
OwningRewritePatternList patterns;
populateLinalgToStandardConversionPatterns(patterns, &getContext());
if (failed(applyFullConversion(module, target, patterns)))
if (failed(applyFullConversion(module, target, std::move(patterns))))
signalPassFailure();
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ void ConvertOpenMPToLLVMPass::runOnOperation() {
[&](omp::ParallelOp op) { return converter.isLegal(&op.getRegion()); });
target.addLegalOp<omp::TerminatorOp, omp::TaskyieldOp, omp::FlushOp,
omp::BarrierOp, omp::TaskwaitOp>();
if (failed(applyPartialConversion(module, target, patterns)))
if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
}

Expand Down
18 changes: 18 additions & 0 deletions mlir/lib/Conversion/PDLToPDLInterp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
add_mlir_conversion_library(MLIRPDLToPDLInterp
PDLToPDLInterp.cpp
Predicate.cpp
PredicateTree.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/PDLToPDLInterp

DEPENDS
MLIRConversionPassIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRInferTypeOpInterface
MLIRPDL
MLIRPDLInterp
MLIRPass
)
694 changes: 694 additions & 0 deletions mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp

Large diffs are not rendered by default.

49 changes: 49 additions & 0 deletions mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
//===- Predicate.cpp - Pattern predicates ---------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "Predicate.h"

using namespace mlir;
using namespace mlir::pdl_to_pdl_interp;

//===----------------------------------------------------------------------===//
// Positions
//===----------------------------------------------------------------------===//

Position::~Position() {}

//===----------------------------------------------------------------------===//
// AttributePosition

AttributePosition::AttributePosition(const KeyTy &key) : Base(key) {
parent = key.first;
}

//===----------------------------------------------------------------------===//
// OperandPosition

OperandPosition::OperandPosition(const KeyTy &key) : Base(key) {
parent = key.first;
}

//===----------------------------------------------------------------------===//
// OperationPosition

OperationPosition *OperationPosition::get(StorageUniquer &uniquer,
ArrayRef<unsigned> index) {
assert(!index.empty() && "expected at least two indices");

// Set the parent position if this isn't the root.
Position *parent = nullptr;
if (index.size() > 1) {
auto *node = OperationPosition::get(uniquer, index.drop_back());
parent = OperandPosition::get(uniquer, std::make_pair(node, index.back()));
}
return uniquer.get<OperationPosition>(
[parent](OperationPosition *node) { node->parent = parent; }, index);
}
530 changes: 530 additions & 0 deletions mlir/lib/Conversion/PDLToPDLInterp/Predicate.h

Large diffs are not rendered by default.

462 changes: 462 additions & 0 deletions mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp

Large diffs are not rendered by default.

200 changes: 200 additions & 0 deletions mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
//===- PredicateTree.h - Predicate tree node definitions --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains definitions for nodes of a tree structure for representing
// the general control flow within a pattern match.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_
#define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_

#include "Predicate.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "llvm/ADT/MapVector.h"

namespace mlir {
namespace pdl_to_pdl_interp {

class MatcherNode;

/// A PositionalPredicate is a predicate that is associated with a specific
/// positional value.
struct PositionalPredicate {
PositionalPredicate(Position *pos,
const PredicateBuilder::Predicate &predicate)
: position(pos), question(predicate.first), answer(predicate.second) {}

/// The position the predicate is applied to.
Position *position;

/// The question that the predicate applies.
Qualifier *question;

/// The expected answer of the predicate.
Qualifier *answer;
};

//===----------------------------------------------------------------------===//
// MatcherNode
//===----------------------------------------------------------------------===//

/// This class represents the base of a predicate matcher node.
class MatcherNode {
public:
virtual ~MatcherNode() = default;

/// Given a module containing PDL pattern operations, generate a matcher tree
/// using the patterns within the given module and return the root matcher
/// node. `valueToPosition` is a map that is populated with the original
/// pdl values and their corresponding positions in the matcher tree.
static std::unique_ptr<MatcherNode>
generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
DenseMap<Value, Position *> &valueToPosition);

/// Returns the position on which the question predicate should be checked.
Position *getPosition() const { return position; }

/// Returns the predicate checked on this node.
Qualifier *getQuestion() const { return question; }

/// Returns the node that should be visited if this, or a subsequent node
/// fails.
std::unique_ptr<MatcherNode> &getFailureNode() { return failureNode; }

/// Sets the node that should be visited if this, or a subsequent node fails.
void setFailureNode(std::unique_ptr<MatcherNode> node) {
failureNode = std::move(node);
}

/// Returns the unique type ID of this matcher instance. This should not be
/// used directly, and is provided to support type casting.
TypeID getMatcherTypeID() const { return matcherTypeID; }

protected:
MatcherNode(TypeID matcherTypeID, Position *position = nullptr,
Qualifier *question = nullptr,
std::unique_ptr<MatcherNode> failureNode = nullptr);

private:
/// The position on which the predicate should be checked.
Position *position;

/// The predicate that is checked on the given position.
Qualifier *question;

/// The node to visit if this node fails.
std::unique_ptr<MatcherNode> failureNode;

/// An owning store for the failure node if it is owned by this node.
std::unique_ptr<MatcherNode> failureNodeStorage;

/// A unique identifier for the derived matcher node, used for type casting.
TypeID matcherTypeID;
};

//===----------------------------------------------------------------------===//
// BoolNode

/// A BoolNode denotes a question with a boolean-like result. These nodes branch
/// to a single node on a successful result, otherwise defaulting to the failure
/// node.
struct BoolNode : public MatcherNode {
BoolNode(Position *position, Qualifier *question, Qualifier *answer,
std::unique_ptr<MatcherNode> successNode,
std::unique_ptr<MatcherNode> failureNode = nullptr);

/// Returns if the given matcher node is an instance of this class, used to
/// support type casting.
static bool classof(const MatcherNode *node) {
return node->getMatcherTypeID() == TypeID::get<BoolNode>();
}

/// Returns the expected answer of this boolean node.
Qualifier *getAnswer() const { return answer; }

/// Returns the node that should be visited on success.
std::unique_ptr<MatcherNode> &getSuccessNode() { return successNode; }

private:
/// The expected answer of this boolean node.
Qualifier *answer;

/// The next node if this node succeeds. Otherwise, go to the failure node.
std::unique_ptr<MatcherNode> successNode;
};

//===----------------------------------------------------------------------===//
// ExitNode

/// An ExitNode is a special sentinel node that denotes the end of matcher.
struct ExitNode : public MatcherNode {
ExitNode() : MatcherNode(TypeID::get<ExitNode>()) {}

/// Returns if the given matcher node is an instance of this class, used to
/// support type casting.
static bool classof(const MatcherNode *node) {
return node->getMatcherTypeID() == TypeID::get<ExitNode>();
}
};

//===----------------------------------------------------------------------===//
// SuccessNode

/// A SuccessNode denotes that a given high level pattern has successfully been
/// matched. This does not terminate the matcher, as there may be multiple
/// successful matches.
struct SuccessNode : public MatcherNode {
explicit SuccessNode(pdl::PatternOp pattern,
std::unique_ptr<MatcherNode> failureNode);

/// Returns if the given matcher node is an instance of this class, used to
/// support type casting.
static bool classof(const MatcherNode *node) {
return node->getMatcherTypeID() == TypeID::get<SuccessNode>();
}

/// Return the high level pattern operation that is matched with this node.
pdl::PatternOp getPattern() const { return pattern; }

private:
/// The high level pattern operation that was successfully matched with this
/// node.
pdl::PatternOp pattern;
};

//===----------------------------------------------------------------------===//
// SwitchNode

/// A SwitchNode denotes a question with multiple potential results. These nodes
/// branch to a specific node based on the result of the question.
struct SwitchNode : public MatcherNode {
SwitchNode(Position *position, Qualifier *question);

/// Returns if the given matcher node is an instance of this class, used to
/// support type casting.
static bool classof(const MatcherNode *node) {
return node->getMatcherTypeID() == TypeID::get<SwitchNode>();
}

/// Returns the children of this switch node. The children are contained
/// within a mapping between the various case answers to destination matcher
/// nodes.
using ChildMapT = llvm::MapVector<Qualifier *, std::unique_ptr<MatcherNode>>;
ChildMapT &getChildren() { return children; }

private:
/// Switch predicate "answers" select the child. Answers that are not found
/// default to the failure node.
ChildMapT children;
};

} // end namespace pdl_to_pdl_interp
} // end namespace mlir

#endif // MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_
4 changes: 4 additions & 0 deletions mlir/lib/Conversion/PassDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ namespace NVVM {
class NVVMDialect;
} // end namespace NVVM

namespace pdl_interp {
class PDLInterpDialect;
} // end namespace pdl_interp

namespace ROCDL {
class ROCDLDialect;
} // end namespace ROCDL
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ struct ParallelLoopToGpuPass
target.addLegalDialect<gpu::GPUDialect>();
target.addLegalDialect<scf::SCFDialect>();
target.addIllegalOp<scf::ParallelOp>();
if (failed(applyPartialConversion(getOperation(), target, patterns)))
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,8 @@ void SCFToStandardPass::runOnOperation() {
ConversionTarget target(getContext());
target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (failed(applyPartialConversion(getOperation(), target, patterns)))
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ class LowerHostCodeToLLVM

ConversionTarget target(*context);
target.addLegalDialect<LLVM::LLVMDialect>();
if (failed(applyPartialConversion(module, target, patterns)))
if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();

// Finally, modify the kernel function in SPIR-V modules to avoid symbolic
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ void ConvertSPIRVToLLVMPass::runOnOperation() {
// conversion.
target.addLegalOp<ModuleOp>();
target.addLegalOp<ModuleTerminatorOp>();
if (failed(applyPartialConversion(module, target, patterns)))
if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;

Expand Down Expand Up @@ -123,7 +124,7 @@ class ConvertShapeConstraints
OwningRewritePatternList patterns;
populateConvertShapeConstraintsConversionPatterns(patterns, context);

if (failed(applyPatternsAndFoldGreedily(func, patterns)))
if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
return signalPassFailure();
}
};
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ void ConvertShapeToStandardPass::runOnOperation() {

// Apply conversion.
auto module = getOperation();
if (failed(applyPartialConversion(module, target, patterns)))
if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3788,7 +3788,7 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
populateStdToLLVMConversionPatterns(typeConverter, patterns);

LLVMConversionTarget target(getContext());
if (failed(applyPartialConversion(m, target, patterns)))
if (failed(applyPartialConversion(m, target, std::move(patterns))))
signalPassFailure();
m.setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
StringAttr::get(this->dataLayout, m.getContext()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ void ConvertStandardToSPIRVPass::runOnOperation() {
populateStandardToSPIRVPatterns(context, typeConverter, patterns);
populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns);

if (failed(applyPartialConversion(module, *target, patterns))) {
if (failed(applyPartialConversion(module, *target, std::move(patterns))))
return signalPassFailure();
}
}

std::unique_ptr<OperationPass<ModuleOp>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;

Expand Down Expand Up @@ -203,7 +203,8 @@ void SPIRVLegalization::runOnOperation() {
OwningRewritePatternList patterns;
auto *context = &getContext();
populateStdLegalizationPatternsForSPIRVLowering(context, patterns);
applyPatternsAndFoldGreedily(getOperation()->getRegions(), patterns);
applyPatternsAndFoldGreedily(getOperation()->getRegions(),
std::move(patterns));
}

std::unique_ptr<Pass> mlir::createLegalizeStdOpsForSPIRVLoweringPass() {
Expand Down
31 changes: 16 additions & 15 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,12 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/Target/LLVMIR/TypeTranslation.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Module.h"
Expand Down Expand Up @@ -1042,7 +1038,12 @@ class VectorInsertStridedSliceOpDifferentRankRewritePattern
class VectorInsertStridedSliceOpSameRankRewritePattern
: public OpRewritePattern<InsertStridedSliceOp> {
public:
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx)
: OpRewritePattern<InsertStridedSliceOp>(ctx) {
// This pattern creates recursive InsertStridedSliceOp, but the recursion is
// bounded as the rank is strictly decreasing.
setHasBoundedRewriteRecursion();
}

LogicalResult matchAndRewrite(InsertStridedSliceOp op,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -1093,9 +1094,6 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
rewriter.replaceOp(op, res);
return success();
}
/// This pattern creates recursive InsertStridedSliceOp, but the recursion is
/// bounded as the rank is strictly decreasing.
bool hasBoundedRewriteRecursion() const final { return true; }
};

/// Returns the strides if the memory underlying `memRefType` has a contiguous
Expand Down Expand Up @@ -1505,7 +1503,12 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
class VectorExtractStridedSliceOpConversion
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
VectorExtractStridedSliceOpConversion(MLIRContext *ctx)
: OpRewritePattern<ExtractStridedSliceOp>(ctx) {
// This pattern creates recursive ExtractStridedSliceOp, but the recursion
// is bounded as the rank is strictly decreasing.
setHasBoundedRewriteRecursion();
}

LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -1552,9 +1555,6 @@ class VectorExtractStridedSliceOpConversion
rewriter.replaceOp(op, res);
return success();
}
/// This pattern creates recursive ExtractStridedSliceOp, but the recursion is
/// bounded as the rank is strictly decreasing.
bool hasBoundedRewriteRecursion() const final { return true; }
};

} // namespace
Expand Down Expand Up @@ -1619,7 +1619,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
populateVectorToVectorCanonicalizationPatterns(patterns, &getContext());
populateVectorSlicesLoweringPatterns(patterns, &getContext());
populateVectorContractLoweringPatterns(patterns, &getContext());
applyPatternsAndFoldGreedily(getOperation(), patterns);
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}

// Convert to the LLVM IR dialect.
Expand All @@ -1632,7 +1632,8 @@ void LowerVectorToLLVMPass::runOnOperation() {
populateStdToLLVMConversionPatterns(converter, patterns);

LLVMConversionTarget target(getContext());
if (failed(applyPartialConversion(getOperation(), target, patterns)))
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
}

Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ void LowerVectorToROCDLPass::runOnOperation() {
LLVMConversionTarget target(getContext());
target.addLegalDialect<ROCDL::ROCDLDialect>();

if (failed(applyPartialConversion(getOperation(), target, patterns)))
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
}

Expand Down
8 changes: 2 additions & 6 deletions mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,10 @@
#include "mlir/Dialect/Vector/VectorUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"

using namespace mlir;
Expand Down Expand Up @@ -714,7 +710,7 @@ struct ConvertVectorToSCFPass
auto *context = getFunction().getContext();
populateVectorToSCFConversionPatterns(
patterns, context, VectorTransferToSCFOptions().setUnroll(fullUnroll));
applyPatternsAndFoldGreedily(getFunction(), patterns);
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ void LowerVectorToSPIRVPass::runOnOperation() {
target->addLegalOp<ModuleOp, ModuleTerminatorOp>();
target->addLegalOp<FuncOp>();

if (failed(applyFullConversion(module, *target, patterns)))
if (failed(applyFullConversion(module, *target, std::move(patterns))))
return signalPassFailure();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopUtils.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/Support/CommandLine.h"
Expand Down Expand Up @@ -228,6 +228,7 @@ void AffineDataCopyGeneration::runOnFunction() {
OwningRewritePatternList patterns;
AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
FrozenRewritePatternList frozenPatterns(std::move(patterns));
for (Operation *op : copyOps)
applyOpPatternsAndFold(op, std::move(patterns));
applyOpPatternsAndFold(op, frozenPatterns);
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Utils.h"

#define DEBUG_TYPE "simplify-affine-structure"
Expand Down Expand Up @@ -83,6 +83,7 @@ void SimplifyAffineStructures::runOnFunction() {
AffineForOp::getCanonicalizationPatterns(patterns, func.getContext());
AffineIfOp::getCanonicalizationPatterns(patterns, func.getContext());
AffineApplyOp::getCanonicalizationPatterns(patterns, func.getContext());
FrozenRewritePatternList frozenPatterns(std::move(patterns));
func.walk([&](Operation *op) {
for (auto attr : op->getAttrs()) {
if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>())
Expand All @@ -94,6 +95,6 @@ void SimplifyAffineStructures::runOnFunction() {
// The simplification of the attribute will likely simplify the op. Try to
// fold / apply canonicalization patterns when we have affine dialect ops.
if (isa<AffineForOp, AffineIfOp, AffineApplyOp>(op))
applyOpPatternsAndFold(op, patterns);
applyOpPatternsAndFold(op, frozenPatterns);
});
}
7 changes: 4 additions & 3 deletions mlir/lib/Dialect/Affine/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"

Expand Down Expand Up @@ -159,7 +159,8 @@ LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
OwningRewritePatternList patterns;
AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
bool erased;
applyOpPatternsAndFold(ifOp, patterns, &erased);
FrozenRewritePatternList frozenPatterns(std::move(patterns));
applyOpPatternsAndFold(ifOp, frozenPatterns, &erased);
if (erased) {
if (folded)
*folded = true;
Expand Down Expand Up @@ -189,7 +190,7 @@ LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
// a sequence of affine.fors that are all perfectly nested).
applyPatternsAndFoldGreedily(
hoistedIfOp.getParentWithTrait<OpTrait::IsIsolatedFromAbove>(),
std::move(patterns));
frozenPatterns);

return success();
}
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
populateWithBufferizeOpConversionPatterns<mlir::ReturnOp, mlir::ReturnOp,
linalg::CopyOp>(
&context, converter, patterns);
if (failed(applyFullConversion(this->getOperation(), target, patterns)))
if (failed(applyFullConversion(this->getOperation(), target,
std::move(patterns))))
this->signalPassFailure();
}
};
Expand Down
12 changes: 7 additions & 5 deletions mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Passes.h"

Expand All @@ -30,7 +31,7 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
// Emplace patterns one at a time while also maintaining a simple chained
// state transition.
unsigned stepCount = 0;
SmallVector<OwningRewritePatternList, 4> stage1Patterns;
SmallVector<FrozenRewritePatternList, 4> stage1Patterns;
auto zeroState = Identifier::get(std::to_string(stepCount), context);
auto currentState = zeroState;
for (const std::unique_ptr<Transformation> &t : transformationSequence) {
Expand Down Expand Up @@ -59,7 +60,7 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
hoistRedundantVectorTransfers(cast<FuncOp>(op));
return success();
};
linalg::applyStagedPatterns(func, stage1Patterns, stage2Patterns,
linalg::applyStagedPatterns(func, stage1Patterns, std::move(stage2Patterns),
stage3Transforms);

//===--------------------------------------------------------------------===//
Expand All @@ -72,21 +73,22 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
OwningRewritePatternList patterns;
patterns.insert<vector::VectorTransferFullPartialRewriter>(
context, vectorTransformsOptions);
applyPatternsAndFoldGreedily(module, patterns);
applyPatternsAndFoldGreedily(module, std::move(patterns));

// Programmatic controlled lowering of vector.contract only.
OwningRewritePatternList vectorContractLoweringPatterns;
vectorContractLoweringPatterns
.insert<ContractionOpToOuterProductOpLowering,
ContractionOpToMatmulOpLowering, ContractionOpLowering>(
vectorTransformsOptions, context);
applyPatternsAndFoldGreedily(module, vectorContractLoweringPatterns);
applyPatternsAndFoldGreedily(module,
std::move(vectorContractLoweringPatterns));

// Programmatic controlled lowering of vector.transfer only.
OwningRewritePatternList vectorToLoopsPatterns;
populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
vectorToSCFOptions);
applyPatternsAndFoldGreedily(module, vectorToLoopsPatterns);
applyPatternsAndFoldGreedily(module, std::move(vectorToLoopsPatterns));

// Ensure we drop the marker in the end.
module.walk([](LinalgOp op) {
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"

Expand Down Expand Up @@ -519,7 +518,7 @@ struct LinalgFoldUnitExtentDimsPass
FoldUnitDimLoops<IndexedGenericOp>>(context);
else
populateLinalgFoldUnitExtentDimsPatterns(context, patterns);
applyPatternsAndFoldGreedily(funcOp.getBody(), patterns);
applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
}
};
} // namespace
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;
using namespace mlir::linalg;
Expand Down Expand Up @@ -912,7 +913,7 @@ struct FusionOfTensorOpsPass
OwningRewritePatternList patterns;
Operation *op = getOperation();
populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns);
applyPatternsAndFoldGreedily(op->getRegions(), patterns);
applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
}
};

Expand All @@ -925,7 +926,7 @@ struct FoldReshapeOpsByLinearizationPass
OwningRewritePatternList patterns;
Operation *op = getOperation();
populateFoldReshapeOpsByLinearizationPatterns(op->getContext(), patterns);
applyPatternsAndFoldGreedily(op->getRegions(), patterns);
applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
}
};

Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "llvm/ADT/TypeSwitch.h"

Expand Down Expand Up @@ -592,7 +593,7 @@ static void lowerLinalgToLoopsImpl(FuncOp funcOp, MLIRContext *context) {
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
patterns.insert<FoldAffineOp>(context);
// Just apply the patterns greedily.
applyPatternsAndFoldGreedily(funcOp, patterns);
applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

namespace {
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "llvm/Support/CommandLine.h"

Expand Down Expand Up @@ -595,7 +595,7 @@ static void applyTilingToLoopPatterns(LinalgTilingLoopType loopType,
MLIRContext *ctx = funcOp.getContext();
OwningRewritePatternList patterns;
insertTilingPatterns(patterns, options, ctx);
applyPatternsAndFoldGreedily(funcOp, patterns);
applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
applyPatternsAndFoldGreedily(funcOp,
getLinalgTilingCanonicalizationPatterns(ctx));
// Drop the marker.
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <type_traits>
Expand Down Expand Up @@ -257,8 +257,8 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
}

LogicalResult mlir::linalg::applyStagedPatterns(
Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns,
const OwningRewritePatternList &stage2Patterns,
Operation *op, ArrayRef<FrozenRewritePatternList> stage1Patterns,
const FrozenRewritePatternList &stage2Patterns,
function_ref<LogicalResult(Operation *)> stage3Lambda) {
unsigned iteration = 0;
(void)iteration;
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
#include "mlir/Dialect/Quant/QuantizeUtils.h"
#include "mlir/Dialect/Quant/UniformSupport.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;
using namespace mlir::quant;
Expand Down Expand Up @@ -97,7 +96,7 @@ void ConvertConstPass::runOnFunction() {
auto func = getFunction();
auto *context = &getContext();
patterns.insert<QuantizedConstRewrite>(context);
applyPatternsAndFoldGreedily(func, patterns);
applyPatternsAndFoldGreedily(func, std::move(patterns));
}

std::unique_ptr<OperationPass<FuncOp>> mlir::quant::createConvertConstPass() {
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
#include "mlir/Dialect/Quant/Passes.h"
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/Quant/UniformSupport.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;
using namespace mlir::quant;
Expand Down Expand Up @@ -130,7 +129,7 @@ void ConvertSimulatedQuantPass::runOnFunction() {
auto ctx = func.getContext();
patterns.insert<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>(
ctx, &hadFailure);
applyPatternsAndFoldGreedily(func, patterns);
applyPatternsAndFoldGreedily(func, std::move(patterns));
if (hadFailure)
signalPassFailure();
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ struct SCFBufferizePass : public SCFBufferizeBase<SCFBufferizePass> {
populateBufferizeMaterializationLegality(target);
populateSCFStructuralTypeConversionsAndLegality(context, typeConverter,
patterns, target);
if (failed(applyPartialConversion(func, target, patterns)))
if (failed(applyPartialConversion(func, target, std::move(patterns))))
return signalPassFailure();
};
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,10 @@ void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() {

// TODO: Change the type for the indirect users such as spv.Load, spv.Store,
// spv.FunctionCall and so on.

for (auto spirvModule : module.getOps<spirv::ModuleOp>()) {
if (failed(applyFullConversion(spirvModule, target, patterns))) {
FrozenRewritePatternList frozenPatterns(std::move(patterns));
for (auto spirvModule : module.getOps<spirv::ModuleOp>())
if (failed(applyFullConversion(spirvModule, target, frozenPatterns)))
signalPassFailure();
}
}
}

std::unique_ptr<OperationPass<ModuleOp>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ void LowerABIAttributesPass::runOnOperation() {
return op->getDialect()->getNamespace() ==
spirv::SPIRVDialect::getDialectNamespace();
});
if (failed(applyPartialConversion(module, target, patterns)))
if (failed(applyPartialConversion(module, target, std::move(patterns))))
return signalPassFailure();

// Walks over all the FuncOps in spirv::ModuleOp to lower the entry point
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ struct ShapeBufferizePass : public ShapeBufferizeBase<ShapeBufferizePass> {
populateShapeStructuralTypeConversionsAndLegality(&ctx, typeConverter,
patterns, target);

if (failed(applyPartialConversion(getFunction(), target, patterns)))
if (failed(
applyPartialConversion(getFunction(), target, std::move(patterns))))
signalPassFailure();
}
};
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;

Expand Down Expand Up @@ -48,7 +49,7 @@ class RemoveShapeConstraintsPass
OwningRewritePatternList patterns;
populateRemoveShapeConstraintsPatterns(patterns, &ctx);

applyPatternsAndFoldGreedily(getFunction(), patterns);
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};

Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ void ShapeToShapeLowering::runOnFunction() {
ConversionTarget target(getContext());
target.addLegalDialect<ShapeDialect, StandardOpsDialect>();
target.addIllegalOp<NumElementsOp>();
if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
if (failed(mlir::applyPartialConversion(getFunction(), target,
std::move(patterns))))
signalPassFailure();
}

Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
populateStdBufferizePatterns(context, typeConverter, patterns);
target.addIllegalOp<DynamicTensorFromElementsOp, ExtractElementOp,
TensorCastOp, TensorFromElementsOp>();

if (failed(applyPartialConversion(getFunction(), target, patterns)))
if (failed(
applyPartialConversion(getFunction(), target, std::move(patterns))))
signalPassFailure();
}
};
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/StandardOps/Transforms/ExpandAtomic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ struct ExpandAtomic : public ExpandAtomicBase<ExpandAtomic> {
return op.kind() != AtomicRMWKind::maxf &&
op.kind() != AtomicRMWKind::minf;
});
if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
if (failed(mlir::applyPartialConversion(getFunction(), target,
std::move(patterns))))
signalPassFailure();
}
};
Expand Down
190 changes: 29 additions & 161 deletions mlir/lib/IR/PatternMatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@

#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "llvm/Support/Debug.h"

using namespace mlir;

#define DEBUG_TYPE "pattern-match"
//===----------------------------------------------------------------------===//
// PatternBenefit
//===----------------------------------------------------------------------===//

PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
Expand All @@ -27,44 +26,25 @@ unsigned short PatternBenefit::getBenefit() const {
}

//===----------------------------------------------------------------------===//
// Pattern implementation
// Pattern
//===----------------------------------------------------------------------===//

Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
MLIRContext *context)
: rootKind(OperationName(rootName, context)), benefit(benefit) {}
Pattern::Pattern(PatternBenefit benefit, MatchAnyOpTypeTag)
Pattern::Pattern(PatternBenefit benefit, MatchAnyOpTypeTag tag)
: benefit(benefit) {}

// Out-of-line vtable anchor.
void Pattern::anchor() {}

//===----------------------------------------------------------------------===//
// RewritePattern and PatternRewriter implementation
//===----------------------------------------------------------------------===//

void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
llvm_unreachable("need to implement either matchAndRewrite or one of the "
"rewrite functions!");
}

LogicalResult RewritePattern::match(Operation *op) const {
llvm_unreachable("need to implement either match or matchAndRewrite!");
}

RewritePattern::RewritePattern(StringRef rootName,
ArrayRef<StringRef> generatedNames,
PatternBenefit benefit, MLIRContext *context)
Pattern::Pattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
PatternBenefit benefit, MLIRContext *context)
: Pattern(rootName, benefit, context) {
generatedOps.reserve(generatedNames.size());
std::transform(generatedNames.begin(), generatedNames.end(),
std::back_inserter(generatedOps), [context](StringRef name) {
return OperationName(name, context);
});
}
RewritePattern::RewritePattern(ArrayRef<StringRef> generatedNames,
PatternBenefit benefit, MLIRContext *context,
MatchAnyOpTypeTag tag)
Pattern::Pattern(ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
MLIRContext *context, MatchAnyOpTypeTag tag)
: Pattern(benefit, tag) {
generatedOps.reserve(generatedNames.size());
std::transform(generatedNames.begin(), generatedNames.end(),
Expand All @@ -73,6 +53,26 @@ RewritePattern::RewritePattern(ArrayRef<StringRef> generatedNames,
});
}

//===----------------------------------------------------------------------===//
// RewritePattern
//===----------------------------------------------------------------------===//

void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
llvm_unreachable("need to implement either matchAndRewrite or one of the "
"rewrite functions!");
}

LogicalResult RewritePattern::match(Operation *op) const {
llvm_unreachable("need to implement either match or matchAndRewrite!");
}

/// Out-of-line vtable anchor.
void RewritePattern::anchor() {}

//===----------------------------------------------------------------------===//
// PatternRewriter
//===----------------------------------------------------------------------===//

PatternRewriter::~PatternRewriter() {
// Out of line to provide a vtable anchor for the class.
}
Expand Down Expand Up @@ -200,135 +200,3 @@ void PatternRewriter::cloneRegionBefore(Region &region, Block *before) {
cloneRegionBefore(region, *before->getParent(), before->getIterator());
}

//===----------------------------------------------------------------------===//
// PatternMatcher implementation
//===----------------------------------------------------------------------===//

void PatternApplicator::applyCostModel(CostModel model) {
// Separate patterns by root kind to simplify lookup later on.
patterns.clear();
anyOpPatterns.clear();
for (const auto &pat : owningPatternList) {
// If the pattern is always impossible to match, just ignore it.
if (pat->getBenefit().isImpossibleToMatch()) {
LLVM_DEBUG({
llvm::dbgs()
<< "Ignoring pattern '" << pat->getRootKind()
<< "' because it is impossible to match (by pattern benefit)\n";
});
continue;
}
if (Optional<OperationName> opName = pat->getRootKind())
patterns[*opName].push_back(pat.get());
else
anyOpPatterns.push_back(pat.get());
}

// Sort the patterns using the provided cost model.
llvm::SmallDenseMap<RewritePattern *, PatternBenefit> benefits;
auto cmp = [&benefits](RewritePattern *lhs, RewritePattern *rhs) {
return benefits[lhs] > benefits[rhs];
};
auto processPatternList = [&](SmallVectorImpl<RewritePattern *> &list) {
// Special case for one pattern in the list, which is the most common case.
if (list.size() == 1) {
if (model(*list.front()).isImpossibleToMatch()) {
LLVM_DEBUG({
llvm::dbgs() << "Ignoring pattern '" << list.front()->getRootKind()
<< "' because it is impossible to match or cannot lead "
"to legal IR (by cost model)\n";
});
list.clear();
}
return;
}

// Collect the dynamic benefits for the current pattern list.
benefits.clear();
for (RewritePattern *pat : list)
benefits.try_emplace(pat, model(*pat));

// Sort patterns with highest benefit first, and remove those that are
// impossible to match.
std::stable_sort(list.begin(), list.end(), cmp);
while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
LLVM_DEBUG({
llvm::dbgs() << "Ignoring pattern '" << list.back()->getRootKind()
<< "' because it is impossible to match or cannot lead to "
"legal IR (by cost model)\n";
});
list.pop_back();
}
};
for (auto &it : patterns)
processPatternList(it.second);
processPatternList(anyOpPatterns);
}

void PatternApplicator::walkAllPatterns(
function_ref<void(const RewritePattern &)> walk) {
for (auto &it : owningPatternList)
walk(*it);
}

LogicalResult PatternApplicator::matchAndRewrite(
Operation *op, PatternRewriter &rewriter,
function_ref<bool(const RewritePattern &)> canApply,
function_ref<void(const RewritePattern &)> onFailure,
function_ref<LogicalResult(const RewritePattern &)> onSuccess) {
// Check to see if there are patterns matching this specific operation type.
MutableArrayRef<RewritePattern *> opPatterns;
auto patternIt = patterns.find(op->getName());
if (patternIt != patterns.end())
opPatterns = patternIt->second;

// Process the patterns for that match the specific operation type, and any
// operation type in an interleaved fashion.
// FIXME: It'd be nice to just write an llvm::make_merge_range utility
// and pass in a comparison function. That would make this code trivial.
auto opIt = opPatterns.begin(), opE = opPatterns.end();
auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end();
while (opIt != opE && anyIt != anyE) {
// Try to match the pattern providing the most benefit.
RewritePattern *pattern;
if ((*opIt)->getBenefit() >= (*anyIt)->getBenefit())
pattern = *(opIt++);
else
pattern = *(anyIt++);

// Otherwise, try to match the generic pattern.
if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
onSuccess)))
return success();
}
// If we break from the loop, then only one of the ranges can still have
// elements. Loop over both without checking given that we don't need to
// interleave anymore.
for (RewritePattern *pattern : llvm::concat<RewritePattern *>(
llvm::make_range(opIt, opE), llvm::make_range(anyIt, anyE))) {
if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
onSuccess)))
return success();
}
return failure();
}

LogicalResult PatternApplicator::matchAndRewrite(
Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter,
function_ref<bool(const RewritePattern &)> canApply,
function_ref<void(const RewritePattern &)> onFailure,
function_ref<LogicalResult(const RewritePattern &)> onSuccess) {
// Check that the pattern can be applied.
if (canApply && !canApply(pattern))
return failure();

// Try to match and rewrite this pattern. The patterns are sorted by
// benefit, so if we match we can immediately rewrite.
rewriter.setInsertionPoint(op);
if (succeeded(pattern.matchAndRewrite(op, rewriter)))
return success(!onSuccess || succeeded(onSuccess(pattern)));

if (onFailure)
onFailure(pattern);
return failure();
}
13 changes: 13 additions & 0 deletions mlir/lib/Rewrite/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
add_mlir_library(MLIRRewrite
FrozenRewritePatternList.cpp
PatternApplicator.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Rewrite

DEPENDS
mlir-generic-headers

LINK_LIBS PUBLIC
MLIRIR
)
19 changes: 19 additions & 0 deletions mlir/lib/Rewrite/FrozenRewritePatternList.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//===- FrozenRewritePatternList.cpp - Frozen Pattern List -------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Rewrite/FrozenRewritePatternList.h"

using namespace mlir;

//===----------------------------------------------------------------------===//
// FrozenRewritePatternList
//===----------------------------------------------------------------------===//

FrozenRewritePatternList::FrozenRewritePatternList(
OwningRewritePatternList &&patterns)
: patterns(patterns.takePatterns()) {}
148 changes: 148 additions & 0 deletions mlir/lib/Rewrite/PatternApplicator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
//===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements an applicator that applies pattern rewrites based upon a
// user defined cost model.
//
//===----------------------------------------------------------------------===//

#include "mlir/Rewrite/PatternApplicator.h"
#include "llvm/Support/Debug.h"

using namespace mlir;

#define DEBUG_TYPE "pattern-match"

void PatternApplicator::applyCostModel(CostModel model) {
// Separate patterns by root kind to simplify lookup later on.
patterns.clear();
anyOpPatterns.clear();
for (const auto &pat : frozenPatternList.getPatterns()) {
// If the pattern is always impossible to match, just ignore it.
if (pat.getBenefit().isImpossibleToMatch()) {
LLVM_DEBUG({
llvm::dbgs()
<< "Ignoring pattern '" << pat.getRootKind()
<< "' because it is impossible to match (by pattern benefit)\n";
});
continue;
}
if (Optional<OperationName> opName = pat.getRootKind())
patterns[*opName].push_back(&pat);
else
anyOpPatterns.push_back(&pat);
}

// Sort the patterns using the provided cost model.
llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits;
auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) {
return benefits[lhs] > benefits[rhs];
};
auto processPatternList = [&](SmallVectorImpl<const RewritePattern *> &list) {
// Special case for one pattern in the list, which is the most common case.
if (list.size() == 1) {
if (model(*list.front()).isImpossibleToMatch()) {
LLVM_DEBUG({
llvm::dbgs() << "Ignoring pattern '" << list.front()->getRootKind()
<< "' because it is impossible to match or cannot lead "
"to legal IR (by cost model)\n";
});
list.clear();
}
return;
}

// Collect the dynamic benefits for the current pattern list.
benefits.clear();
for (const Pattern *pat : list)
benefits.try_emplace(pat, model(*pat));

// Sort patterns with highest benefit first, and remove those that are
// impossible to match.
std::stable_sort(list.begin(), list.end(), cmp);
while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
LLVM_DEBUG({
llvm::dbgs() << "Ignoring pattern '" << list.back()->getRootKind()
<< "' because it is impossible to match or cannot lead to "
"legal IR (by cost model)\n";
});
list.pop_back();
}
};
for (auto &it : patterns)
processPatternList(it.second);
processPatternList(anyOpPatterns);
}

void PatternApplicator::walkAllPatterns(
function_ref<void(const Pattern &)> walk) {
for (auto &it : frozenPatternList.getPatterns())
walk(it);
}

LogicalResult PatternApplicator::matchAndRewrite(
Operation *op, PatternRewriter &rewriter,
function_ref<bool(const Pattern &)> canApply,
function_ref<void(const Pattern &)> onFailure,
function_ref<LogicalResult(const Pattern &)> onSuccess) {
// Check to see if there are patterns matching this specific operation type.
MutableArrayRef<const RewritePattern *> opPatterns;
auto patternIt = patterns.find(op->getName());
if (patternIt != patterns.end())
opPatterns = patternIt->second;

// Process the patterns for that match the specific operation type, and any
// operation type in an interleaved fashion.
// FIXME: It'd be nice to just write an llvm::make_merge_range utility
// and pass in a comparison function. That would make this code trivial.
auto opIt = opPatterns.begin(), opE = opPatterns.end();
auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end();
while (opIt != opE && anyIt != anyE) {
// Try to match the pattern providing the most benefit.
const RewritePattern *pattern;
if ((*opIt)->getBenefit() >= (*anyIt)->getBenefit())
pattern = *(opIt++);
else
pattern = *(anyIt++);

// Otherwise, try to match the generic pattern.
if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
onSuccess)))
return success();
}
// If we break from the loop, then only one of the ranges can still have
// elements. Loop over both without checking given that we don't need to
// interleave anymore.
for (const RewritePattern *pattern : llvm::concat<const RewritePattern *>(
llvm::make_range(opIt, opE), llvm::make_range(anyIt, anyE))) {
if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
onSuccess)))
return success();
}
return failure();
}

LogicalResult PatternApplicator::matchAndRewrite(
Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter,
function_ref<bool(const Pattern &)> canApply,
function_ref<void(const Pattern &)> onFailure,
function_ref<LogicalResult(const Pattern &)> onSuccess) {
// Check that the pattern can be applied.
if (canApply && !canApply(pattern))
return failure();

// Try to match and rewrite this pattern. The patterns are sorted by
// benefit, so if we match we can immediately rewrite.
rewriter.setInsertionPoint(op);
if (succeeded(pattern.matchAndRewrite(op, rewriter)))
return success(!onSuccess || succeeded(onSuccess(pattern)));

if (onFailure)
onFailure(pattern);
return failure();
}
1 change: 0 additions & 1 deletion mlir/lib/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ add_mlir_library(MLIRTransforms
Canonicalizer.cpp
CopyRemoval.cpp
CSE.cpp
DialectConversion.cpp
Inliner.cpp
LocationSnapshot.cpp
LoopCoalescing.cpp
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Transforms/Canonicalizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
//===----------------------------------------------------------------------===//

#include "PassDetail.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"

using namespace mlir;
Expand All @@ -32,7 +32,7 @@ struct Canonicalizer : public CanonicalizerBase<Canonicalizer> {
op->getCanonicalizationPatterns(patterns, context);

Operation *op = getOperation();
applyPatternsAndFoldGreedily(op->getRegions(), patterns);
applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
}
};
} // end anonymous namespace
Expand Down
10 changes: 6 additions & 4 deletions mlir/lib/Transforms/Inliner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Analysis/CallGraph.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/SCCIterator.h"
Expand Down Expand Up @@ -502,7 +503,7 @@ static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
/// canonicalization patterns.
static void canonicalizeSCC(CallGraph &cg, CGUseList &useList,
CallGraphSCC &currentSCC, MLIRContext *context,
const OwningRewritePatternList &canonPatterns) {
const FrozenRewritePatternList &canonPatterns) {
// Collect the sets of nodes to canonicalize.
SmallVector<CallGraphNode *, 4> nodesToCanonicalize;
for (auto *node : currentSCC) {
Expand Down Expand Up @@ -573,7 +574,7 @@ struct InlinerPass : public InlinerBase<InlinerPass> {
/// the inlining of newly devirtualized calls.
void inlineSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC &currentSCC,
MLIRContext *context,
const OwningRewritePatternList &canonPatterns);
const FrozenRewritePatternList &canonPatterns);
};
} // end anonymous namespace

Expand All @@ -595,13 +596,14 @@ void InlinerPass::runOnOperation() {
OwningRewritePatternList canonPatterns;
for (auto *op : context->getRegisteredOperations())
op->getCanonicalizationPatterns(canonPatterns, context);
FrozenRewritePatternList frozenCanonPatterns(std::move(canonPatterns));

// Run the inline transform in post-order over the SCCs in the callgraph.
SymbolTableCollection symbolTable;
Inliner inliner(context, cg, symbolTable);
CGUseList useList(getOperation(), cg, symbolTable);
runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
inlineSCC(inliner, useList, scc, context, canonPatterns);
inlineSCC(inliner, useList, scc, context, frozenCanonPatterns);
});

// After inlining, make sure to erase any callables proven to be dead.
Expand All @@ -610,7 +612,7 @@ void InlinerPass::runOnOperation() {

void InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
CallGraphSCC &currentSCC, MLIRContext *context,
const OwningRewritePatternList &canonPatterns) {
const FrozenRewritePatternList &canonPatterns) {
// If we successfully inlined any calls, run some simplifications on the
// nodes of the scc. Continue attempting to inline until we reach a fixed
// point, or a maximum iteration count. We canonicalize here as it may
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Transforms/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_library(MLIRTransformUtils
DialectConversion.cpp
FoldUtils.cpp
GreedyPatternRewriteDriver.cpp
InliningUtils.cpp
Expand All @@ -19,5 +20,6 @@ add_mlir_library(MLIRTransformUtils
MLIRLoopAnalysis
MLIRSCF
MLIRPass
MLIRRewrite
MLIRStandard
)
Loading