Skip to content

Commit

Permalink
[mlir] Tighten access of RewritePattern methods.
Browse files Browse the repository at this point in the history
In RewritePattern, only expose `matchAndRewrite` as a public function. `match` can be protected (but needs to be protected because we want to call it from an override of `matchAndRewrite`). `rewrite` can be private.

For classes deriving from RewritePattern, all 3 functions can be private.

Side note: I didn't understand the need for the `using RewritePattern::matchAndRewrite` in derived classes, and started poking around. They are gone now, and I think the result is (only very slightly) cleaner.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D92670
  • Loading branch information
chsigg committed Dec 8, 2020
1 parent 2812c15 commit 02c9050
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 51 deletions.
Expand Up @@ -571,11 +571,9 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
&typeConverter.getContext(), typeConverter,
benefit) {}

/// Wrappers around the RewritePattern methods that pass the derived op type.
void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), operands, rewriter);
}
private:
/// Wrappers around the ConversionPattern methods that pass the derived op
/// type.
LogicalResult match(Operation *op) const final {
return match(cast<SourceOp>(op));
}
Expand All @@ -584,6 +582,10 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
ConversionPatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
}
void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), operands, rewriter);
}

/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
Expand All @@ -603,10 +605,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
}
return failure();
}

private:
using ConvertToLLVMPattern::match;
using ConvertToLLVMPattern::matchAndRewrite;
};

namespace LLVM {
Expand Down Expand Up @@ -636,6 +634,7 @@ class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
using Super = OneToOneConvertToLLVMPattern<SourceOp, TargetOp>;

private:
/// Converts the type of the result to an LLVM type, pass operands as is,
/// preserve attributes.
LogicalResult
Expand All @@ -655,6 +654,7 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;

private:
LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Expand Down
28 changes: 16 additions & 12 deletions mlir/include/mlir/IR/PatternMatch.h
Expand Up @@ -156,17 +156,6 @@ class RewritePattern : public Pattern {
public:
virtual ~RewritePattern() {}

/// Rewrite the IR rooted at the specified operation with the result of
/// this pattern, generating any new operations with the specified
/// builder. If an unexpected error is encountered (an internal
/// compiler error), it is emitted through the normal MLIR diagnostic
/// hooks and the IR is left in a valid state.
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;

/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind().
virtual LogicalResult match(Operation *op) const;

/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind(). If successful, this
/// function will automatically perform the rewrite.
Expand All @@ -183,19 +172,34 @@ class RewritePattern : public Pattern {
/// Inherit the base constructors from `Pattern`.
using Pattern::Pattern;

/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind().
virtual LogicalResult match(Operation *op) const;

private:
/// Rewrite the IR rooted at the specified operation with the result of
/// this pattern, generating any new operations with the specified
/// builder. If an unexpected error is encountered (an internal
/// compiler error), it is emitted through the normal MLIR diagnostic
/// hooks and the IR is left in a valid state.
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;

/// An anchor for the virtual table.
virtual void anchor();
};

/// OpRewritePattern is a wrapper around RewritePattern that allows for
/// matching and rewriting against an instance of a derived operation class as
/// opposed to a raw Operation.
template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
template <typename SourceOp>
class OpRewritePattern : public RewritePattern {
public:
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
: RewritePattern(SourceOp::getOperationName(), benefit, context) {}

private:
/// Wrappers around the RewritePattern methods that pass the derived op type.
void rewrite(Operation *op, PatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), rewriter);
Expand Down
61 changes: 31 additions & 30 deletions mlir/include/mlir/Transforms/DialectConversion.h
Expand Up @@ -313,6 +313,30 @@ class TypeConverter {
/// patterns of this type can only be used with the 'apply*' methods below.
class ConversionPattern : public RewritePattern {
public:
/// Return the type converter held by this pattern, or nullptr if the pattern
/// does not require type conversion.
TypeConverter *getTypeConverter() const { return typeConverter; }

protected:
/// See `RewritePattern::RewritePattern` for information on the other
/// available constructors.
using RewritePattern::RewritePattern;
/// Construct a conversion pattern that matches an operation with the given
/// root name. This constructor allows for providing a type converter to use
/// within the pattern.
ConversionPattern(StringRef rootName, PatternBenefit benefit,
TypeConverter &typeConverter, MLIRContext *ctx)
: RewritePattern(rootName, benefit, ctx), typeConverter(&typeConverter) {}
/// Construct a conversion pattern that matches any operation type. This
/// constructor allows for providing a type converter to use within the
/// pattern. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
/// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
/// always be supplied here.
ConversionPattern(PatternBenefit benefit, TypeConverter &typeConverter,
MatchAnyOpTypeTag tag)
: RewritePattern(benefit, tag), typeConverter(&typeConverter) {}

private:
/// Hook for derived classes to implement rewriting. `op` is the (first)
/// operation matched by the pattern, `operands` is a list of the rewritten
/// operand values that are passed to `op`, `rewriter` can be used to emit the
Expand All @@ -323,6 +347,10 @@ class ConversionPattern : public RewritePattern {
llvm_unreachable("unimplemented rewrite");
}

void rewrite(Operation *op, PatternRewriter &rewriter) const final {
llvm_unreachable("never called");
}

/// Hook for derived classes to implement combined matching and rewriting.
virtual LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
Expand All @@ -337,49 +365,25 @@ class ConversionPattern : public RewritePattern {
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final;

/// Return the type converter held by this pattern, or nullptr if the pattern
/// does not require type conversion.
TypeConverter *getTypeConverter() const { return typeConverter; }

protected:
/// See `RewritePattern::RewritePattern` for information on the other
/// available constructors.
using RewritePattern::RewritePattern;
/// Construct a conversion pattern that matches an operation with the given
/// root name. This constructor allows for providing a type converter to use
/// within the pattern.
ConversionPattern(StringRef rootName, PatternBenefit benefit,
TypeConverter &typeConverter, MLIRContext *ctx)
: RewritePattern(rootName, benefit, ctx), typeConverter(&typeConverter) {}
/// Construct a conversion pattern that matches any operation type. This
/// constructor allows for providing a type converter to use within the
/// pattern. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
/// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
/// always be supplied here.
ConversionPattern(PatternBenefit benefit, TypeConverter &typeConverter,
MatchAnyOpTypeTag tag)
: RewritePattern(benefit, tag), typeConverter(&typeConverter) {}

protected:
/// An optional type converter for use by this pattern.
TypeConverter *typeConverter = nullptr;

private:
using RewritePattern::rewrite;
};

/// OpConversionPattern is a wrapper around ConversionPattern that allows for
/// matching and rewriting against an instance of a derived operation class as
/// opposed to a raw Operation.
template <typename SourceOp>
struct OpConversionPattern : public ConversionPattern {
class OpConversionPattern : public ConversionPattern {
public:
OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context,
PatternBenefit benefit = 1)
: ConversionPattern(SourceOp::getOperationName(), benefit, typeConverter,
context) {}

private:
/// Wrappers around the ConversionPattern methods that pass the derived op
/// type.
void rewrite(Operation *op, ArrayRef<Value> operands,
Expand Down Expand Up @@ -409,9 +413,6 @@ struct OpConversionPattern : public ConversionPattern {
rewrite(op, operands, rewriter);
return success();
}

private:
using ConversionPattern::matchAndRewrite;
};

/// Add a pattern to the given pattern list to convert the signature of a FuncOp
Expand Down

0 comments on commit 02c9050

Please sign in to comment.