Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 41 additions & 41 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace {
/// convertible to the lower level target dialect (LLVM, SPIR-V, etc.) directly.
class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
public:
using OpRewritePattern::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(vector::BroadcastOp op,
PatternRewriter &rewriter) const override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1151,7 +1151,7 @@ FailureOr<Value> ContractionOpLowering::lowerReduction(
///
class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
public:
using OpRewritePattern::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(vector::OuterProductOp op,
PatternRewriter &rewriter) const override {
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ namespace {
///
/// Supports vector types with a fixed leading dimension.
struct UnrollGather : OpRewritePattern<vector::GatherOp> {
using OpRewritePattern::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(vector::GatherOp op,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -98,7 +98,7 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
/// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
/// but should be fairly straightforward to extend beyond that.
struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
using OpRewritePattern::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(vector::GatherOp op,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -164,7 +164,7 @@ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
/// loads/extracts are made conditional using `scf.if` ops.
struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
using OpRewritePattern::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(vector::GatherOp op,
PatternRewriter &rewriter) const override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class UnrollDeinterleaveOp final
/// : vector<7xi16>, vector<7xi16>
/// ```
struct InterleaveToShuffle final : OpRewritePattern<vector::InterleaveOp> {
using OpRewritePattern::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(vector::InterleaveOp op,
PatternRewriter &rewriter) const override {
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ namespace {
/// until a one-dimensional vector is reached.
class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
public:
using OpRewritePattern::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(vector::CreateMaskOp op,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -100,7 +100,7 @@ class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
/// will be folded at LLVM IR level.
class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
public:
using OpRewritePattern::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -184,7 +184,7 @@ namespace {
/// and actually match the traits of its the nested `MaskableOpInterface`.
template <class SourceOp>
struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
using OpRewritePattern<MaskOp>::OpRewritePattern;
using Base::Base;

private:
LogicalResult matchAndRewrite(MaskOp maskOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ namespace {
class InnerOuterDimReductionConversion
: public OpRewritePattern<vector::MultiDimReductionOp> {
public:
using OpRewritePattern::OpRewritePattern;
using Base::Base;

explicit InnerOuterDimReductionConversion(
MLIRContext *context, vector::VectorMultiReductionLowering options,
Expand Down Expand Up @@ -136,7 +136,7 @@ class InnerOuterDimReductionConversion
class ReduceMultiDimReductionRank
: public OpRewritePattern<vector::MultiDimReductionOp> {
public:
using OpRewritePattern::OpRewritePattern;
using Base::Base;

explicit ReduceMultiDimReductionRank(
MLIRContext *context, vector::VectorMultiReductionLowering options,
Expand Down Expand Up @@ -304,7 +304,7 @@ class ReduceMultiDimReductionRank
/// and combines results
struct TwoDimMultiReductionToElementWise
: public OpRewritePattern<vector::MultiDimReductionOp> {
using OpRewritePattern::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -359,7 +359,7 @@ struct TwoDimMultiReductionToElementWise
/// a sequence of vector.reduction ops.
struct TwoDimMultiReductionToReduction
: public OpRewritePattern<vector::MultiDimReductionOp> {
using OpRewritePattern::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -420,7 +420,7 @@ struct TwoDimMultiReductionToReduction
/// separately.
struct OneDimMultiReductionToTwoDim
: public OpRewritePattern<vector::MultiDimReductionOp> {
using OpRewritePattern::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ namespace {
/// return %7, %8 : vector<2x3xi32>, vector<2xi32>
/// ```
struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
using OpRewritePattern::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(vector::ScanOp scanOp,
PatternRewriter &rewriter) const override {
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
}

public:
using OpRewritePattern::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -356,7 +356,7 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
class ScalableShapeCastOpRewritePattern
: public OpRewritePattern<vector::ShapeCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ namespace {
///
struct MixedSizeInputShuffleOpRewrite final
: OpRewritePattern<vector::ShuffleOp> {
using OpRewritePattern::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp,
PatternRewriter &rewriter) const override {
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ using namespace mlir::vector;
namespace {

struct StepToArithConstantOpRewrite final : OpRewritePattern<vector::StepOp> {
using OpRewritePattern::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(vector::StepOp stepOp,
PatternRewriter &rewriter) const override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ getToElementsDefiningOps(FromElementsOp fromElemsOp,
struct ToFromElementsToShuffleTreeRewrite final
: OpRewritePattern<vector::FromElementsOp> {

using OpRewritePattern::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(vector::FromElementsOp fromElemsOp,
PatternRewriter &rewriter) const override {
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ namespace {
/// %x = vector.insert .., .. [.., ..]
class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;
using Base::Base;

TransposeOpLowering(vector::VectorTransposeLowering vectorTransposeLowering,
MLIRContext *context, PatternBenefit benefit = 1)
Expand Down Expand Up @@ -395,7 +395,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
class Transpose2DWithUnitDimToShapeCast
: public OpRewritePattern<vector::TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;
using Base::Base;

Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
PatternBenefit benefit = 1)
Expand Down Expand Up @@ -433,7 +433,7 @@ class Transpose2DWithUnitDimToShapeCast
class TransposeOp2DToShuffleLowering
: public OpRewritePattern<vector::TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;
using Base::Base;

TransposeOp2DToShuffleLowering(
vector::VectorTransposeLowering vectorTransposeLowering,
Expand Down
12 changes: 6 additions & 6 deletions mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ namespace {
// input by inserting vector.broadcast.
struct CastAwayExtractStridedSliceLeadingOneDim
: public OpRewritePattern<vector::ExtractStridedSliceOp> {
using OpRewritePattern::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -104,7 +104,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim
// inputs by inserting vector.broadcast.
struct CastAwayInsertStridedSliceLeadingOneDim
: public OpRewritePattern<vector::InsertStridedSliceOp> {
using OpRewritePattern::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -145,7 +145,7 @@ struct CastAwayInsertStridedSliceLeadingOneDim
// Casts away leading one dimensions in vector.insert's vector inputs by
// inserting vector.broadcast.
struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
using OpRewritePattern::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(vector::InsertOp insertOp,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -221,7 +221,7 @@ static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask,
// 1 dimensions.
struct CastAwayTransferReadLeadingOneDim
: public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(vector::TransferReadOp read,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -275,7 +275,7 @@ struct CastAwayTransferReadLeadingOneDim
// 1 dimensions.
struct CastAwayTransferWriteLeadingOneDim
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(vector::TransferWriteOp write,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -541,7 +541,7 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
// vector.broadcast back to the original shape.
struct CastAwayConstantMaskLeadingOneDim
: public OpRewritePattern<vector::ConstantMaskOp> {
using OpRewritePattern::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(vector::ConstantMaskOp mask,
PatternRewriter &rewriter) const override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ namespace {
///
struct VectorMaskedLoadOpConverter final
: OpRewritePattern<vector::MaskedLoadOp> {
using OpRewritePattern::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -117,7 +117,7 @@ struct VectorMaskedLoadOpConverter final
///
struct VectorMaskedStoreOpConverter final
: OpRewritePattern<vector::MaskedStoreOp> {
using OpRewritePattern::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp,
PatternRewriter &rewriter) const override {
Expand Down
16 changes: 8 additions & 8 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ namespace {
// NOTE: By default, all RMW sequences are atomic. Set `disableAtomicRMW` to
// `false` to generate non-atomic RMW sequences.
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
using OpConversionPattern::OpConversionPattern;
using Base::Base;

ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW)
: OpConversionPattern<vector::StoreOp>(context),
Expand Down Expand Up @@ -827,7 +827,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
/// adjusted mask .
struct ConvertVectorMaskedStore final
: OpConversionPattern<vector::MaskedStoreOp> {
using OpConversionPattern::OpConversionPattern;
using Base::Base;

LogicalResult
matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -950,7 +950,7 @@ struct ConvertVectorMaskedStore final
/// those cases, loads are converted to byte-aligned, byte-sized loads and the
/// target vector is extracted from the loaded vector.
struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
using OpConversionPattern::OpConversionPattern;
using Base::Base;

LogicalResult
matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -1059,7 +1059,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
/// bitcasting, since each `i8` container element holds two `i4` values.
struct ConvertVectorMaskedLoad final
: OpConversionPattern<vector::MaskedLoadOp> {
using OpConversionPattern::OpConversionPattern;
using Base::Base;

LogicalResult
matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -1257,7 +1257,7 @@ static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
// TODO: Document-me
struct ConvertVectorTransferRead final
: OpConversionPattern<vector::TransferReadOp> {
using OpConversionPattern::OpConversionPattern;
using Base::Base;

LogicalResult
matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -1942,7 +1942,7 @@ namespace {
/// advantage of high-level information to avoid leaving LLVM to scramble with
/// peephole optimizations.
struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
using OpRewritePattern::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -2147,7 +2147,7 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
/// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
///
struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
using OpRewritePattern<arith::TruncIOp>::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -2200,7 +2200,7 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
/// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4>
///
struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
using Base::Base;

RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
: OpRewritePattern<vector::TransposeOp>(context, benefit) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ using namespace mlir::vector;
class DecomposeDifferentRankInsertStridedSlice
: public OpRewritePattern<InsertStridedSliceOp> {
public:
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(InsertStridedSliceOp op,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -84,7 +84,7 @@ class DecomposeDifferentRankInsertStridedSlice
class ConvertSameRankInsertStridedSliceIntoShuffle
: public OpRewritePattern<InsertStridedSliceOp> {
public:
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
using Base::Base;

void initialize() {
// This pattern creates recursive InsertStridedSliceOp, but the recursion is
Expand Down Expand Up @@ -183,7 +183,7 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
class Convert1DExtractStridedSliceIntoShuffle
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
using Base::Base;

LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -271,7 +271,7 @@ class Convert1DExtractStridedSliceIntoExtractInsertChain final
class DecomposeNDExtractStridedSlice
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
using Base::Base;

void initialize() {
// This pattern creates recursive ExtractStridedSliceOp, but the recursion
Expand Down
Loading