Skip to content

Commit

Permalink
[mlir] Revert "Tighten access of RewritePattern methods."
Browse files Browse the repository at this point in the history
This reverts commit 02c9050.
Painted myself into a corner with -Wvirtual_overload, private access, and final.

Differential Revision: https://reviews.llvm.org/D92855
  • Loading branch information
chsigg committed Dec 8, 2020
1 parent 25f5df7 commit 2a98409
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 55 deletions.
Expand Up @@ -571,9 +571,11 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
&typeConverter.getContext(), typeConverter,
benefit) {}

private:
/// Wrappers around the ConversionPattern methods that pass the derived op
/// type.
/// Wrappers around the RewritePattern methods that pass the derived op type.
void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), operands, rewriter);
}
LogicalResult match(Operation *op) const final {
return match(cast<SourceOp>(op));
}
Expand All @@ -582,10 +584,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
ConversionPatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
}
void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), operands, rewriter);
}

/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
Expand All @@ -605,6 +603,10 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
}
return failure();
}

private:
using ConvertToLLVMPattern::match;
using ConvertToLLVMPattern::matchAndRewrite;
};

namespace LLVM {
Expand Down Expand Up @@ -634,7 +636,6 @@ class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
using Super = OneToOneConvertToLLVMPattern<SourceOp, TargetOp>;

private:
/// Converts the type of the result to an LLVM type, pass operands as is,
/// preserve attributes.
LogicalResult
Expand All @@ -654,7 +655,6 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;

private:
LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Expand Down
27 changes: 12 additions & 15 deletions mlir/include/mlir/IR/PatternMatch.h
Expand Up @@ -156,6 +156,17 @@ class RewritePattern : public Pattern {
public:
virtual ~RewritePattern() {}

/// Rewrite the IR rooted at the specified operation with the result of
/// this pattern, generating any new operations with the specified
/// builder. If an unexpected error is encountered (an internal
/// compiler error), it is emitted through the normal MLIR diagnostic
/// hooks and the IR is left in a valid state.
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;

/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind().
virtual LogicalResult match(Operation *op) const;

/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind(). If successful, this
/// function will automatically perform the rewrite.
Expand All @@ -172,18 +183,6 @@ class RewritePattern : public Pattern {
/// Inherit the base constructors from `Pattern`.
using Pattern::Pattern;

/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind().
virtual LogicalResult match(Operation *op) const;

private:
/// Rewrite the IR rooted at the specified operation with the result of
/// this pattern, generating any new operations with the specified
/// builder. If an unexpected error is encountered (an internal
/// compiler error), it is emitted through the normal MLIR diagnostic
/// hooks and the IR is left in a valid state.
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;

/// An anchor for the virtual table.
virtual void anchor();
};
Expand All @@ -192,14 +191,12 @@ class RewritePattern : public Pattern {
/// matching and rewriting against an instance of a derived operation class as
/// opposed to a raw Operation.
template <typename SourceOp>
class OpRewritePattern : public RewritePattern {
public:
struct OpRewritePattern : public RewritePattern {
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
: RewritePattern(SourceOp::getOperationName(), benefit, context) {}

private:
/// Wrappers around the RewritePattern methods that pass the derived op type.
void rewrite(Operation *op, PatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), rewriter);
Expand Down
61 changes: 30 additions & 31 deletions mlir/include/mlir/Transforms/DialectConversion.h
Expand Up @@ -313,30 +313,6 @@ class TypeConverter {
/// patterns of this type can only be used with the 'apply*' methods below.
class ConversionPattern : public RewritePattern {
public:
/// Return the type converter held by this pattern, or nullptr if the pattern
/// does not require type conversion.
TypeConverter *getTypeConverter() const { return typeConverter; }

protected:
/// See `RewritePattern::RewritePattern` for information on the other
/// available constructors.
using RewritePattern::RewritePattern;
/// Construct a conversion pattern that matches an operation with the given
/// root name. This constructor allows for providing a type converter to use
/// within the pattern.
ConversionPattern(StringRef rootName, PatternBenefit benefit,
TypeConverter &typeConverter, MLIRContext *ctx)
: RewritePattern(rootName, benefit, ctx), typeConverter(&typeConverter) {}
/// Construct a conversion pattern that matches any operation type. This
/// constructor allows for providing a type converter to use within the
/// pattern. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
/// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
/// always be supplied here.
ConversionPattern(PatternBenefit benefit, TypeConverter &typeConverter,
MatchAnyOpTypeTag tag)
: RewritePattern(benefit, tag), typeConverter(&typeConverter) {}

private:
/// Hook for derived classes to implement rewriting. `op` is the (first)
/// operation matched by the pattern, `operands` is a list of the rewritten
/// operand values that are passed to `op`, `rewriter` can be used to emit the
Expand All @@ -347,10 +323,6 @@ class ConversionPattern : public RewritePattern {
llvm_unreachable("unimplemented rewrite");
}

void rewrite(Operation *op, PatternRewriter &rewriter) const final {
llvm_unreachable("never called");
}

/// Hook for derived classes to implement combined matching and rewriting.
virtual LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
Expand All @@ -365,25 +337,49 @@ class ConversionPattern : public RewritePattern {
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final;

/// Return the type converter held by this pattern, or nullptr if the pattern
/// does not require type conversion.
TypeConverter *getTypeConverter() const { return typeConverter; }

protected:
/// See `RewritePattern::RewritePattern` for information on the other
/// available constructors.
using RewritePattern::RewritePattern;
/// Construct a conversion pattern that matches an operation with the given
/// root name. This constructor allows for providing a type converter to use
/// within the pattern.
ConversionPattern(StringRef rootName, PatternBenefit benefit,
TypeConverter &typeConverter, MLIRContext *ctx)
: RewritePattern(rootName, benefit, ctx), typeConverter(&typeConverter) {}
/// Construct a conversion pattern that matches any operation type. This
/// constructor allows for providing a type converter to use within the
/// pattern. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
/// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
/// always be supplied here.
ConversionPattern(PatternBenefit benefit, TypeConverter &typeConverter,
MatchAnyOpTypeTag tag)
: RewritePattern(benefit, tag), typeConverter(&typeConverter) {}

protected:
/// An optional type converter for use by this pattern.
TypeConverter *typeConverter = nullptr;

private:
using RewritePattern::rewrite;
};

/// OpConversionPattern is a wrapper around ConversionPattern that allows for
/// matching and rewriting against an instance of a derived operation class as
/// opposed to a raw Operation.
template <typename SourceOp>
class OpConversionPattern : public ConversionPattern {
public:
struct OpConversionPattern : public ConversionPattern {
OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context,
PatternBenefit benefit = 1)
: ConversionPattern(SourceOp::getOperationName(), benefit, typeConverter,
context) {}

private:
/// Wrappers around the ConversionPattern methods that pass the derived op
/// type.
void rewrite(Operation *op, ArrayRef<Value> operands,
Expand Down Expand Up @@ -413,6 +409,9 @@ class OpConversionPattern : public ConversionPattern {
rewrite(op, operands, rewriter);
return success();
}

private:
using ConversionPattern::matchAndRewrite;
};

/// Add a pattern to the given pattern list to convert the signature of a FuncOp
Expand Down

0 comments on commit 2a98409

Please sign in to comment.