Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,9 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
template <typename SourceOp>
struct OpRewritePattern
: public mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp> {
/// Type alias to allow derived classes to inherit constructors with
/// `using Base::Base;`.
using Base = OpRewritePattern;

/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching and a list of generated
Expand All @@ -328,6 +331,9 @@ struct OpRewritePattern
template <typename SourceOp>
struct OpInterfaceRewritePattern
: public mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp> {
/// Type alias to allow derived classes to inherit constructors with
/// `using Base::Base;`.
using Base = OpInterfaceRewritePattern;

OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
: mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp>(
Expand All @@ -341,6 +347,10 @@ struct OpInterfaceRewritePattern
template <template <typename> class TraitType>
class OpTraitRewritePattern : public RewritePattern {
public:
/// Type alias to allow derived classes to inherit constructors with
/// `using Base::Base;`.
using Base = OpTraitRewritePattern;

OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
: RewritePattern(Pattern::MatchTraitOpTypeTag(), TypeID::get<TraitType>(),
benefit, context) {}
Expand Down
16 changes: 16 additions & 0 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ class Value;
/// registered using addConversion and addMaterialization, respectively.
class TypeConverter {
public:
/// Type alias to allow derived classes to inherit constructors with
/// `using Base::Base;`.
using Base = TypeConverter;

virtual ~TypeConverter() = default;
TypeConverter() = default;
// Copy the registered conversions, but not the caches
Expand Down Expand Up @@ -679,6 +683,10 @@ class ConversionPattern : public RewritePattern {
template <typename SourceOp>
class OpConversionPattern : public ConversionPattern {
public:
/// Type alias to allow derived classes to inherit constructors with
/// `using Base::Base;`.
using Base = OpConversionPattern;

using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
Expand Down Expand Up @@ -729,6 +737,10 @@ class OpConversionPattern : public ConversionPattern {
template <typename SourceOp>
class OpInterfaceConversionPattern : public ConversionPattern {
public:
/// Type alias to allow derived classes to inherit constructors with
/// `using Base::Base;`.
using Base = OpInterfaceConversionPattern;

OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(Pattern::MatchInterfaceOpTypeTag(),
SourceOp::getInterfaceID(), benefit, context) {}
Expand Down Expand Up @@ -773,6 +785,10 @@ class OpInterfaceConversionPattern : public ConversionPattern {
template <template <typename> class TraitType>
class OpTraitConversionPattern : public ConversionPattern {
public:
/// Type alias to allow derived classes to inherit constructors with
/// `using Base::Base;`.
using Base = OpTraitConversionPattern;

OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(Pattern::MatchTraitOpTypeTag(),
TypeID::get<TraitType>(), benefit, context) {}
Expand Down
11 changes: 8 additions & 3 deletions mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
Expand Down Expand Up @@ -114,7 +115,8 @@ struct FoldingPattern : public RewritePattern {
struct FolderInsertBeforePreviouslyFoldedConstantPattern
: public OpRewritePattern<TestCastOp> {
public:
using OpRewritePattern<TestCastOp>::OpRewritePattern;
static_assert(std::is_same_v<Base, OpRewritePattern<TestCastOp>>);
using Base::Base;

LogicalResult matchAndRewrite(TestCastOp op,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -1306,7 +1308,8 @@ class TestReplaceWithValidConsumer : public ConversionPattern {
/// b) or: drops all block arguments and replaces each with 2x the first
/// operand.
class TestConvertBlockArgs : public OpConversionPattern<ConvertBlockArgsOp> {
using OpConversionPattern<ConvertBlockArgsOp>::OpConversionPattern;
static_assert(std::is_same_v<Base, OpConversionPattern<ConvertBlockArgsOp>>);
using Base::Base;

LogicalResult
matchAndRewrite(ConvertBlockArgsOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -1431,7 +1434,9 @@ class TestTypeConsumerOpPattern

namespace {
struct TestTypeConverter : public TypeConverter {
using TypeConverter::TypeConverter;
static_assert(std::is_same_v<Base, TypeConverter>);
using Base::Base;

TestTypeConverter() {
addConversion(convertType);
addSourceMaterialization(materializeCast);
Expand Down
Loading