diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 6984bc2dff498..5f7cf30335e99 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -491,9 +491,6 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern { if (!controlFn(padOp)) return failure(); - if (!padOp.getResult().hasOneUse()) - return failure(); - // TODO: Enable padding when the padding values are the same. if (packOp.getPaddingValue()) return failure(); @@ -510,7 +507,6 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern { return failure(); ArrayRef innerDimsPos = packOp.getInnerDimsPos(); - ArrayRef outerDimsPerm = packOp.getOuterDimsPerm(); // Bail out if one of the padded dimension is a tiled one. llvm::SmallBitVector paddedDims = padOp.getPaddedDims(); @@ -524,11 +520,13 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(padOp); + ArrayRef outerDimsPerm = packOp.getOuterDimsPerm(); + SmallVector mixedTiles = packOp.getMixedTiles(); auto empty = tensor::PackOp::createDestinationTensor( - rewriter, loc, padOp.getSource(), packOp.getMixedTiles(), innerDimsPos, + rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos, outerDimsPerm); - Value packedSource = rewriter.create( - loc, padOp.getSource(), empty, innerDimsPos, packOp.getMixedTiles(), + auto sourcePack = rewriter.create( + loc, padOp.getSource(), empty, innerDimsPos, mixedTiles, /*padding=*/std::nullopt, outerDimsPerm); // If we have `outer_dims_perms` we need to adjust the padded dimensions. @@ -545,9 +543,22 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern { highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); auto newPadOp = rewriter.create( - loc, /*result=*/Type(), packedSource, lowPad, highPad, paddingVal, + loc, /*result=*/Type(), sourcePack, lowPad, highPad, paddingVal, padOp.getNofold()); + + // If the pad has more than one user, create an unpack on the new pad to + // replace the other uses. + if (!padOp->hasOneUse()) { + auto unpackEmpty = tensor::UnPackOp::createDestinationTensor( + rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm); + Value unpackedPad = rewriter.create( + loc, newPadOp, unpackEmpty, innerDimsPos, mixedTiles, outerDimsPerm); + rewriter.replaceAllUsesExcept(padOp, unpackedPad, sourcePack); + } + + // Replace the pack with the new pad. rewriter.replaceOp(packOp, newPadOp.getResult()); + return success(); } diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir index 626dd8b697e59..d9206432379fb 100644 --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -458,23 +458,23 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] // CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32> -// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] +// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] // CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]] // CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> -// CHECK: %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] +// CHECK: %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] // CHECK-SAME: into %[[ARG1_PACK_EMPTY]] // CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> -// CHECK: %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] +// CHECK: %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] // CHECK-SAME: into %[[ARG0_PACK_EMPTY]] // CHECK: %[[RES:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]] // CHECK-SAME: ins(%[[ARG0_PACK]] // CHECK-SAME: outs(%[[ARG1_PACK]] -// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] // CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]] // ----- @@ -537,20 +537,20 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5 // CHECK-LABEL: func.func @forward_tensor_empty // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32> -// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] +// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] // CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]] // CHECK: %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> // CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> -// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] +// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] // CHECK-SAME: into %[[ARG0_PACK_EMPTY]] // CHECK: %[[RES:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]] // CHECK-SAME: ins(%[[PACKED_ARG0]] // CHECK-SAME: outs(%[[DEST]] // CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] // CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]] // ----- @@ -571,8 +571,8 @@ func.func @pad_valid_unpack_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tens // CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0] // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x58x58x64xf32> -// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] // CHECK-SAME: into %[[EMPTY]] : tensor<1x2x58x58x32xf32> -> tensor<1x58x58x64xf32> // ----- @@ -614,8 +614,8 @@ func.func @pad_along_unpacked_dim(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x5 // CHECK: %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>) // CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x56x56x64xf32> -// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] // CHECK-SAME: into %[[EMPTY]] : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32> // CHECK: %[[PADDED:.+]] = tensor.pad %[[UNPACK]] low[0, 1, 1, 1] high[0, 1, 1, 1] @@ -687,6 +687,29 @@ func.func @pad_along_packed_dim(%arg0: tensor<1x60x56x56xf32>) -> tensor<1x2x58x // ----- +func.func @multi_use_pad_pack_propagation(%arg0: tensor<1x64x56x56xf32>) -> (tensor<1x64x58x58xf32>, tensor<1x2x58x58x32xf32>) { + %cst = arith.constant 0.000000e+00 : f32 + %padded = tensor.pad %arg0 low[0, 0, 1, 1] high[0, 0, 1, 1] { + ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): + tensor.yield %cst : f32 + } : tensor<1x64x56x56xf32> to tensor<1x64x58x58xf32> + %0 = tensor.empty() : tensor<1x2x58x58x32xf32> + %1 = tensor.pack %padded inner_dims_pos = [1] inner_tiles = [32] into %0 : tensor<1x64x58x58xf32> -> tensor<1x2x58x58x32xf32> + return %padded, %1 : tensor<1x64x58x58xf32>, tensor<1x2x58x58x32xf32> +} + +// CHECK-LABEL: func.func @multi_use_pad_pack_propagation( +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x64x56x56xf32>) +// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x2x56x56x32xf32> +// CHECK: %[[PACKED:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1] inner_tiles = [32] +// CHECK-SAME: into %[[EMPTY]] : tensor<1x64x56x56xf32> -> tensor<1x2x56x56x32xf32> +// CHECK: %[[PADDED:.+]] = tensor.pad %[[PACKED]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0] +// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[PADDED]] inner_dims_pos = [1] inner_tiles = [32] +// CHECK: return %[[UNPACKED]], %[[PADDED]] + +// ----- + #map0 = affine_map<(d0, d1) -> (d0, d1)> func.func @would_break_dominance(%arg0: tensor<128x256xi32>) -> tensor<4x16x16x32xi32>{ %init = tensor.empty() : tensor<128x256xi32> @@ -713,7 +736,7 @@ func.func @would_break_dominance(%arg0: tensor<128x256xi32>) -> tensor<4x16x16x3 // CHECK-SAME: outs(%[[EMPTY]] // CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor() : tensor<4x16x16x32xi32> // CHECK-NEXT: %{{.+}} = tensor.pack %[[GEN]] -// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32] +// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32] // CHECK-SAME: into %[[ALLOC]] // ----- @@ -760,19 +783,19 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x // CHECK-LABEL: func.func @unpack_empty_inner_dims // CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] -// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] +// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] // CHECK: %[[RES:.+]] = linalg.generic // CHECK-SAME: ins(%[[PACKED_ARG0]] // CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] // ----- #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> #map1 = affine_map<(d0, d1, d2) -> (d0, d1)> -func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>, +func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>, %arg1: tensor<128x256xi32>) -> tensor<4x16x16x32xi32>{ %elem = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<128x256x32xi32>) @@ -810,7 +833,7 @@ func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>, // ----- -func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, +func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>, %init_reduction: tensor<100x128x256xi32>) -> tensor<4x16x100x16x32xi32> { %reduction = linalg.generic { @@ -867,7 +890,7 @@ func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %a #map0 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d3)> -func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32>, +func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32>, %filter: tensor<2x2xi32>) -> tensor<16x540x960xi32>{ %init = tensor.empty() : tensor<16x540x960xi32> %empty = tensor.empty() : tensor<1x16x1080x1920xi32>