-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir] Add base class type aliases for rewrites/conversions. NFC. #158433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This is to simplify writing rewrite/conversion patterns that usually start with: ```c++ struct MyPattern : public OpRewritePattern<MyPattern> { using OpRewritePattern::OpRewritePattern; ``` and allow for: ```c++ struct MyPattern : public OpRewritePattern<MyPattern> { using Base::Base; ``` similar to pass classes.
@llvm/pr-subscribers-mlir Author: Jakub Kuderski (kuhar) ChangesThis is to simplify writing rewrite/conversion patterns that usually start with: struct MyPattern : public OpRewritePattern<MyOp> {
using OpRewritePattern::OpRewritePattern; and allow for: struct MyPattern : public OpRewritePattern<MyOp> {
using Base::Base; similar to how we enable it for pass classes. Full diff: https://github.com/llvm/llvm-project/pull/158433.diff 3 Files Affected:
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 7b0b9cef9c5bd..576481a6e7215 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -312,6 +312,9 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
template <typename SourceOp>
struct OpRewritePattern
: public mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+ /// Type alias to allow derived classes to inherit constructors with
+ /// `using Base::Base;`.
+ using Base = OpRewritePattern;
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching and a list of generated
@@ -328,6 +331,9 @@ struct OpRewritePattern
template <typename SourceOp>
struct OpInterfaceRewritePattern
: public mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+ /// Type alias to allow derived classes to inherit constructors with
+ /// `using Base::Base;`.
+ using Base = OpInterfaceRewritePattern;
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
: mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp>(
@@ -341,6 +347,10 @@ struct OpInterfaceRewritePattern
template <template <typename> class TraitType>
class OpTraitRewritePattern : public RewritePattern {
public:
+ /// Type alias to allow derived classes to inherit constructors with
+ /// `using Base::Base;`.
+ using Base = OpTraitRewritePattern;
+
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
: RewritePattern(Pattern::MatchTraitOpTypeTag(), TypeID::get<TraitType>(),
benefit, context) {}
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index bfbe12d2a5668..6ef649e8fc13a 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -40,6 +40,10 @@ class Value;
/// registered using addConversion and addMaterialization, respectively.
class TypeConverter {
public:
+ /// Type alias to allow derived classes to inherit constructors with
+ /// `using Base::Base;`.
+ using Base = TypeConverter;
+
virtual ~TypeConverter() = default;
TypeConverter() = default;
// Copy the registered conversions, but not the caches
@@ -679,6 +683,10 @@ class ConversionPattern : public RewritePattern {
template <typename SourceOp>
class OpConversionPattern : public ConversionPattern {
public:
+ /// Type alias to allow derived classes to inherit constructors with
+ /// `using Base::Base;`.
+ using Base = OpConversionPattern;
+
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
@@ -729,6 +737,10 @@ class OpConversionPattern : public ConversionPattern {
template <typename SourceOp>
class OpInterfaceConversionPattern : public ConversionPattern {
public:
+ /// Type alias to allow derived classes to inherit constructors with
+ /// `using Base::Base;`.
+ using Base = OpInterfaceConversionPattern;
+
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(Pattern::MatchInterfaceOpTypeTag(),
SourceOp::getInterfaceID(), benefit, context) {}
@@ -773,6 +785,10 @@ class OpInterfaceConversionPattern : public ConversionPattern {
template <template <typename> class TraitType>
class OpTraitConversionPattern : public ConversionPattern {
public:
+ /// Type alias to allow derived classes to inherit constructors with
+ /// `using Base::Base;`.
+ using Base = OpTraitConversionPattern;
+
OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(Pattern::MatchTraitOpTypeTag(),
TypeID::get<TraitType>(), benefit, context) {}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 93b007c792ad9..f8b5144e3acb2 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -114,7 +115,8 @@ struct FoldingPattern : public RewritePattern {
struct FolderInsertBeforePreviouslyFoldedConstantPattern
: public OpRewritePattern<TestCastOp> {
public:
- using OpRewritePattern<TestCastOp>::OpRewritePattern;
+ static_assert(std::is_same_v<Base, OpRewritePattern<TestCastOp>>);
+ using Base::Base;
LogicalResult matchAndRewrite(TestCastOp op,
PatternRewriter &rewriter) const override {
@@ -1306,7 +1308,8 @@ class TestReplaceWithValidConsumer : public ConversionPattern {
/// b) or: drops all block arguments and replaces each with 2x the first
/// operand.
class TestConvertBlockArgs : public OpConversionPattern<ConvertBlockArgsOp> {
- using OpConversionPattern<ConvertBlockArgsOp>::OpConversionPattern;
+ static_assert(std::is_same_v<Base, OpConversionPattern<ConvertBlockArgsOp>>);
+ using Base::Base;
LogicalResult
matchAndRewrite(ConvertBlockArgsOp op, OpAdaptor adaptor,
@@ -1431,7 +1434,9 @@ class TestTypeConsumerOpPattern
namespace {
struct TestTypeConverter : public TypeConverter {
- using TypeConverter::TypeConverter;
+ static_assert(std::is_same_v<Base, TypeConverter>);
+ using Base::Base;
+
TestTypeConverter() {
addConversion(convertType);
addSourceMaterialization(materializeCast);
|
@llvm/pr-subscribers-mlir-core Author: Jakub Kuderski (kuhar) ChangesThis is to simplify writing rewrite/conversion patterns that usually start with: struct MyPattern : public OpRewritePattern<MyOp> {
using OpRewritePattern::OpRewritePattern; and allow for: struct MyPattern : public OpRewritePattern<MyOp> {
using Base::Base; similar to how we enable it for pass classes. Full diff: https://github.com/llvm/llvm-project/pull/158433.diff 3 Files Affected:
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 7b0b9cef9c5bd..576481a6e7215 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -312,6 +312,9 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
template <typename SourceOp>
struct OpRewritePattern
: public mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+ /// Type alias to allow derived classes to inherit constructors with
+ /// `using Base::Base;`.
+ using Base = OpRewritePattern;
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching and a list of generated
@@ -328,6 +331,9 @@ struct OpRewritePattern
template <typename SourceOp>
struct OpInterfaceRewritePattern
: public mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+ /// Type alias to allow derived classes to inherit constructors with
+ /// `using Base::Base;`.
+ using Base = OpInterfaceRewritePattern;
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
: mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp>(
@@ -341,6 +347,10 @@ struct OpInterfaceRewritePattern
template <template <typename> class TraitType>
class OpTraitRewritePattern : public RewritePattern {
public:
+ /// Type alias to allow derived classes to inherit constructors with
+ /// `using Base::Base;`.
+ using Base = OpTraitRewritePattern;
+
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
: RewritePattern(Pattern::MatchTraitOpTypeTag(), TypeID::get<TraitType>(),
benefit, context) {}
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index bfbe12d2a5668..6ef649e8fc13a 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -40,6 +40,10 @@ class Value;
/// registered using addConversion and addMaterialization, respectively.
class TypeConverter {
public:
+ /// Type alias to allow derived classes to inherit constructors with
+ /// `using Base::Base;`.
+ using Base = TypeConverter;
+
virtual ~TypeConverter() = default;
TypeConverter() = default;
// Copy the registered conversions, but not the caches
@@ -679,6 +683,10 @@ class ConversionPattern : public RewritePattern {
template <typename SourceOp>
class OpConversionPattern : public ConversionPattern {
public:
+ /// Type alias to allow derived classes to inherit constructors with
+ /// `using Base::Base;`.
+ using Base = OpConversionPattern;
+
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
@@ -729,6 +737,10 @@ class OpConversionPattern : public ConversionPattern {
template <typename SourceOp>
class OpInterfaceConversionPattern : public ConversionPattern {
public:
+ /// Type alias to allow derived classes to inherit constructors with
+ /// `using Base::Base;`.
+ using Base = OpInterfaceConversionPattern;
+
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(Pattern::MatchInterfaceOpTypeTag(),
SourceOp::getInterfaceID(), benefit, context) {}
@@ -773,6 +785,10 @@ class OpInterfaceConversionPattern : public ConversionPattern {
template <template <typename> class TraitType>
class OpTraitConversionPattern : public ConversionPattern {
public:
+ /// Type alias to allow derived classes to inherit constructors with
+ /// `using Base::Base;`.
+ using Base = OpTraitConversionPattern;
+
OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(Pattern::MatchTraitOpTypeTag(),
TypeID::get<TraitType>(), benefit, context) {}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 93b007c792ad9..f8b5144e3acb2 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -114,7 +115,8 @@ struct FoldingPattern : public RewritePattern {
struct FolderInsertBeforePreviouslyFoldedConstantPattern
: public OpRewritePattern<TestCastOp> {
public:
- using OpRewritePattern<TestCastOp>::OpRewritePattern;
+ static_assert(std::is_same_v<Base, OpRewritePattern<TestCastOp>>);
+ using Base::Base;
LogicalResult matchAndRewrite(TestCastOp op,
PatternRewriter &rewriter) const override {
@@ -1306,7 +1308,8 @@ class TestReplaceWithValidConsumer : public ConversionPattern {
/// b) or: drops all block arguments and replaces each with 2x the first
/// operand.
class TestConvertBlockArgs : public OpConversionPattern<ConvertBlockArgsOp> {
- using OpConversionPattern<ConvertBlockArgsOp>::OpConversionPattern;
+ static_assert(std::is_same_v<Base, OpConversionPattern<ConvertBlockArgsOp>>);
+ using Base::Base;
LogicalResult
matchAndRewrite(ConvertBlockArgsOp op, OpAdaptor adaptor,
@@ -1431,7 +1434,9 @@ class TestTypeConsumerOpPattern
namespace {
struct TestTypeConverter : public TypeConverter {
- using TypeConverter::TypeConverter;
+ static_assert(std::is_same_v<Base, TypeConverter>);
+ using Base::Base;
+
TestTypeConverter() {
addConversion(convertType);
addSourceMaterialization(materializeCast);
|
Looks good, but I'll let others have time to review. Thanks! |
Use the `Base` type alias from llvm#158433.
…. NFC. (#159681) Use the `Base` type alias from llvm/llvm-project#158433.
Use the `Base` type alias from llvm#158433.
…s. NFC. (#159682) Use the `Base` type alias from llvm/llvm-project#158433.
…vm#158433) This is to simplify writing rewrite/conversion patterns that usually start with: ```c++ struct MyPattern : public OpRewritePattern<MyOp> { using OpRewritePattern::OpRewritePattern; ``` and allow for: ```c++ struct MyPattern : public OpRewritePattern<MyOp> { using Base::Base; ``` similar to how we enable it for pass classes.
…m#159681) Use the `Base` type alias from llvm#158433.
…vm#159682) Use the `Base` type alias from llvm#158433.
Use the `Base` type alias added in llvm/llvm-project#158433.
Use the `Base` type alias added in llvm/llvm-project#158433.
…C. (#22143) Use the `Base` type alias added in llvm/llvm-project#158433.
…22142) Use the `Base` type alias added in llvm/llvm-project#158433.
Use the `Base` type alias from llvm#158433.
…tructors. NFC. (#161670) Use the `Base` type alias from llvm/llvm-project#158433.
…FC. (llvm#161670) Use the `Base` type alias from llvm#158433.
…FC. (llvm#161670) Use the `Base` type alias from llvm#158433.
Use the `Base` type alias from llvm#158433.
…ctors. NFC. (#161966) Use the `Base` type alias from llvm/llvm-project#158433.
…llvm#161966) Use the `Base` type alias from llvm#158433.
This is to simplify writing rewrite/conversion patterns that usually start with:
and allow for:
similar to how we enable it for pass classes.