diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h index 33e3d94f02b1c..8b76930aed35a 100644 --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h @@ -145,6 +145,10 @@ FailureOr multiBuffer(memref::AllocOp allocOp, /// ``` void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns); +/// Patterns for flattening multi-dimensional memref operations into +/// one-dimensional memref operations. +void populateFlattenVectorOpsOnMemrefPatterns(RewritePatternSet &patterns); +void populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns); void populateFlattenMemrefsPatterns(RewritePatternSet &patterns); /// Build a new memref::AllocaOp whose dynamic sizes are independent of all diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index 0138f477cadea..08f439222a9a0 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -383,6 +383,16 @@ void populateVectorNarrowTypeEmulationPatterns( const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns, bool disableAtomicRMW = false); +/// Populates patterns for both MeMref flattening and Vector narrow type +/// emulation. +/// +/// Patterns for narrow-type-emulation require "flattened" MemRef(s), so this +/// composite populate* method can be used for narrow-type-emulation for Ops +/// operating on MemRef(s) that are rank > 2. +void populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns( + arith::NarrowTypeEmulationConverter &typeConverter, + RewritePatternSet &patterns); + /// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of /// vector operations comprising `shuffle` and `bitwise` ops. /// Warning: these patterns currently only work for little endian targets. diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp index 42be847811d52..1208fddf37e0b 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp @@ -271,12 +271,9 @@ struct FlattenMemrefsPass } // namespace -void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) { - patterns.insert, - MemRefRewritePattern, - MemRefRewritePattern, - MemRefRewritePattern, - MemRefRewritePattern, +void memref::populateFlattenVectorOpsOnMemrefPatterns( + RewritePatternSet &patterns) { + patterns.insert, MemRefRewritePattern, MemRefRewritePattern, MemRefRewritePattern, @@ -284,3 +281,16 @@ void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) { MemRefRewritePattern>( patterns.getContext()); } + +void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) { + patterns.insert, + MemRefRewritePattern, + MemRefRewritePattern, + MemRefRewritePattern>( + patterns.getContext()); +} + +void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) { + populateFlattenMemrefOpsPatterns(patterns); + populateFlattenVectorOpsOnMemrefPatterns(patterns); +} diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 852eff2d2b909..264cbc1869b9a 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -38,6 +38,8 @@ #include #include +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" + using namespace mlir; #define DEBUG_TYPE "vector-narrow-type-emulation" @@ -556,7 +558,6 @@ struct ConvertVectorStore final : OpConversionPattern { matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // See #115653 if (op.getValueToStore().getType().getRank() != 1) return rewriter.notifyMatchFailure(op, "only 1-D vectors are supported ATM"); @@ -817,7 +818,13 @@ struct ConvertVectorStore final : OpConversionPattern { // ConvertVectorMaskedStore //===----------------------------------------------------------------------===// -// TODO: Document-me +/// Converts `vector.maskedstore` operations on narrow element types to work +/// with wider, byte-aligned container types by adjusting the mask and using +/// bitcasting. +/// +/// Example: Storing `vector<6xi4>` is emulated by bitcasting to `vector<3xi8>` +/// (each `i8` container element holds two `i4` values) and storing with an +/// adjusted mask . struct ConvertVectorMaskedStore final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -826,10 +833,10 @@ struct ConvertVectorMaskedStore final matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // See #115653 + // Prerequisite: memref in the vector.maskedstore op is flattened into 1-D. if (op.getValueToStore().getType().getRank() != 1) - return rewriter.notifyMatchFailure(op, - "only 1-D vectors are supported ATM"); + return rewriter.notifyMatchFailure( + op, "Memref in vector.maskedstore op must be flattened beforehand."); auto loc = op.getLoc(); auto containerElemTy = @@ -931,18 +938,27 @@ struct ConvertVectorMaskedStore final // ConvertVectorLoad //===----------------------------------------------------------------------===// -// TODO: Document-me +/// Converts `vector.load` on narrow element types to work with +/// wider, byte-aligned container types by adjusting load sizes and using +/// bitcasting. +/// +/// Example: `vector.load` of `vector<4xi4>` from `memref<3x4xi4>` is emulated +/// by loading `vector<2xi8>` from the linearized `memref<6xi8>` (each `i8` +/// container holds two `i4` values) and bitcasting back. +/// +/// There are cases where the number of elements to load is not byte-aligned. In +/// those cases, loads are converted to byte-aligned, byte-sized loads and the +/// target vector is extracted from the loaded vector. struct ConvertVectorLoad final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - - // See #115653 + // Prerequisite: memref in the vector.load op is flattened into 1-D. if (op.getVectorType().getRank() != 1) - return rewriter.notifyMatchFailure(op, - "only 1-D vectors are supported ATM"); + return rewriter.notifyMatchFailure( + op, "Memref in emulated vector ops must be flattened beforehand."); auto loc = op.getLoc(); auto containerElemTy = @@ -961,8 +977,6 @@ struct ConvertVectorLoad final : OpConversionPattern { // Adjust the number of elements to load when emulating narrow types, // and then cast back to the original type with vector.bitcast op. - // Here only the 1-D vector load is considered, and the N-D memref types - // should be linearized. // For example, to emulate i4 to i8, the following op: // // %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4> @@ -1037,7 +1051,12 @@ struct ConvertVectorLoad final : OpConversionPattern { // ConvertVectorMaskedLoad //===----------------------------------------------------------------------===// -// TODO: Document-me +/// Converts `vector.maskedload` operations on narrow element types to work with +/// wider, byte-aligned container types by adjusting the mask and using +/// bitcasting. +/// +/// Example: Loading `vector<6xi4>` is emulated by loading `vector<3xi8>` and +/// bitcasting, since each `i8` container element holds two `i4` values. struct ConvertVectorMaskedLoad final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -1045,10 +1064,9 @@ struct ConvertVectorMaskedLoad final LogicalResult matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // See #115653 if (op.getVectorType().getRank() != 1) - return rewriter.notifyMatchFailure(op, - "only 1-D vectors are supported ATM"); + return rewriter.notifyMatchFailure( + op, "Memref in emulated vector ops must be flattened beforehand."); auto loc = op.getLoc(); @@ -1229,7 +1247,6 @@ static bool fitsInMultiByteContainerTy(VectorType subByteVecTy, int elemsPerMultiByte = multiByteBits / subByteBits; - // TODO: This is a bit too restrictive for vectors rank > 1. return subByteVecTy.getShape().back() % elemsPerMultiByte == 0; } @@ -1246,10 +1263,11 @@ struct ConvertVectorTransferRead final matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // See #115653 + // Prerequisites: memref in the vector.transfer_read op is flattened into + // 1-D. if (op.getVectorType().getRank() != 1) - return rewriter.notifyMatchFailure(op, - "only 1-D vectors are supported ATM"); + return rewriter.notifyMatchFailure( + op, "Memref in emulated vector ops must be flattened beforehand."); auto loc = op.getLoc(); auto containerElemTy = @@ -2227,7 +2245,6 @@ struct RewriteVectorTranspose : OpRewritePattern { void vector::populateVectorNarrowTypeEmulationPatterns( const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns, bool disableAtomicRMW) { - // Populate `vector.*` conversion patterns. // TODO: #119553 support atomicity patterns.add(patterns.getContext(), benefit); } + +void vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns( + arith::NarrowTypeEmulationConverter &typeConverter, + RewritePatternSet &patterns) { + memref::populateFlattenVectorOpsOnMemrefPatterns(patterns); + vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns); +} diff --git a/mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir b/mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir new file mode 100644 index 0000000000000..222e613f5c18a --- /dev/null +++ b/mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir @@ -0,0 +1,64 @@ +// RUN: mlir-opt --test-memref-flatten-and-vector-narrow-type-emulation --split-input-file %s | FileCheck %s + +// This test verifies that narrow-type-emulation works correctly for +// rank > 1 memrefs by combining memref flattening with vector narrow type +// emulation patterns. +// +// The patterns tested here demonstrate the composition of two transformations, +// memref flattening for vector ops and vector op narrow type emulation. +// +// TODO: Support `vector.transfer_write` operation. + +func.func @vector_load_2d_i4(%arg0: index) -> vector<8xi4> { + %0 = memref.alloc() : memref<4x8xi4> + %1 = vector.load %0[%arg0, %arg0] : memref<4x8xi4>, vector<8xi4> + return %1 : vector<8xi4> +} +// CHECK-LABEL: func @vector_load_2d_i4 +// CHECK: vector.load {{.*}} memref<16xi8> + +// ----- + +func.func @vector_maskedload_2d_i4(%arg0: index, %passthru: vector<8xi4>) -> vector<8xi4> { + %0 = memref.alloc() : memref<4x8xi4> + %mask = vector.constant_mask [6] : vector<8xi1> + %1 = vector.maskedload %0[%arg0, %arg0], %mask, %passthru : + memref<4x8xi4>, vector<8xi1>, vector<8xi4> into vector<8xi4> + return %1 : vector<8xi4> +} +// CHECK-LABEL: func @vector_maskedload_2d_i4( +// CHECK: vector.maskedload {{.*}} memref<16xi8> + +// ----- + +func.func @vector_maskedstore_2d_i4(%arg0: index, %value: vector<8xi4>) { + %0 = memref.alloc() : memref<4x8xi4> + %mask = vector.constant_mask [5] : vector<8xi1> + vector.maskedstore %0[%arg0, %arg0], %mask, %value : + memref<4x8xi4>, vector<8xi1>, vector<8xi4> + return +} +// CHECK-LABEL: func @vector_maskedstore_2d_i4( +// CHECK: vector.maskedstore {{.*}} memref<16xi8> + +// ----- + +func.func @vector_store_2d_i4(%arg0: index, %value: vector<8xi4>) { + %0 = memref.alloc() : memref<4x8xi4> + vector.store %value, %0[%arg0, %arg0] : memref<4x8xi4>, vector<8xi4> + return +} +// CHECK-LABEL: func @vector_store_2d_i4( +// CHECK: vector.store {{.*}} memref<16xi8> + +// ----- + +func.func @vector_transfer_read_2d_i4(%arg0: index, %padding: i4) -> vector<8xi4> { + %0 = memref.alloc() : memref<4x8xi4> + %1 = vector.transfer_read %0[%arg0, %arg0], %padding {in_bounds = [true]} : memref<4x8xi4>, vector<8xi4> + return %1 : vector<8xi4> +} +// CHECK-LABEL: func @vector_transfer_read_2d_i4( +// CHECK-SAME: %{{.*}}: index, %[[PADDING_I4:.*]]: i4) +// CHECK: %[[PADDING_I8:.*]] = arith.extui %[[PADDING_I4]] : i4 to i8 +// CHECK: vector.transfer_read {{.*}}, %[[PADDING_I8]] : memref<16xi8>, vector<4xi8> diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp index ba2ea40e83d96..b5f015aff19b4 100644 --- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp +++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; @@ -126,10 +127,73 @@ struct TestEmulateNarrowTypePass "normal sequence"), llvm::cl::init(false)}; }; + +struct TestMemRefFlattenAndVectorNarrowTypeEmulationPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestMemRefFlattenAndVectorNarrowTypeEmulationPass) + + TestMemRefFlattenAndVectorNarrowTypeEmulationPass() = default; + TestMemRefFlattenAndVectorNarrowTypeEmulationPass( + const TestMemRefFlattenAndVectorNarrowTypeEmulationPass &pass) + : PassWrapper(pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + StringRef getArgument() const final { + return "test-memref-flatten-and-vector-narrow-type-emulation"; + } + + StringRef getDescription() const final { + return "Test MemRef flattening and vector narrow type emulation patterns"; + } + + void runOnOperation() override { + Operation *op = getOperation(); + MLIRContext *ctx = &getContext(); + + // Create a type converter for narrow type emulation (8-bit) + arith::NarrowTypeEmulationConverter typeConverter(8); + + // Add conversions for memref types with i4 elements + memref::populateMemRefNarrowTypeEmulationConversions(typeConverter); + + ConversionTarget target(*ctx); + target.addDynamicallyLegalOp([&typeConverter](Operation *op) { + return typeConverter.isLegal(cast(op).getFunctionType()); + }); + auto opLegalCallback = [&typeConverter](Operation *op) { + return typeConverter.isLegal(op); + }; + target.addDynamicallyLegalOp(opLegalCallback); + target.addDynamicallyLegalDialect< + arith::ArithDialect, vector::VectorDialect, memref::MemRefDialect, + affine::AffineDialect>(opLegalCallback); + + RewritePatternSet patterns(ctx); + + // This is necessary for the purpose of emulating `memref.alloc` and + // function boundaries. + memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns); + + vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns( + typeConverter, patterns); + + // Apply partial conversion + if (failed(applyPartialConversion(op, target, std::move(patterns)))) + signalPassFailure(); + } +}; } // namespace namespace mlir::test { void registerTestEmulateNarrowTypePass() { PassRegistration(); + PassRegistration(); } } // namespace mlir::test