From 9f031bccdd56b02726055d7744de6be649fbf3fc Mon Sep 17 00:00:00 2001 From: Maximilian Bartel Date: Tue, 23 Sep 2025 08:57:22 +0200 Subject: [PATCH 1/3] (linalg.pack): fix empty tensor assumptions --- .../Dialect/Linalg/Transforms/Transforms.cpp | 17 +++++++---------- mlir/test/Dialect/Linalg/decompose-pack.mlir | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index e9a8b253eea35..69cbc7048f646 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1151,11 +1151,11 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( packOp.getDimAndTileMapping(); int64_t srcRank = packOp.getSourceRank(); int64_t destRank = packOp.getDestRank(); - int64_t numTiles = destRank - srcRank; - // 1. Extract the inner tile sizes. - // Where possible, values are replaced with constant attributes (to match the - // behaviour of `getPackOpSourceOrPaddedSource`). + // 1. Extract the inner tile sizes and the shapes for the tensor.empty op + // before transposing. Where possible, values are replaced with constant + // attributes (to match the behaviour of `getPackOpSourceOrPaddedSource`). + SmallVector transShapeForEmptyOp(srcRank, oneIdxAttr); SmallVector tileSizes; for (auto i : llvm::seq(0, srcRank)) { if (dimAndTileMapping.count(i)) { @@ -1165,6 +1165,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( auto [_, tileSize] = getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter); tileSizes.push_back(tileSize); + transShapeForEmptyOp[i] = tileSize; } } @@ -1194,18 +1195,14 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( LDBG() << "Pack permutation: " << packOp; LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose); - // 2.1 Create tensor.empty (init value for TransposeOp) - SmallVector transShapeForEmptyOp(srcRank - numTiles, - oneIdxAttr); - transShapeForEmptyOp.append(tileSizes); - + // 2.2 Transpose the tensor.empty shapes. applyPermutationToVector(transShapeForEmptyOp, srcPermForTranspose); Value empty = tensor::EmptyOp::create(rewriter, loc, transShapeForEmptyOp, packOp.getSourceType().getElementType()); - // 2.2 Create linalg.transpose + // 2.3 Create linalg.transpose auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty, srcPermForTranspose); diff --git a/mlir/test/Dialect/Linalg/decompose-pack.mlir b/mlir/test/Dialect/Linalg/decompose-pack.mlir index 17e6c29754f9d..15521d415b8a7 100644 --- a/mlir/test/Dialect/Linalg/decompose-pack.mlir +++ b/mlir/test/Dialect/Linalg/decompose-pack.mlir @@ -274,3 +274,22 @@ func.func @pack_with_adjacent_trailing_dimensions_inner_dims_pos_and_unit_outer( // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] // CHECK-SAME: [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32> // CHECK: return %[[INSERT]] + +// ----- + +func.func @pack_with_zero_pos_tile_size(%arg0: tensor<8x1x1x1xf32>, %arg1:tensor<1x1x1x1x8x1xf32>) -> tensor<1x1x1x1x8x1xf32> { + %pack = linalg.pack %arg0 outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [0, 3] inner_tiles = [8, 1] into %arg1: tensor<8x1x1x1xf32> -> tensor<1x1x1x1x8x1xf32> + return %pack : tensor<1x1x1x1x8x1xf32> +} + +// CHECK-LABEL: func.func @pack_with_zero_pos_tile_size +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x8x1xf32> +// CHECK: %[[TRANSP:.+]] = linalg.transpose +// CHECK-SAME: ins(%[[SRC]] : tensor<8x1x1x1xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x1x8x1xf32>) +// CHECK-SAME: permutation = [1, 2, 0, 3] +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] +// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x1xf32> into tensor<1x1x1x1x8x1xf32> +// CHECK: return %[[INSERT]] \ No newline at end of file From eae9f9946e30104dbb4e7a86b96ed3a735700929 Mon Sep 17 00:00:00 2001 From: Maximilian Bartel Date: Mon, 29 Sep 2025 16:35:46 +0200 Subject: [PATCH 2/3] (linalg.pack): simplify outer dims patterns after review --- .../Dialect/Linalg/IR/LinalgRelayoutOps.td | 3 +- .../Dialect/Linalg/Transforms/Transforms.cpp | 59 +++++++++---------- 2 files changed, 30 insertions(+), 32 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td index f36b41ccf6745..5006d815a798a 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td @@ -57,7 +57,8 @@ class Linalg_RelayoutOp traits = []> : /// tile factors. DenseMap getDimAndTileMapping(); - /// Return the tile sizes as OpFoldResult. + /// Return the tile sizes as OpFoldResult. Will return the Value + /// of the constant Op, not the constant Attribute. SmallVector getMixedTiles(); /// Return the tile sizes as `int64_t`. If a tile size is dynamic diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 69cbc7048f646..60219335d6a1c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1146,38 +1146,25 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( Attribute oneIdxAttr = rewriter.getIndexAttr(1); Location loc = packOp.getLoc(); - Value input = getPackOpSourceOrPaddedSource(rewriter, packOp); - DenseMap dimAndTileMapping = - packOp.getDimAndTileMapping(); int64_t srcRank = packOp.getSourceRank(); int64_t destRank = packOp.getDestRank(); + ArrayRef innerDimsPos = packOp.getInnerDimsPos(); + int64_t numberOfTiles = innerDimsPos.size(); - // 1. Extract the inner tile sizes and the shapes for the tensor.empty op - // before transposing. Where possible, values are replaced with constant - // attributes (to match the behaviour of `getPackOpSourceOrPaddedSource`). - SmallVector transShapeForEmptyOp(srcRank, oneIdxAttr); - SmallVector tileSizes; - for (auto i : llvm::seq(0, srcRank)) { - if (dimAndTileMapping.count(i)) { - // Rather than taking the tile size as is, extact the actual constant - // value Attribute where possible, e.g.: - // [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8] - auto [_, tileSize] = - getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter); - tileSizes.push_back(tileSize); - transShapeForEmptyOp[i] = tileSize; - } - } + // 1. Get the input that is going to be packed. If the input requires padding, + // add a padding operation and return that as the input. + Value input = getPackOpSourceOrPaddedSource(rewriter, packOp); // 2. Transpose the input to match the inner tile order: // %init = tensor.empty() // %transposed_tile = linalg.transpose ins(%source_or_padded_source), // outs(%init) // Assumptions made: - // 1. All outer dims are 1 - the corresponding transposition order doesn't + // - All outer dims are 1 - the corresponding transposition order doesn't // matter, but requires all dim indices to be present. + + // 2.1 Get the permutation for linalg.transpose SmallVector srcPermForTranspose; - ArrayRef innerDimPos(packOp.getInnerDimsPos()); for (int64_t i = 0; i < srcRank; i++) { // We assume the `k` dimensions of the inner dim position, where `k` is the // rank of the inner tiling, correspond to the last `k` indices of the @@ -1186,21 +1173,32 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( // rank of the source tensor. For example if we have a source tensor with // indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining // indices are [1, 2]. and the transpose will be [1, 2, 3, 0]. - if (llvm::is_contained(innerDimPos, i)) + if (llvm::is_contained(innerDimsPos, i)) continue; srcPermForTranspose.push_back(i); } - srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end()); + srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end()); + + // 2.2 Create the init tensor for linalg.transpose with the correct shape + SmallVector shapeForEmptyOp(srcRank - numberOfTiles, + oneIdxAttr); + shapeForEmptyOp.append(packOp.getMixedTiles()); + + // getMixedTiles() may contain Values pointing to constant ops, not the + // constant attributes. Replace them with a true OpFoldResult. + llvm::transform(shapeForEmptyOp, shapeForEmptyOp.begin(), + [&](OpFoldResult ofr) { + if (auto val = llvm::dyn_cast(ofr)) + return getAsOpFoldResult(val); + return ofr; + }); LDBG() << "Pack permutation: " << packOp; LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose); + LDBG() << "Shape of empty tensor: " << llvm::interleaved(shapeForEmptyOp); - // 2.2 Transpose the tensor.empty shapes. - applyPermutationToVector(transShapeForEmptyOp, - srcPermForTranspose); - Value empty = - tensor::EmptyOp::create(rewriter, loc, transShapeForEmptyOp, - packOp.getSourceType().getElementType()); + Value empty = tensor::EmptyOp::create( + rewriter, loc, shapeForEmptyOp, packOp.getSourceType().getElementType()); // 2.3 Create linalg.transpose auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty, @@ -1211,8 +1209,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( SmallVector writeStrides(destRank, oneIdxAttr); SmallVector writeOffsets(destRank, zeroIdxAttr); // Outer dims are all 1s! - SmallVector writeSizes(destRank - dimAndTileMapping.size(), - oneIdxAttr); + SmallVector writeSizes(destRank - numberOfTiles, oneIdxAttr); SmallVector writeShape; for (auto tileSize : packOp.getMixedTiles()) { From 1f44ee147a0aabd22bfa25df575a2af5e74a3016 Mon Sep 17 00:00:00 2001 From: Maximilian Bartel Date: Wed, 8 Oct 2025 11:44:33 +0200 Subject: [PATCH 3/3] (linalg.pack): clarify comments and test names --- mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td | 3 +++ mlir/test/Dialect/Linalg/decompose-pack.mlir | 8 +++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td index 5006d815a798a..d1ceac0bff19c 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td @@ -57,8 +57,11 @@ class Linalg_RelayoutOp traits = []> : /// tile factors. DenseMap getDimAndTileMapping(); + // TODO: Return the folded result. /// Return the tile sizes as OpFoldResult. Will return the Value /// of the constant Op, not the constant Attribute. + /// E.g., for %size = arith.constant 1 : i32 will return %size, + /// not 1. SmallVector getMixedTiles(); /// Return the tile sizes as `int64_t`. If a tile size is dynamic diff --git a/mlir/test/Dialect/Linalg/decompose-pack.mlir b/mlir/test/Dialect/Linalg/decompose-pack.mlir index 15521d415b8a7..18a09f4c669bb 100644 --- a/mlir/test/Dialect/Linalg/decompose-pack.mlir +++ b/mlir/test/Dialect/Linalg/decompose-pack.mlir @@ -277,12 +277,14 @@ func.func @pack_with_adjacent_trailing_dimensions_inner_dims_pos_and_unit_outer( // ----- -func.func @pack_with_zero_pos_tile_size(%arg0: tensor<8x1x1x1xf32>, %arg1:tensor<1x1x1x1x8x1xf32>) -> tensor<1x1x1x1x8x1xf32> { +// The following example shows a pack operation where the inner dims +// positions are non-adjacent and non-permuted. +func.func @pack_with_non_adjacent_and_non_permuted_inner_dims(%arg0: tensor<8x1x1x1xf32>, %arg1:tensor<1x1x1x1x8x1xf32>) -> tensor<1x1x1x1x8x1xf32> { %pack = linalg.pack %arg0 outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [0, 3] inner_tiles = [8, 1] into %arg1: tensor<8x1x1x1xf32> -> tensor<1x1x1x1x8x1xf32> return %pack : tensor<1x1x1x1x8x1xf32> } -// CHECK-LABEL: func.func @pack_with_zero_pos_tile_size +// CHECK-LABEL: func.func @pack_with_non_adjacent_and_non_permuted_inner_dims // CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x8x1xf32> @@ -292,4 +294,4 @@ func.func @pack_with_zero_pos_tile_size(%arg0: tensor<8x1x1x1xf32>, %arg1:tensor // CHECK-SAME: permutation = [1, 2, 0, 3] // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] // CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x1xf32> into tensor<1x1x1x1x8x1xf32> -// CHECK: return %[[INSERT]] \ No newline at end of file +// CHECK: return %[[INSERT]]