diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td index f36b41ccf6745..3390f380c7eb8 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td @@ -239,6 +239,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [ ArrayRef outerDimsPerm, ArrayRef innerTiles); + // Same as above function but here dynamic dimensions are assumed + // to require padding. + static bool requirePaddingValueStrict(ArrayRef inputShape, + ArrayRef innerDimsPos, + ArrayRef outputShape, + ArrayRef outerDimsPerm, + ArrayRef innerTiles); + static Value createDestinationTensor(OpBuilder &b, Location loc, Value source, ArrayRef innerTileSizes, ArrayRef innerDimsPos, ArrayRef outerDimsPerm); diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 64d3a2448b409..41670249936e6 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1914,9 +1914,12 @@ void populateElementwiseOpsFusionPatterns( using ControlPropagationFn = std::function; /// Patterns to bubble up or down data layout ops across other operations. +/// The function also has an option to allow the patterns to propagate with +/// poison padding if requested by the caller. void populateDataLayoutPropagationPatterns( RewritePatternSet &patterns, - const ControlPropagationFn &controlPackUnPackPropagation); + const ControlPropagationFn &controlPackUnPackPropagation, + bool PoisonPaddingOk = false); /// Patterns to sink extract slice across other operations. void populateExtractSliceSinkingPatterns( diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 578931e1351c6..49c2b54748c29 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -5310,6 +5310,32 @@ bool PackOp::requirePaddingValue(ArrayRef inputShape, return false; } +bool PackOp::requirePaddingValueStrict(ArrayRef inputShape, + ArrayRef innerDimsPos, + ArrayRef outputShape, + ArrayRef outerDimsPerm, + ArrayRef innerTiles) { + SmallVector outputTileSizes( + outputShape.take_front(inputShape.size())); + if (!outerDimsPerm.empty()) { + assert(outerDimsPerm.size() == outputTileSizes.size() && + "expected output and outer_dims_perm to have same size"); + applyPermutationToVector(outputTileSizes, + invertPermutationVector(outerDimsPerm)); + } + for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) { + if (ShapedType::isDynamic(inputShape[pos]) || + ShapedType::isDynamic(outputTileSizes[pos])) + return true; + std::optional constantTile = getConstantIntValue(tileSize); + if (!constantTile) + return true; + if (inputShape[pos] % (*constantTile) != 0) + return true; + } + return false; +} + LogicalResult PackOp::verify() { if (failed(commonVerifierPackAndUnPackOp(*this))) return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 6c17c3c2d0cab..3bb5f8af821c0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/Dominance.h" +#include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/TypeSwitch.h" @@ -189,40 +190,20 @@ static SmallVector computeOuterDims(ArrayRef perm, return outerDimsPerm; } -/// Returns a tuple for packed operand and indexing_map with the assumptions: -/// 1) The generic op is the producer of the pack op. -/// 2) The generic op has only one result. -/// If the operand is a scalar or packing dimensions are all irrelevant to the -/// operand, the operand and the updated indexing map will be returned. -/// Otherwise, it returns the packed operand and the updated indexing map. E.g., -/// -/// #map0 = affine_map<(d0, d1) -> (d0, d1)> -/// #map1 = affine_map<(d0, d1) -> (d0)> -/// #map2 = affine_map<(d0, d1) -> (d1)> -/// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0], -/// iterator_types = ["parallel", "parallel"]} -/// ins(%arg0, %arg1 : tensor, tensor) -/// outs(%init : tensor) { -/// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): -/// %4 = arith.addf %arg3, %arg4 : f32 -/// linalg.yield %4 : f32 -/// } -> tensor -/// %1 = linalg.pack %0 -/// inner_dims_pos = [0, 1] -/// inner_tiles = [8, 2] -/// into %dest : tensor -> tensor -/// -/// Taking the first input operand as an example, the inner tile size of d1 is -/// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> -> -/// affine_map<(d1, d3)>` will be returned. -/// -/// %pack = linalg.pack %arg0 -/// inner_dims_pos = [0] -/// inner_tiles = [8] -/// into %init : tensor -> tensor -static std::tuple -getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, - GenericOp genericOp, OpOperand *opOperand) { +struct PackedOperandDetails { + SmallVector innerTileSizes; + SmallVector innerDimsPos; + SmallVector outerDimsPerm; + AffineMap indexingMap; +}; + +/// Helper function for getOrCreatePackedViewOfOperand that populates +/// the details of the packedOperand that needs to be formed and also +/// returns if the packing would require padding. +static bool getPackedOperandDetails( + OpBuilder &b, PackInfo packInfo, GenericOp genericOp, OpOperand *opOperand, + DenseMap &packedOperandMap) { + PackedOperandDetails currOperandDetails; int64_t numOrigLoops = genericOp.getNumLoops(); int64_t numInnerLoops = packInfo.getNumTiledLoops(); int64_t numLoops = numOrigLoops + numInnerLoops; @@ -231,9 +212,12 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, SmallVector exprs(origIndexingMap.getResults()); // If the OpOperand is a scalar or a zero-rank tensor, no need to pack. - if (genericOp.isScalar(opOperand) || exprs.empty()) - return std::make_tuple(opOperand->get(), - AffineMap::get(numLoops, 0, exprs, b.getContext())); + if (genericOp.isScalar(opOperand) || exprs.empty()) { + currOperandDetails.indexingMap = + AffineMap::get(numLoops, 0, exprs, b.getContext()); + packedOperandMap[opOperand] = currOperandDetails; + return false; + } // Step 1. Construct the information of packing data dimensions; append inner // dimensions to the indexing maps for the operand. @@ -281,18 +265,86 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, exprs = auxVec; } } - auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext()); + currOperandDetails.indexingMap = + AffineMap::get(numLoops, 0, exprs, b.getContext()); // The operand does not have dimensions that relates to pack op. + if (innerDimsPos.empty() && outerDimsPerm.empty()) { + packedOperandMap[opOperand] = currOperandDetails; + return false; + } + auto inputType = cast(opOperand->get().getType()); + + auto maybeIntInnerTileSizes = + llvm::map_to_vector(innerTileSizes, [](OpFoldResult ofr) -> int64_t { + std::optional maybeCst = getConstantIntValue(ofr); + return maybeCst.value_or(ShapedType::kDynamic); + }); + bool requirePadding = linalg::PackOp::requirePaddingValueStrict( + inputType.getShape(), innerDimsPos, + linalg::PackOp::inferPackedType(inputType, maybeIntInnerTileSizes, + innerDimsPos, outerDimsPerm) + .getShape(), + outerDimsPerm, innerTileSizes); + currOperandDetails.innerDimsPos = innerDimsPos; + currOperandDetails.innerTileSizes = innerTileSizes; + currOperandDetails.outerDimsPerm = outerDimsPerm; + packedOperandMap[opOperand] = currOperandDetails; + + return requirePadding; +} + +/// Returns a tuple for packed operand and indexing_map with the assumptions: +/// 1) The generic op is the producer of the pack op. +/// 2) The generic op has only one result. +/// If the operand is a scalar or packing dimensions are all irrelevant to the +/// operand, the operand and the updated indexing map will be returned. +/// Otherwise, it returns the packed operand and the updated indexing map. E.g., +/// +/// #map0 = affine_map<(d0, d1) -> (d0, d1)> +/// #map1 = affine_map<(d0, d1) -> (d0)> +/// #map2 = affine_map<(d0, d1) -> (d1)> +/// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0], +/// iterator_types = ["parallel", "parallel"]} +/// ins(%arg0, %arg1 : tensor, tensor) +/// outs(%init : tensor) { +/// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): +/// %4 = arith.addf %arg3, %arg4 : f32 +/// linalg.yield %4 : f32 +/// } -> tensor +/// %1 = linalg.pack %0 +/// inner_dims_pos = [0, 1] +/// inner_tiles = [8, 2] +/// into %dest : tensor -> tensor +/// +/// Taking the first input operand as an example, the inner tile size of d1 is +/// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> -> +/// affine_map<(d1, d3)>` will be returned. +/// +/// %pack = linalg.pack %arg0 +/// inner_dims_pos = [0] +/// inner_tiles = [8] +/// into %init : tensor -> tensor +static std::tuple getOrCreatePackedViewOfOperand( + OpBuilder &b, Location loc, OpOperand *opOperand, + const DenseMap &packedOperandMap) { + assert(packedOperandMap.contains(opOperand) && + "packed operand details expected to be populated"); + auto currOperandDetails = packedOperandMap.at(opOperand); + auto innerDimsPos = currOperandDetails.innerDimsPos; + auto outerDimsPerm = currOperandDetails.outerDimsPerm; + auto innerTileSizes = currOperandDetails.innerTileSizes; if (innerDimsPos.empty() && outerDimsPerm.empty()) - return std::make_tuple(opOperand->get(), indexingMap); + return std::make_tuple(opOperand->get(), currOperandDetails.indexingMap); auto empty = linalg::PackOp::createDestinationTensor( b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm); - auto packedOperand = linalg::PackOp::create( - b, loc, opOperand->get(), empty, innerDimsPos, innerTileSizes, - /*padding=*/std::nullopt, outerDimsPerm); - return std::make_tuple(packedOperand, indexingMap); + auto poison = ub::PoisonOp::create( + b, loc, getElementTypeOrSelf(opOperand->get().getType())); + Value packedOperand = + linalg::PackOp::create(b, loc, opOperand->get(), empty, innerDimsPos, + innerTileSizes, poison, outerDimsPerm); + return std::make_tuple(packedOperand, currOperandDetails.indexingMap); } /// This function is a helper subroutine to pack a genericOp and return it. It @@ -301,10 +353,10 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, /// around it. Implicitly this will only work when a packInfo can be obtained. /// This make sure that we are only using this function on parallel permuted /// dimensions. -static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, - Value dest, AffineMap packedOutIndexingMap, - const PackInfo &packInfo, - bool isFoldableUnpackPack) { +static FailureOr +packGenericOp(RewriterBase &rewriter, GenericOp genericOp, Value dest, + AffineMap packedOutIndexingMap, const PackInfo &packInfo, + bool isFoldableUnpackPack, bool poisonPaddingOk) { Location loc = genericOp.getLoc(); SmallVector inputOperands; SmallVector inputOperandsFromUnpackedSource; @@ -314,9 +366,18 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, packOp.getInnerDimsPos() == unPackOp.getInnerDimsPos() && llvm::equal(packOp.getMixedTiles(), unPackOp.getMixedTiles()); }; + DenseMap packedOperandMap; + bool requiresPadding = false; + for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { + requiresPadding |= getPackedOperandDetails(rewriter, packInfo, genericOp, + inputOperand, packedOperandMap); + } + if (requiresPadding && !poisonPaddingOk) + return failure(); + for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand( - rewriter, loc, packInfo, genericOp, inputOperand); + rewriter, loc, inputOperand, packedOperandMap); auto unpackOp = inputOperand->get().getDefiningOp(); auto packOp = packedOperand.getDefiningOp(); if (packOp && unpackOp && hasEquivalentTiles(packOp, unpackOp)) { @@ -407,7 +468,8 @@ static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) { /// } -> tensor static FailureOr bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp, - const ControlPropagationFn &controlFn) { + const ControlPropagationFn &controlFn, + bool poisonPaddingOk) { auto genericOp = packOp.getSource().getDefiningOp(); if (!genericOp) return failure(); @@ -470,10 +532,15 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp, } // Rebuild the indexing map for the corresponding init operand. - auto [packedOutOperand, packedOutIndexingMap] = - getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, - genericOp, opOperand); + DenseMap packedOperandMap; + bool requiresPadding = getPackedOperandDetails(rewriter, *packInfo, genericOp, + opOperand, packedOperandMap); + if (requiresPadding && !poisonPaddingOk) + return failure(); + auto [packedOutOperand, packedOutIndexingMap] = + getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), opOperand, + packedOperandMap); // Forward the new tensor.empty as a destination if it is one of the following // situations: // 1) The dps init operand is a tensor.empty. @@ -488,7 +555,8 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp, // pack(unpack) isn't naively foldable because the unpack op can be from // an arbitrary domain so we need to keep both. return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, - *packInfo, /*isFoldableUnpackPack=*/false); + *packInfo, /*isFoldableUnpackPack=*/false, + poisonPaddingOk); } /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method. @@ -496,13 +564,15 @@ struct BubbleUpPackOpThroughGenericOpPattern : public OpRewritePattern { public: BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context, - ControlPropagationFn fun) - : OpRewritePattern(context), controlFn(std::move(fun)) {} + ControlPropagationFn fun, + bool poisonPaddingOk) + : OpRewritePattern(context), controlFn(std::move(fun)), + poisonPaddingOk(std::move(poisonPaddingOk)) {} LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override { - auto genericOp = - bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn); + auto genericOp = bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn, + poisonPaddingOk); if (failed(genericOp)) return failure(); rewriter.replaceOp(packOp, genericOp->getResults()); @@ -511,6 +581,7 @@ struct BubbleUpPackOpThroughGenericOpPattern private: ControlPropagationFn controlFn; + bool poisonPaddingOk; }; /// Propagate a linalg.pack operation up through a tensor.pad. The idea is to @@ -1080,7 +1151,8 @@ static FailureOr getUnPackedOperand(GenericOp genericOp) { /// static FailureOr> pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, - ControlPropagationFn controlFn) { + ControlPropagationFn controlFn, + bool poisonPaddingOk) { if (genericOp.getNumResults() != 1) return failure(); @@ -1107,9 +1179,17 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, return failure(); // Rebuild the indexing map for the corresponding init operand. + DenseMap packedOperandMap; + bool requiresPadding = + getPackedOperandDetails(rewriter, *packInfo, genericOp, + genericOp.getDpsInitOperand(0), packedOperandMap); + if (requiresPadding && !poisonPaddingOk) + return failure(); + auto [packedOutOperand, packedOutIndexingMap] = - getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, - genericOp, genericOp.getDpsInitOperand(0)); + getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), + genericOp.getDpsInitOperand(0), + packedOperandMap); auto destPack = packedOutOperand.getDefiningOp(); // Forward the new tensor.empty as a destination if it is one of the following @@ -1129,9 +1209,12 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, // pack(unpack) is foldable in this case. This is because in pushing down the // unpack, by default we will populate an additional pack op after the unpack. // This guarantees them to be foldable. - GenericOp newGenericOp = + auto maybeGenericOp = packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo, - /*isFoldableUnpackPack=*/true); + /*isFoldableUnpackPack=*/true, poisonPaddingOk); + if (failed(maybeGenericOp)) + return failure(); + GenericOp newGenericOp = *maybeGenericOp; Value newResult = newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)); @@ -1157,13 +1240,15 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern { public: PushDownUnPackOpThroughGenericOp(MLIRContext *context, - ControlPropagationFn fun) - : OpRewritePattern(context), controlFn(std::move(fun)) {} + ControlPropagationFn fun, + bool poisonPaddingOk) + : OpRewritePattern(context), controlFn(std::move(fun)), + poisonPaddingOk(std::move(poisonPaddingOk)) {} LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - auto genericAndRepl = - pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn); + auto genericAndRepl = pushDownUnPackOpThroughGenericOp( + rewriter, genericOp, controlFn, poisonPaddingOk); if (failed(genericAndRepl)) return failure(); rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl)); @@ -1172,6 +1257,7 @@ struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern { private: ControlPropagationFn controlFn; + bool poisonPaddingOk; }; /// Propagate a linalg.unpack operation through a tensor.pad. The idea is to @@ -1522,12 +1608,14 @@ class PushDownExtractSliceOpThroughGenericOp final void mlir::linalg::populateDataLayoutPropagationPatterns( RewritePatternSet &patterns, - const ControlPropagationFn &controlPackUnPackPropagation) { - patterns - .insert( - patterns.getContext(), controlPackUnPackPropagation); + const ControlPropagationFn &controlPackUnPackPropagation, + bool PoisonPaddingOk) { + patterns.insert( + patterns.getContext(), controlPackUnPackPropagation); + patterns.insert( + patterns.getContext(), controlPackUnPackPropagation, PoisonPaddingOk); } void mlir::linalg::populateExtractSliceSinkingPatterns( diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir index a5f8d63a3e912..7a16bc0a4faee 100644 --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -1450,6 +1450,33 @@ func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %ar // ----- +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @push_unpack_in_padded_domain_multiple_inputs(%arg0: tensor<1x4x16x16xf32>, %arg1: tensor<8x64xf32>, %arg2: tensor<8x64xf32>) -> tensor<8x64xf32> { + %0 = tensor.empty() : tensor<8x64xf32> + %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %0 : tensor<1x4x16x16xf32> -> tensor<8x64xf32> + %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg1, %unpack : tensor<8x64xf32>, tensor<8x64xf32>) outs(%arg2 : tensor<8x64xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %2 = arith.addf %in, %in_0 : f32 + linalg.yield %2 : f32 + } -> tensor<8x64xf32> + return %1 : tensor<8x64xf32> +} +// CHECK-LABEL: func.func @push_unpack_in_padded_domain_multiple_inputs +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[POISON:.+]] = ub.poison : f32 +// CHECK: %[[PACK:.+]] = linalg.pack %[[ARG1]] padding_value(%[[POISON]] : f32) +// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [16, 16] +// CHECK: %[[ELEM:.+]] = linalg.generic +// CHECK: ins(%[[PACK]], %[[ARG0]] +// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ELEM]] +// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [16, 16] +// CHECK-SAME: into %[[ARG2]] +// CHECK: return %[[UNPACK]] + +// ----- + module { func.func @push_extract_through_generic(%arg0: tensor<128x7x128xf32>, %arg1: tensor, %arg2: tensor, %arg3: index) -> tensor { %extracted_slice = tensor.extract_slice %arg0[0, 0, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32> @@ -1473,7 +1500,7 @@ module { // CHECK: } : tensor to tensor // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<128x5x128xbf16> // CHECK: %[[GENERIC:.+]] = linalg.generic -// CHECK-SAME: ins(%[[ARG0]], %[[PADDED]] +// CHECK-SAME: ins(%[[ARG0]], %[[PADDED]] // CHECK-SAME: outs(%[[EMPTY]] // CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %3[%[[ARG3]], 0, 0] [%[[ARG3]], 5, 128] [1, 1, 1] : tensor<128x5x128xbf16> to tensor // CHECK: return %[[EXTRACT]] @@ -1492,7 +1519,7 @@ func.func @nopush_extract_through_generic_nodimexpr1(%arg0: tensor<128x7x128xf32 // CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr1 // CHECK: %[[GENERIC:.+]] = linalg.generic -// CHECK: return %[[GENERIC]] +// CHECK: return %[[GENERIC]] // ----- @@ -1508,7 +1535,7 @@ func.func @nopush_extract_through_generic_nodimexpr2(%arg0: tensor<128x?x128xf32 // CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr2 // CHECK: %[[GENERIC:.+]] = linalg.generic -// CHECK: return %[[GENERIC]] +// CHECK: return %[[GENERIC]] // ----- @@ -1575,7 +1602,7 @@ func.func @push_extract_through_generic_rank0_operand(%arg0: tensor<128x128xf32> // CHECK-LABEL: func.func @push_extract_through_generic_rank0_operand // CHECK: %[[GENERIC:.+]] = linalg.generic -// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[GENERIC]] +// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[GENERIC]] // CHECK: return %[[EXTRACT]] // ----- diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp index d332270468ea8..d45aaf788f9c2 100644 --- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp @@ -33,7 +33,8 @@ struct TestDataLayoutPropagationPass MLIRContext *context = &getContext(); RewritePatternSet patterns(context); linalg::populateDataLayoutPropagationPatterns( - patterns, [](OpOperand *opOperand) { return true; }); + patterns, [](OpOperand *opOperand) { return true; }, + /*poisonPaddingOk=*/true); linalg::ControlPropagationFn controlExtract = [](OpOperand *opOperand) -> bool { Operation *producer = opOperand->get().getDefiningOp();