Skip to content

Commit 76f3c2f

Browse files
committed
[mlir][Pattern] Add better support for using interfaces/traits to match root operations in rewrite patterns
To match an interface or trait, users currently have to use the `MatchAny` tag. This tag can be quite problematic for compile time for things like the canonicalizer, as the `MatchAny` patterns may get applied to *every* operation. This revision adds better support by bucketing interface/trait patterns based on which registered operations have them registered. This means that moving forward we will only attempt to match these patterns to operations that have this interface registered. Two simplify defining patterns that match traits and interfaces, two new utility classes have been added: OpTraitRewritePattern and OpInterfaceRewritePattern. Differential Revision: https://reviews.llvm.org/D98986
1 parent 782c534 commit 76f3c2f

File tree

33 files changed

+462
-254
lines changed

33 files changed

+462
-254
lines changed

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,7 @@ static bool isOne(mlir::Value v) { return checkIsIntegerConstant(v, 1); }
697697
template <typename FltOp, typename CpxOp>
698698
struct UndoComplexPattern : public mlir::RewritePattern {
699699
UndoComplexPattern(mlir::MLIRContext *ctx)
700-
: mlir::RewritePattern("fir.insert_value", {}, 2, ctx) {}
700+
: mlir::RewritePattern("fir.insert_value", 2, ctx) {}
701701

702702
mlir::LogicalResult
703703
matchAndRewrite(mlir::Operation *op,

mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@ namespace linalg {
3030
// or in an externally linked library.
3131
// This is a generic entry point for all LinalgOp, except for CopyOp and
3232
// IndexedGenericOp, for which omre specialized patterns are provided.
33-
class LinalgOpToLibraryCallRewrite : public RewritePattern {
33+
class LinalgOpToLibraryCallRewrite
34+
: public OpInterfaceRewritePattern<LinalgOp> {
3435
public:
35-
LinalgOpToLibraryCallRewrite()
36-
: RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
36+
using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
3737

38-
LogicalResult matchAndRewrite(Operation *op,
38+
LogicalResult matchAndRewrite(LinalgOp op,
3939
PatternRewriter &rewriter) const override;
4040
};
4141

mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ void enqueue(RewritePatternSet &patternList, OptionsType options,
6060
if (!opName.empty())
6161
patternList.add<PatternType>(opName, patternList.getContext(), options, m);
6262
else
63-
patternList.add<PatternType>(m.addOpFilter<OpType>(), options);
63+
patternList.add<PatternType>(patternList.getContext(),
64+
m.addOpFilter<OpType>(), options);
6465
}
6566

6667
/// Promotion transformation enqueues a particular stage-1 pattern for

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
452452
struct LinalgBaseTilingPattern : public RewritePattern {
453453
// Entry point to match any LinalgOp OpInterface.
454454
LinalgBaseTilingPattern(
455-
LinalgTilingOptions options,
455+
MLIRContext *context, LinalgTilingOptions options,
456456
LinalgTransformationFilter filter = LinalgTransformationFilter(),
457457
PatternBenefit benefit = 1);
458458
// Entry point to match a specific Linalg op.
@@ -644,7 +644,8 @@ struct LinalgVectorizationOptions {};
644644

645645
struct LinalgBaseVectorizationPattern : public RewritePattern {
646646
/// MatchAnyOpTag-based constructor with a mandatory `filter`.
647-
LinalgBaseVectorizationPattern(LinalgTransformationFilter filter,
647+
LinalgBaseVectorizationPattern(MLIRContext *context,
648+
LinalgTransformationFilter filter,
648649
PatternBenefit benefit = 1);
649650
/// Name-based constructor with an optional `filter`.
650651
LinalgBaseVectorizationPattern(
@@ -663,10 +664,10 @@ struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern {
663664
/// These constructors are available to anyone.
664665
/// MatchAnyOpTag-based constructor with a mandatory `filter`.
665666
LinalgVectorizationPattern(
666-
LinalgTransformationFilter filter,
667+
MLIRContext *context, LinalgTransformationFilter filter,
667668
LinalgVectorizationOptions options = LinalgVectorizationOptions(),
668669
PatternBenefit benefit = 1)
669-
: LinalgBaseVectorizationPattern(filter, benefit) {}
670+
: LinalgBaseVectorizationPattern(context, filter, benefit) {}
670671
/// Name-based constructor with an optional `filter`.
671672
LinalgVectorizationPattern(
672673
StringRef opName, MLIRContext *context,
@@ -702,8 +703,8 @@ template <typename OpType, typename = std::enable_if_t<
702703
void insertVectorizationPatternImpl(RewritePatternSet &patternList,
703704
linalg::LinalgVectorizationOptions options,
704705
linalg::LinalgTransformationFilter f) {
705-
patternList.add<linalg::LinalgVectorizationPattern>(f.addOpFilter<OpType>(),
706-
options);
706+
patternList.add<linalg::LinalgVectorizationPattern>(
707+
patternList.getContext(), f.addOpFilter<OpType>(), options);
707708
}
708709

709710
/// Variadic helper function to insert vectorization patterns for C++ ops.
@@ -737,7 +738,7 @@ struct LinalgLoweringPattern : public RewritePattern {
737738
MLIRContext *context, LinalgLoweringType loweringType,
738739
LinalgTransformationFilter filter = LinalgTransformationFilter(),
739740
ArrayRef<unsigned> interchangeVector = {}, PatternBenefit benefit = 1)
740-
: RewritePattern(OpTy::getOperationName(), {}, benefit, context),
741+
: RewritePattern(OpTy::getOperationName(), benefit, context),
741742
filter(filter), loweringType(loweringType),
742743
interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
743744

mlir/include/mlir/Dialect/Vector/VectorTransforms.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ struct UnrollVectorOptions {
123123
struct UnrollVectorPattern : public RewritePattern {
124124
using FilterConstraintType = std::function<LogicalResult(Operation *op)>;
125125
UnrollVectorPattern(MLIRContext *context, UnrollVectorOptions options)
126-
: RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()), options(options) {}
126+
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
127+
options(options) {}
127128
LogicalResult matchAndRewrite(Operation *op,
128129
PatternRewriter &rewriter) const override {
129130
if (options.filterConstraint && failed(options.filterConstraint(op)))
@@ -216,7 +217,7 @@ struct VectorTransferFullPartialRewriter : public RewritePattern {
216217
FilterConstraintType filter =
217218
[](VectorTransferOpInterface op) { return success(); },
218219
PatternBenefit benefit = 1)
219-
: RewritePattern(benefit, MatchAnyOpTypeTag()), options(options),
220+
: RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
220221
filter(filter) {}
221222

222223
/// Performs the rewrite.

mlir/include/mlir/IR/OpDefinition.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,6 +1516,13 @@ class Op : public OpState, public Traits<ConcreteType>... {
15161516
#endif
15171517
return false;
15181518
}
1519+
/// Provide `classof` support for other OpBase derived classes, such as
1520+
/// Interfaces.
1521+
template <typename T>
1522+
static std::enable_if_t<std::is_base_of<OpState, T>::value, bool>
1523+
classof(const T *op) {
1524+
return classof(const_cast<T *>(op)->getOperation());
1525+
}
15191526

15201527
/// Expose the type we are instantiated on to template machinery that may want
15211528
/// to introspect traits on this operation.

mlir/include/mlir/IR/OperationSupport.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,20 @@ class AbstractOperation {
142142
return interfaceMap.lookup<T>();
143143
}
144144

145+
/// Returns true if this operation has the given interface registered to it.
146+
bool hasInterface(TypeID interfaceID) const {
147+
return interfaceMap.contains(interfaceID);
148+
}
149+
145150
/// Returns true if the operation has a particular trait.
146151
template <template <typename T> class Trait>
147152
bool hasTrait() const {
148153
return hasTraitFn(TypeID::get<Trait>());
149154
}
150155

156+
/// Returns true if the operation has a particular trait.
157+
bool hasTrait(TypeID traitID) const { return hasTraitFn(traitID); }
158+
151159
/// Look up the specified operation in the specified MLIRContext and return a
152160
/// pointer to it if present. Otherwise, return a null pointer.
153161
static const AbstractOperation *lookup(StringRef opName,

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 132 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,49 @@ class PatternBenefit {
6868
/// used to interface with the metadata of a pattern, such as the benefit or
6969
/// root operation.
7070
class Pattern {
71+
/// This enum represents the kind of value used to select the root operations
72+
/// that match this pattern.
73+
enum class RootKind {
74+
/// The pattern root matches "any" operation.
75+
Any,
76+
/// The pattern root is matched using a concrete operation name.
77+
OperationName,
78+
/// The pattern root is matched using an interface ID.
79+
InterfaceID,
80+
/// The patter root is matched using a trait ID.
81+
TraitID
82+
};
83+
7184
public:
7285
/// Return a list of operations that may be generated when rewriting an
7386
/// operation instance with this pattern.
7487
ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
7588

7689
/// Return the root node that this pattern matches. Patterns that can match
7790
/// multiple root types return None.
78-
Optional<OperationName> getRootKind() const { return rootKind; }
91+
Optional<OperationName> getRootKind() const {
92+
if (rootKind == RootKind::OperationName)
93+
return OperationName::getFromOpaquePointer(rootValue);
94+
return llvm::None;
95+
}
96+
97+
/// Return the interface ID used to match the root operation of this pattern.
98+
/// If the pattern does not use an interface ID for deciding the root match,
99+
/// this returns None.
100+
Optional<TypeID> getRootInterfaceID() const {
101+
if (rootKind == RootKind::InterfaceID)
102+
return TypeID::getFromOpaquePointer(rootValue);
103+
return llvm::None;
104+
}
105+
106+
/// Return the trait ID used to match the root operation of this pattern.
107+
/// If the pattern does not use a trait ID for deciding the root match, this
108+
/// returns None.
109+
Optional<TypeID> getRootTraitID() const {
110+
if (rootKind == RootKind::TraitID)
111+
return TypeID::getFromOpaquePointer(rootValue);
112+
return llvm::None;
113+
}
79114

80115
/// Return the benefit (the inverse of "cost") of matching this pattern. The
81116
/// benefit of a Pattern is always static - rewrites that may have dynamic
@@ -88,56 +123,85 @@ class Pattern {
88123
/// i.e. this pattern may generate IR that also matches this pattern, but is
89124
/// known to bound the recursion. This signals to a rewrite driver that it is
90125
/// safe to apply this pattern recursively to generated IR.
91-
bool hasBoundedRewriteRecursion() const { return hasBoundedRecursion; }
126+
bool hasBoundedRewriteRecursion() const {
127+
return contextAndHasBoundedRecursion.getInt();
128+
}
129+
130+
/// Return the MLIRContext used to create this pattern.
131+
MLIRContext *getContext() const {
132+
return contextAndHasBoundedRecursion.getPointer();
133+
}
92134

93135
protected:
94136
/// This class acts as a special tag that makes the desire to match "any"
95137
/// operation type explicit. This helps to avoid unnecessary usages of this
96138
/// feature, and ensures that the user is making a conscious decision.
97139
struct MatchAnyOpTypeTag {};
140+
/// This class acts as a special tag that makes the desire to match any
141+
/// operation that implements a given interface explicit. This helps to avoid
142+
/// unnecessary usages of this feature, and ensures that the user is making a
143+
/// conscious decision.
144+
struct MatchInterfaceOpTypeTag {};
145+
/// This class acts as a special tag that makes the desire to match any
146+
/// operation that implements a given trait explicit. This helps to avoid
147+
/// unnecessary usages of this feature, and ensures that the user is making a
148+
/// conscious decision.
149+
struct MatchTraitOpTypeTag {};
98150

99151
/// Construct a pattern with a certain benefit that matches the operation
100152
/// with the given root name.
101-
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context);
102-
/// Construct a pattern with a certain benefit that matches any operation
103-
/// type. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
104-
/// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
105-
/// always be supplied here.
106-
Pattern(PatternBenefit benefit, MatchAnyOpTypeTag tag);
107-
/// Construct a pattern with a certain benefit that matches the operation with
108-
/// the given root name. `generatedNames` contains the names of operations
109-
/// that may be generated during a successful rewrite.
110-
Pattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
111-
PatternBenefit benefit, MLIRContext *context);
153+
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context,
154+
ArrayRef<StringRef> generatedNames = {});
112155
/// Construct a pattern that may match any operation type. `generatedNames`
113156
/// contains the names of operations that may be generated during a successful
114157
/// rewrite. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
115158
/// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
116159
/// always be supplied here.
117-
Pattern(ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
118-
MLIRContext *context, MatchAnyOpTypeTag tag);
160+
Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit, MLIRContext *context,
161+
ArrayRef<StringRef> generatedNames = {});
162+
/// Construct a pattern that may match any operation that implements the
163+
/// interface defined by the provided `interfaceID`. `generatedNames` contains
164+
/// the names of operations that may be generated during a successful rewrite.
165+
/// `MatchInterfaceOpTypeTag` is just a tag to ensure that the "match
166+
/// interface" behavior is what the user actually desired,
167+
/// `MatchInterfaceOpTypeTag()` should always be supplied here.
168+
Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
169+
PatternBenefit benefit, MLIRContext *context,
170+
ArrayRef<StringRef> generatedNames = {});
171+
/// Construct a pattern that may match any operation that implements the
172+
/// trait defined by the provided `traitID`. `generatedNames` contains the
173+
/// names of operations that may be generated during a successful rewrite.
174+
/// `MatchTraitOpTypeTag` is just a tag to ensure that the "match trait"
175+
/// behavior is what the user actually desired, `MatchTraitOpTypeTag()` should
176+
/// always be supplied here.
177+
Pattern(MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit,
178+
MLIRContext *context, ArrayRef<StringRef> generatedNames = {});
119179

120180
/// Set the flag detailing if this pattern has bounded rewrite recursion or
121181
/// not.
122182
void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg = true) {
123-
hasBoundedRecursion = hasBoundedRecursionArg;
183+
contextAndHasBoundedRecursion.setInt(hasBoundedRecursionArg);
124184
}
125185

126186
private:
127-
/// A list of the potential operations that may be generated when rewriting
128-
/// an op with this pattern.
129-
SmallVector<OperationName, 2> generatedOps;
187+
Pattern(const void *rootValue, RootKind rootKind,
188+
ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
189+
MLIRContext *context);
130190

131-
/// The root operation of the pattern. If the pattern matches a specific
132-
/// operation, this contains the name of that operation. Contains None
133-
/// otherwise.
134-
Optional<OperationName> rootKind;
191+
/// The value used to match the root operation of the pattern.
192+
const void *rootValue;
193+
RootKind rootKind;
135194

136195
/// The expected benefit of matching this pattern.
137196
const PatternBenefit benefit;
138197

139-
/// A boolean flag of whether this pattern has bounded recursion or not.
140-
bool hasBoundedRecursion = false;
198+
/// The context this pattern was created from, and a boolean flag indicating
199+
/// whether this pattern has bounded recursion or not.
200+
llvm::PointerIntPair<MLIRContext *, 1, bool> contextAndHasBoundedRecursion;
201+
202+
/// A list of the potential operations that may be generated when rewriting
203+
/// an op with this pattern.
204+
SmallVector<OperationName, 2> generatedOps;
141205
};
142206

143207
//===----------------------------------------------------------------------===//
@@ -188,15 +252,13 @@ class RewritePattern : public Pattern {
188252
virtual void anchor();
189253
};
190254

191-
/// OpRewritePattern is a wrapper around RewritePattern that allows for
192-
/// matching and rewriting against an instance of a derived operation class as
193-
/// opposed to a raw Operation.
255+
namespace detail {
256+
/// OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that
257+
/// allows for matching and rewriting against an instance of a derived operation
258+
/// class or Interface.
194259
template <typename SourceOp>
195-
struct OpRewritePattern : public RewritePattern {
196-
/// Patterns must specify the root operation name they match against, and can
197-
/// also specify the benefit of the pattern matching.
198-
OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
199-
: RewritePattern(SourceOp::getOperationName(), benefit, context) {}
260+
struct OpOrInterfaceRewritePatternBase : public RewritePattern {
261+
using RewritePattern::RewritePattern;
200262

201263
/// Wrappers around the RewritePattern methods that pass the derived op type.
202264
void rewrite(Operation *op, PatternRewriter &rewriter) const final {
@@ -227,6 +289,43 @@ struct OpRewritePattern : public RewritePattern {
227289
return failure();
228290
}
229291
};
292+
} // namespace detail
293+
294+
/// OpRewritePattern is a wrapper around RewritePattern that allows for
295+
/// matching and rewriting against an instance of a derived operation class as
296+
/// opposed to a raw Operation.
297+
template <typename SourceOp>
298+
struct OpRewritePattern
299+
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
300+
/// Patterns must specify the root operation name they match against, and can
301+
/// also specify the benefit of the pattern matching.
302+
OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
303+
: detail::OpOrInterfaceRewritePatternBase<SourceOp>(
304+
SourceOp::getOperationName(), benefit, context) {}
305+
};
306+
307+
/// OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for
308+
/// matching and rewriting against an instance of an operation interface instead
309+
/// of a raw Operation.
310+
template <typename SourceOp>
311+
struct OpInterfaceRewritePattern
312+
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
313+
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
314+
: detail::OpOrInterfaceRewritePatternBase<SourceOp>(
315+
Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
316+
benefit, context) {}
317+
};
318+
319+
/// OpTraitRewritePattern is a wrapper around RewritePattern that allows for
320+
/// matching and rewriting against instances of an operation that possess a
321+
/// given trait.
322+
template <template <typename> class TraitType>
323+
class OpTraitRewritePattern : public RewritePattern {
324+
public:
325+
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
326+
: RewritePattern(Pattern::MatchTraitOpTypeTag(), TypeID::get<TraitType>(),
327+
benefit, context) {}
328+
};
230329

231330
//===----------------------------------------------------------------------===//
232331
// PDLPatternModule

0 commit comments

Comments
 (0)