diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td index f36b41ccf6745..d1ceac0bff19c 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td @@ -57,7 +57,11 @@ class Linalg_RelayoutOp traits = []> : /// tile factors. DenseMap getDimAndTileMapping(); - /// Return the tile sizes as OpFoldResult. + // TODO: Return the folded result. + /// Return the tile sizes as OpFoldResult. Will return the Value + /// of the constant Op, not the constant Attribute. + /// E.g., for %size = arith.constant 1 : i32 will return %size, + /// not 1. SmallVector getMixedTiles(); /// Return the tile sizes as `int64_t`. If a tile size is dynamic diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index e9a8b253eea35..60219335d6a1c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1146,37 +1146,25 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( Attribute oneIdxAttr = rewriter.getIndexAttr(1); Location loc = packOp.getLoc(); - Value input = getPackOpSourceOrPaddedSource(rewriter, packOp); - DenseMap dimAndTileMapping = - packOp.getDimAndTileMapping(); int64_t srcRank = packOp.getSourceRank(); int64_t destRank = packOp.getDestRank(); - int64_t numTiles = destRank - srcRank; + ArrayRef innerDimsPos = packOp.getInnerDimsPos(); + int64_t numberOfTiles = innerDimsPos.size(); - // 1. Extract the inner tile sizes. - // Where possible, values are replaced with constant attributes (to match the - // behaviour of `getPackOpSourceOrPaddedSource`). - SmallVector tileSizes; - for (auto i : llvm::seq(0, srcRank)) { - if (dimAndTileMapping.count(i)) { - // Rather than taking the tile size as is, extact the actual constant - // value Attribute where possible, e.g.: - // [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8] - auto [_, tileSize] = - getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter); - tileSizes.push_back(tileSize); - } - } + // 1. Get the input that is going to be packed. If the input requires padding, + // add a padding operation and return that as the input. + Value input = getPackOpSourceOrPaddedSource(rewriter, packOp); // 2. Transpose the input to match the inner tile order: // %init = tensor.empty() // %transposed_tile = linalg.transpose ins(%source_or_padded_source), // outs(%init) // Assumptions made: - // 1. All outer dims are 1 - the corresponding transposition order doesn't + // - All outer dims are 1 - the corresponding transposition order doesn't // matter, but requires all dim indices to be present. + + // 2.1 Get the permutation for linalg.transpose SmallVector srcPermForTranspose; - ArrayRef innerDimPos(packOp.getInnerDimsPos()); for (int64_t i = 0; i < srcRank; i++) { // We assume the `k` dimensions of the inner dim position, where `k` is the // rank of the inner tiling, correspond to the last `k` indices of the @@ -1185,27 +1173,34 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( // rank of the source tensor. For example if we have a source tensor with // indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining // indices are [1, 2]. and the transpose will be [1, 2, 3, 0]. - if (llvm::is_contained(innerDimPos, i)) + if (llvm::is_contained(innerDimsPos, i)) continue; srcPermForTranspose.push_back(i); } - srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end()); + srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end()); + + // 2.2 Create the init tensor for linalg.transpose with the correct shape + SmallVector shapeForEmptyOp(srcRank - numberOfTiles, + oneIdxAttr); + shapeForEmptyOp.append(packOp.getMixedTiles()); + + // getMixedTiles() may contain Values pointing to constant ops, not the + // constant attributes. Replace them with a true OpFoldResult. + llvm::transform(shapeForEmptyOp, shapeForEmptyOp.begin(), + [&](OpFoldResult ofr) { + if (auto val = llvm::dyn_cast(ofr)) + return getAsOpFoldResult(val); + return ofr; + }); LDBG() << "Pack permutation: " << packOp; LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose); + LDBG() << "Shape of empty tensor: " << llvm::interleaved(shapeForEmptyOp); - // 2.1 Create tensor.empty (init value for TransposeOp) - SmallVector transShapeForEmptyOp(srcRank - numTiles, - oneIdxAttr); - transShapeForEmptyOp.append(tileSizes); - - applyPermutationToVector(transShapeForEmptyOp, - srcPermForTranspose); - Value empty = - tensor::EmptyOp::create(rewriter, loc, transShapeForEmptyOp, - packOp.getSourceType().getElementType()); + Value empty = tensor::EmptyOp::create( + rewriter, loc, shapeForEmptyOp, packOp.getSourceType().getElementType()); - // 2.2 Create linalg.transpose + // 2.3 Create linalg.transpose auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty, srcPermForTranspose); @@ -1214,8 +1209,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( SmallVector writeStrides(destRank, oneIdxAttr); SmallVector writeOffsets(destRank, zeroIdxAttr); // Outer dims are all 1s! - SmallVector writeSizes(destRank - dimAndTileMapping.size(), - oneIdxAttr); + SmallVector writeSizes(destRank - numberOfTiles, oneIdxAttr); SmallVector writeShape; for (auto tileSize : packOp.getMixedTiles()) { diff --git a/mlir/test/Dialect/Linalg/decompose-pack.mlir b/mlir/test/Dialect/Linalg/decompose-pack.mlir index 17e6c29754f9d..18a09f4c669bb 100644 --- a/mlir/test/Dialect/Linalg/decompose-pack.mlir +++ b/mlir/test/Dialect/Linalg/decompose-pack.mlir @@ -274,3 +274,24 @@ func.func @pack_with_adjacent_trailing_dimensions_inner_dims_pos_and_unit_outer( // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] // CHECK-SAME: [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32> // CHECK: return %[[INSERT]] + +// ----- + +// The following example shows a pack operation where the inner dims +// positions are non-adjacent and non-permuted. +func.func @pack_with_non_adjacent_and_non_permuted_inner_dims(%arg0: tensor<8x1x1x1xf32>, %arg1:tensor<1x1x1x1x8x1xf32>) -> tensor<1x1x1x1x8x1xf32> { + %pack = linalg.pack %arg0 outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [0, 3] inner_tiles = [8, 1] into %arg1: tensor<8x1x1x1xf32> -> tensor<1x1x1x1x8x1xf32> + return %pack : tensor<1x1x1x1x8x1xf32> +} + +// CHECK-LABEL: func.func @pack_with_non_adjacent_and_non_permuted_inner_dims +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x8x1xf32> +// CHECK: %[[TRANSP:.+]] = linalg.transpose +// CHECK-SAME: ins(%[[SRC]] : tensor<8x1x1x1xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x1x8x1xf32>) +// CHECK-SAME: permutation = [1, 2, 0, 3] +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] +// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x1xf32> into tensor<1x1x1x1x8x1xf32> +// CHECK: return %[[INSERT]]