@@ -68,14 +68,49 @@ class PatternBenefit {
6868// / used to interface with the metadata of a pattern, such as the benefit or
6969// / root operation.
7070class Pattern {
71+ // / This enum represents the kind of value used to select the root operations
72+ // / that match this pattern.
73+ enum class RootKind {
74+ // / The pattern root matches "any" operation.
75+ Any,
76+ // / The pattern root is matched using a concrete operation name.
77+ OperationName,
78+ // / The pattern root is matched using an interface ID.
79+ InterfaceID,
80+ // / The patter root is matched using a trait ID.
81+ TraitID
82+ };
83+
7184public:
7285 // / Return a list of operations that may be generated when rewriting an
7386 // / operation instance with this pattern.
7487 ArrayRef<OperationName> getGeneratedOps () const { return generatedOps; }
7588
7689 // / Return the root node that this pattern matches. Patterns that can match
7790 // / multiple root types return None.
78- Optional<OperationName> getRootKind () const { return rootKind; }
91+ Optional<OperationName> getRootKind () const {
92+ if (rootKind == RootKind::OperationName)
93+ return OperationName::getFromOpaquePointer (rootValue);
94+ return llvm::None;
95+ }
96+
97+ // / Return the interface ID used to match the root operation of this pattern.
98+ // / If the pattern does not use an interface ID for deciding the root match,
99+ // / this returns None.
100+ Optional<TypeID> getRootInterfaceID () const {
101+ if (rootKind == RootKind::InterfaceID)
102+ return TypeID::getFromOpaquePointer (rootValue);
103+ return llvm::None;
104+ }
105+
106+ // / Return the trait ID used to match the root operation of this pattern.
107+ // / If the pattern does not use a trait ID for deciding the root match, this
108+ // / returns None.
109+ Optional<TypeID> getRootTraitID () const {
110+ if (rootKind == RootKind::TraitID)
111+ return TypeID::getFromOpaquePointer (rootValue);
112+ return llvm::None;
113+ }
79114
80115 // / Return the benefit (the inverse of "cost") of matching this pattern. The
81116 // / benefit of a Pattern is always static - rewrites that may have dynamic
@@ -88,56 +123,85 @@ class Pattern {
88123 // / i.e. this pattern may generate IR that also matches this pattern, but is
89124 // / known to bound the recursion. This signals to a rewrite driver that it is
90125 // / safe to apply this pattern recursively to generated IR.
91- bool hasBoundedRewriteRecursion () const { return hasBoundedRecursion; }
126+ bool hasBoundedRewriteRecursion () const {
127+ return contextAndHasBoundedRecursion.getInt ();
128+ }
129+
130+ // / Return the MLIRContext used to create this pattern.
131+ MLIRContext *getContext () const {
132+ return contextAndHasBoundedRecursion.getPointer ();
133+ }
92134
93135protected:
94136 // / This class acts as a special tag that makes the desire to match "any"
95137 // / operation type explicit. This helps to avoid unnecessary usages of this
96138 // / feature, and ensures that the user is making a conscious decision.
97139 struct MatchAnyOpTypeTag {};
140+ // / This class acts as a special tag that makes the desire to match any
141+ // / operation that implements a given interface explicit. This helps to avoid
142+ // / unnecessary usages of this feature, and ensures that the user is making a
143+ // / conscious decision.
144+ struct MatchInterfaceOpTypeTag {};
145+ // / This class acts as a special tag that makes the desire to match any
146+ // / operation that implements a given trait explicit. This helps to avoid
147+ // / unnecessary usages of this feature, and ensures that the user is making a
148+ // / conscious decision.
149+ struct MatchTraitOpTypeTag {};
98150
99151 // / Construct a pattern with a certain benefit that matches the operation
100152 // / with the given root name.
101- Pattern (StringRef rootName, PatternBenefit benefit, MLIRContext *context);
102- // / Construct a pattern with a certain benefit that matches any operation
103- // / type. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
104- // / behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
105- // / always be supplied here.
106- Pattern (PatternBenefit benefit, MatchAnyOpTypeTag tag);
107- // / Construct a pattern with a certain benefit that matches the operation with
108- // / the given root name. `generatedNames` contains the names of operations
109- // / that may be generated during a successful rewrite.
110- Pattern (StringRef rootName, ArrayRef<StringRef> generatedNames,
111- PatternBenefit benefit, MLIRContext *context);
153+ Pattern (StringRef rootName, PatternBenefit benefit, MLIRContext *context,
154+ ArrayRef<StringRef> generatedNames = {});
112155 // / Construct a pattern that may match any operation type. `generatedNames`
113156 // / contains the names of operations that may be generated during a successful
114157 // / rewrite. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
115158 // / behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
116159 // / always be supplied here.
117- Pattern (ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
118- MLIRContext *context, MatchAnyOpTypeTag tag);
160+ Pattern (MatchAnyOpTypeTag tag, PatternBenefit benefit, MLIRContext *context,
161+ ArrayRef<StringRef> generatedNames = {});
162+ // / Construct a pattern that may match any operation that implements the
163+ // / interface defined by the provided `interfaceID`. `generatedNames` contains
164+ // / the names of operations that may be generated during a successful rewrite.
165+ // / `MatchInterfaceOpTypeTag` is just a tag to ensure that the "match
166+ // / interface" behavior is what the user actually desired,
167+ // / `MatchInterfaceOpTypeTag()` should always be supplied here.
168+ Pattern (MatchInterfaceOpTypeTag tag, TypeID interfaceID,
169+ PatternBenefit benefit, MLIRContext *context,
170+ ArrayRef<StringRef> generatedNames = {});
171+ // / Construct a pattern that may match any operation that implements the
172+ // / trait defined by the provided `traitID`. `generatedNames` contains the
173+ // / names of operations that may be generated during a successful rewrite.
174+ // / `MatchTraitOpTypeTag` is just a tag to ensure that the "match trait"
175+ // / behavior is what the user actually desired, `MatchTraitOpTypeTag()` should
176+ // / always be supplied here.
177+ Pattern (MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit,
178+ MLIRContext *context, ArrayRef<StringRef> generatedNames = {});
119179
120180 // / Set the flag detailing if this pattern has bounded rewrite recursion or
121181 // / not.
122182 void setHasBoundedRewriteRecursion (bool hasBoundedRecursionArg = true ) {
123- hasBoundedRecursion = hasBoundedRecursionArg;
183+ contextAndHasBoundedRecursion. setInt ( hasBoundedRecursionArg) ;
124184 }
125185
126186private:
127- // / A list of the potential operations that may be generated when rewriting
128- // / an op with this pattern.
129- SmallVector<OperationName, 2 > generatedOps ;
187+ Pattern ( const void *rootValue, RootKind rootKind,
188+ ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
189+ MLIRContext *context) ;
130190
131- // / The root operation of the pattern. If the pattern matches a specific
132- // / operation, this contains the name of that operation. Contains None
133- // / otherwise.
134- Optional<OperationName> rootKind;
191+ // / The value used to match the root operation of the pattern.
192+ const void *rootValue;
193+ RootKind rootKind;
135194
136195 // / The expected benefit of matching this pattern.
137196 const PatternBenefit benefit;
138197
139- // / A boolean flag of whether this pattern has bounded recursion or not.
140- bool hasBoundedRecursion = false ;
198+ // / The context this pattern was created from, and a boolean flag indicating
199+ // / whether this pattern has bounded recursion or not.
200+ llvm::PointerIntPair<MLIRContext *, 1 , bool > contextAndHasBoundedRecursion;
201+
202+ // / A list of the potential operations that may be generated when rewriting
203+ // / an op with this pattern.
204+ SmallVector<OperationName, 2 > generatedOps;
141205};
142206
143207// ===----------------------------------------------------------------------===//
@@ -188,15 +252,13 @@ class RewritePattern : public Pattern {
188252 virtual void anchor ();
189253};
190254
191- // / OpRewritePattern is a wrapper around RewritePattern that allows for
192- // / matching and rewriting against an instance of a derived operation class as
193- // / opposed to a raw Operation.
255+ namespace detail {
256+ // / OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that
257+ // / allows for matching and rewriting against an instance of a derived operation
258+ // / class or Interface.
194259template <typename SourceOp>
195- struct OpRewritePattern : public RewritePattern {
196- // / Patterns must specify the root operation name they match against, and can
197- // / also specify the benefit of the pattern matching.
198- OpRewritePattern (MLIRContext *context, PatternBenefit benefit = 1 )
199- : RewritePattern(SourceOp::getOperationName(), benefit, context) {}
260+ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
261+ using RewritePattern::RewritePattern;
200262
201263 // / Wrappers around the RewritePattern methods that pass the derived op type.
202264 void rewrite (Operation *op, PatternRewriter &rewriter) const final {
@@ -227,6 +289,43 @@ struct OpRewritePattern : public RewritePattern {
227289 return failure ();
228290 }
229291};
292+ } // namespace detail
293+
294+ // / OpRewritePattern is a wrapper around RewritePattern that allows for
295+ // / matching and rewriting against an instance of a derived operation class as
296+ // / opposed to a raw Operation.
297+ template <typename SourceOp>
298+ struct OpRewritePattern
299+ : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
300+ // / Patterns must specify the root operation name they match against, and can
301+ // / also specify the benefit of the pattern matching.
302+ OpRewritePattern (MLIRContext *context, PatternBenefit benefit = 1 )
303+ : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
304+ SourceOp::getOperationName (), benefit, context) {}
305+ };
306+
307+ // / OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for
308+ // / matching and rewriting against an instance of an operation interface instead
309+ // / of a raw Operation.
310+ template <typename SourceOp>
311+ struct OpInterfaceRewritePattern
312+ : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
313+ OpInterfaceRewritePattern (MLIRContext *context, PatternBenefit benefit = 1 )
314+ : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
315+ Pattern::MatchInterfaceOpTypeTag (), SourceOp::getInterfaceID(),
316+ benefit, context) {}
317+ };
318+
319+ // / OpTraitRewritePattern is a wrapper around RewritePattern that allows for
320+ // / matching and rewriting against instances of an operation that possess a
321+ // / given trait.
322+ template <template <typename > class TraitType >
323+ class OpTraitRewritePattern : public RewritePattern {
324+ public:
325+ OpTraitRewritePattern (MLIRContext *context, PatternBenefit benefit = 1 )
326+ : RewritePattern(Pattern::MatchTraitOpTypeTag(), TypeID::get<TraitType>(),
327+ benefit, context) {}
328+ };
230329
231330// ===----------------------------------------------------------------------===//
232331// PDLPatternModule
0 commit comments