diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index b57e66a1c3580..fc1bc8e55cbbe 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -107,11 +107,10 @@ static OpType getSingleOpOfType(Block &block) { /// Helper function to extract the input slices after filter is unrolled along /// kw. -static SmallVector -extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, - int64_t nSize, int64_t wSize, int64_t cSize, - int64_t kwSize, int strideW, int dilationW, - int64_t wSizeStep, bool isSingleChanneled) { +static SmallVector extractConvInputSlices( + RewriterBase &rewriter, Location loc, Value input, int64_t nSize, + int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, + int64_t wSizeStep, bool isSingleChanneled, bool isNCWPooling = false) { SmallVector result; if (isSingleChanneled) { // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled @@ -125,6 +124,19 @@ extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, strides)); } } + } else if (isNCWPooling) { + // Extract lhs slice of size {n, c, wSizeStep} @ [0, 0, sw * w + dw * kw] + // for NCW pooling. + SmallVector sizes = {nSize, cSize, wSizeStep}; + SmallVector strides = {1, 1, 1}; + for (int64_t kw = 0; kw < kwSize; ++kw) { + for (int64_t w = 0; w < wSize; w += wSizeStep) { + result.push_back(vector::ExtractStridedSliceOp::create( + rewriter, loc, input, + /*offsets=*/ArrayRef{0, 0, w * strideW + kw * dilationW}, + sizes, strides)); + } + } } else { // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0] // for channeled convolution. @@ -162,7 +174,8 @@ static SmallVector extractConvFilterSlices(RewriterBase &rewriter, static SmallVector extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, - int64_t wSizeStep, bool isSingleChanneled) { + int64_t wSizeStep, bool isSingleChanneled, + bool isNCWPooling = false) { SmallVector result; if (isSingleChanneled) { // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution. @@ -173,6 +186,15 @@ extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, rewriter, loc, res, /*offsets=*/ArrayRef{w}, sizes, strides)); } + } else if (isNCWPooling) { + // Extract res slice: {n, f, wSizeStep} @ [0, 0, w] for NCW pooling. + SmallVector sizes = {nSize, fSize, wSizeStep}; + SmallVector strides = {1, 1, 1}; + for (int64_t w = 0; w < wSize; w += wSizeStep) { + result.push_back(vector::ExtractStridedSliceOp::create( + rewriter, loc, res, /*offsets=*/ArrayRef{0, 0, w}, sizes, + strides)); + } } else { // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled // convolution. @@ -191,7 +213,8 @@ extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl &resVals, - bool isSingleChanneled) { + bool isSingleChanneled, + bool isNCWPooling = false) { if (isSingleChanneled) { // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution. @@ -202,6 +225,14 @@ static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, rewriter, loc, resVals[w], res, /*offsets=*/ArrayRef{w}, strides); } + } else if (isNCWPooling) { + // Write back res slice: {n, f, wSizeStep} @ [0, 0, w] for NCW pooling. + SmallVector strides = {1, 1, 1}; + for (int64_t w = 0; w < wSize; w += wSizeStep) { + res = vector::InsertStridedSliceOp::create( + rewriter, loc, resVals[w], res, + /*offsets=*/ArrayRef{0, 0, w}, strides); + } } else { // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled // convolution. This does not depend on kw. @@ -3436,6 +3467,8 @@ struct Conv1DGenerator int64_t nSize, wSize, cSize, kwSize, fSize; SmallVector lhsShape, rhsShape, resShape; bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W); + bool isNCWPooling = (oper == ConvOperationKind::Pool && + conv1DOpOrder == Conv1DOpOrder::Ncw); switch (conv1DOpOrder) { case Conv1DOpOrder::W: // Initialize unused dimensions @@ -3555,6 +3588,8 @@ struct Conv1DGenerator // Base case, so no transposes necessary. break; case Conv1DOpOrder::Ncw: { + if (isNCWPooling) + break; // To match base vectorization case, we pre-transpose current case. // ncw -> nwc static constexpr std::array permLhs = {0, 2, 1}; @@ -3579,12 +3614,13 @@ struct Conv1DGenerator SmallVector lhsVals, rhsVals, resVals; lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize, kwSize, strideW, dilationW, wSizeStep, - isSingleChanneled); + isSingleChanneled, isNCWPooling); // Do not do for pooling. if (oper == ConvOperationKind::Conv) rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize); - resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize, - wSizeStep, isSingleChanneled); + resVals = + extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize, + wSizeStep, isSingleChanneled, isNCWPooling); auto linearIndex = [&](int64_t kw, int64_t w) { return kw * (wSize / wSizeStep) + w; @@ -3616,7 +3652,7 @@ struct Conv1DGenerator } res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals, - isSingleChanneled); + isSingleChanneled, isNCWPooling); //===------------------------------------------------------------------===// // End vector-only rewrite part //===------------------------------------------------------------------===// @@ -3630,6 +3666,8 @@ struct Conv1DGenerator // Base case, so no transposes necessary. break; case Conv1DOpOrder::Ncw: { + if (isNCWPooling) + break; // nwf -> nfw static constexpr std::array perm = {0, 2, 1}; res = vector::TransposeOp::create(rewriter, loc, res, perm); diff --git a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir index 97b27befd44e2..d443375e525fc 100644 --- a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir @@ -1111,18 +1111,16 @@ func.func @pooling_ncw_sum_memref_1_2_1_3(%input: memref<4x3x4xf32>, %filter: me // CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[V0:.+]] = vector.transfer_read %[[INPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x3x4xf32>, vector<4x3x4xf32> // CHECK: %[[V1:.+]] = vector.transfer_read %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x3x2xf32>, vector<4x3x2xf32> -// CHECK: %[[V2:.+]] = vector.transpose %[[V0]], [0, 2, 1] : vector<4x3x4xf32> to vector<4x4x3xf32> -// CHECK: %[[V3:.+]] = vector.transpose %[[V1]], [0, 2, 1] : vector<4x3x2xf32> to vector<4x2x3xf32> -// CHECK: %[[V4:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32> -// CHECK: %[[V5:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32> -// CHECK: %[[V6:.+]] = vector.extract_strided_slice %[[V3]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32> -// CHECK: %[[V7:.+]] = vector.extract_strided_slice %[[V3]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32> -// CHECK: %[[V8:.+]] = arith.addf %[[V4]], %[[V6]] : vector<4x1x3xf32> -// CHECK: %[[V9:.+]] = arith.addf %[[V5]], %[[V7]] : vector<4x1x3xf32> -// CHECK: %[[V10:.+]] = vector.insert_strided_slice %[[V8]], %[[V3]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32> -// CHECK: %[[V11:.+]] = vector.insert_strided_slice %[[V9]], %[[V10]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32> -// CHECK: %[[V12:.+]] = vector.transpose %[[V11]], [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32> -// CHECK: vector.transfer_write %[[V12:.+]], %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x3x2xf32>, memref<4x3x2xf32> +// CHECK-NOT: vector.transpose +// CHECK: %[[V2:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 0], sizes = [4, 3, 1], strides = [1, 1, 1]} : vector<4x3x4xf32> to vector<4x3x1xf32> +// CHECK: %[[V3:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 3], sizes = [4, 3, 1], strides = [1, 1, 1]} : vector<4x3x4xf32> to vector<4x3x1xf32> +// CHECK: %[[V4:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 0, 0], sizes = [4, 3, 1], strides = [1, 1, 1]} : vector<4x3x2xf32> to vector<4x3x1xf32> +// CHECK: %[[V5:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 0, 1], sizes = [4, 3, 1], strides = [1, 1, 1]} : vector<4x3x2xf32> to vector<4x3x1xf32> +// CHECK: %[[V6:.+]] = arith.addf %[[V2]], %[[V4]] : vector<4x3x1xf32> +// CHECK: %[[V7:.+]] = arith.addf %[[V3]], %[[V5]] : vector<4x3x1xf32> +// CHECK: %[[V8:.+]] = vector.insert_strided_slice %[[V6]], %[[V1]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x3x1xf32> into vector<4x3x2xf32> +// CHECK: %[[V9:.+]] = vector.insert_strided_slice %[[V7]], %[[V8]] {offsets = [0, 0, 1], strides = [1, 1, 1]} : vector<4x3x1xf32> into vector<4x3x2xf32> +// CHECK: vector.transfer_write %[[V9]], %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x3x2xf32>, memref<4x3x2xf32> module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { @@ -1212,22 +1210,20 @@ func.func @pooling_ncw_sum_memref_2_2_2_3(%input: memref<4x3x6xf32>, %filter: me // CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[V0:.+]] = vector.transfer_read %[[INPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x3x6xf32>, vector<4x3x6xf32> // CHECK: %[[V1:.+]] = vector.transfer_read %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x3x2xf32>, vector<4x3x2xf32> -// CHECK: %[[V2:.+]] = vector.transpose %[[V0]], [0, 2, 1] : vector<4x3x6xf32> to vector<4x6x3xf32> -// CHECK: %[[V3:.+]] = vector.transpose %[[V1]], [0, 2, 1] : vector<4x3x2xf32> to vector<4x2x3xf32> -// CHECK: %[[V4:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> -// CHECK: %[[V5:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> -// CHECK: %[[V6:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 2, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> -// CHECK: %[[V7:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 5, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> -// CHECK: %[[V8:.+]] = vector.extract_strided_slice %[[V3]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32> -// CHECK: %[[V9:.+]] = vector.extract_strided_slice %[[V3]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32> -// CHECK: %[[V10:.+]] = arith.addf %[[V4]], %[[V8]] : vector<4x1x3xf32> -// CHECK: %[[V11:.+]] = arith.addf %[[V5]], %[[V9]] : vector<4x1x3xf32> -// CHECK: %[[V12:.+]] = arith.addf %[[V6]], %[[V10]] : vector<4x1x3xf32> -// CHECK: %[[V13:.+]] = arith.addf %[[V7]], %[[V11]] : vector<4x1x3xf32> -// CHECK: %[[V14:.+]] = vector.insert_strided_slice %[[V12]], %[[V3]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32> -// CHECK: %[[V15:.+]] = vector.insert_strided_slice %[[V13]], %[[V14]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32> -// CHECK: %[[V16:.+]] = vector.transpose %[[V15]], [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32> -// CHECK: vector.transfer_write %[[V16:.+]], %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x3x2xf32>, memref<4x3x2xf32> +// CHECK-NOT: vector.transpose +// CHECK: %[[V2:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 0], sizes = [4, 3, 1], strides = [1, 1, 1]} : vector<4x3x6xf32> to vector<4x3x1xf32> +// CHECK: %[[V3:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 3], sizes = [4, 3, 1], strides = [1, 1, 1]} : vector<4x3x6xf32> to vector<4x3x1xf32> +// CHECK: %[[V4:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 2], sizes = [4, 3, 1], strides = [1, 1, 1]} : vector<4x3x6xf32> to vector<4x3x1xf32> +// CHECK: %[[V5:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 5], sizes = [4, 3, 1], strides = [1, 1, 1]} : vector<4x3x6xf32> to vector<4x3x1xf32> +// CHECK: %[[V6:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 0, 0], sizes = [4, 3, 1], strides = [1, 1, 1]} : vector<4x3x2xf32> to vector<4x3x1xf32> +// CHECK: %[[V7:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 0, 1], sizes = [4, 3, 1], strides = [1, 1, 1]} : vector<4x3x2xf32> to vector<4x3x1xf32> +// CHECK: %[[V8:.+]] = arith.addf %[[V2]], %[[V6]] : vector<4x3x1xf32> +// CHECK: %[[V9:.+]] = arith.addf %[[V3]], %[[V7]] : vector<4x3x1xf32> +// CHECK: %[[V10:.+]] = arith.addf %[[V4]], %[[V8]] : vector<4x3x1xf32> +// CHECK: %[[V11:.+]] = arith.addf %[[V5]], %[[V9]] : vector<4x3x1xf32> +// CHECK: %[[V12:.+]] = vector.insert_strided_slice %[[V10]], %[[V1]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x3x1xf32> into vector<4x3x2xf32> +// CHECK: %[[V13:.+]] = vector.insert_strided_slice %[[V11]], %[[V12]] {offsets = [0, 0, 1], strides = [1, 1, 1]} : vector<4x3x1xf32> into vector<4x3x2xf32> +// CHECK: vector.transfer_write %[[V13]], %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x3x2xf32>, memref<4x3x2xf32> module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { @@ -1254,14 +1250,12 @@ func.func @pooling_ncw_sum_memref_2_3_2_1(%input: memref<4x2x5xf32>, %filter: me // CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[V0:.+]] = vector.transfer_read %[[INPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x2x5xf32>, vector<4x2x5xf32> // CHECK: %[[V1:.+]] = vector.transfer_read %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x2x3xf32>, vector<4x2x3xf32> -// CHECK: %[[V2:.+]] = vector.transpose %[[V0]], [0, 2, 1] : vector<4x2x5xf32> to vector<4x5x2xf32> -// CHECK: %[[V3:.+]] = vector.transpose %[[V1]], [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32> -// CHECK: %[[V4:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 0, 0], sizes = [4, 3, 2], strides = [1, 1, 1]} : vector<4x5x2xf32> to vector<4x3x2xf32> -// CHECK: %[[V5:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 2, 0], sizes = [4, 3, 2], strides = [1, 1, 1]} : vector<4x5x2xf32> to vector<4x3x2xf32> -// CHECK: %[[V6:.+]] = arith.addf %[[V4]], %[[V3]] : vector<4x3x2xf32> -// CHECK: %[[V7:.+]] = arith.addf %[[V5]], %[[V6]] : vector<4x3x2xf32> -// CHECK: %[[V8:.+]] = vector.transpose %[[V7]], [0, 2, 1] : vector<4x3x2xf32> to vector<4x2x3xf32> -// CHECK: vector.transfer_write %[[V8:.+]], %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xf32>, memref<4x2x3xf32> +// CHECK-NOT: vector.transpose +// CHECK: %[[V2:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x2x5xf32> to vector<4x2x3xf32> +// CHECK: %[[V3:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 2], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x2x5xf32> to vector<4x2x3xf32> +// CHECK: %[[V4:.+]] = arith.addf %[[V2]], %[[V1]] : vector<4x2x3xf32> +// CHECK: %[[V5:.+]] = arith.addf %[[V3]], %[[V4]] : vector<4x2x3xf32> +// CHECK: vector.transfer_write %[[V5]], %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xf32>, memref<4x2x3xf32> module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {