Skip to content

Conversation

maxbartel
Copy link
Contributor

@maxbartel maxbartel commented Sep 23, 2025

The original code seemed to assume that the tiling dimensions for the tensor.empty op before applying the transposing were always the last dimensions. However, pack allows you to choose any dimension to tile.

The easiest way I found to solve this is to prefill the SmallVector with 1s of size (srcRank - numberOfTiles) and then appending the tile sizes.

This way I could also get rid of the first loop in the code.

@llvmbot
Copy link
Member

llvmbot commented Sep 23, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: None (maxbartel)

Changes

The original code seemed to assume that the tiling dimensions for the tensor.empty op before applying the transposing were always the last dimensions. However, pack allows you to choose any dimension to tile.

The easiest way I found to solve this is to prefill the SmallVector with 1s and then replace the tiled dimension with the tile size directly when figuring out the tile size. That way we do not have the need to add another for loop.


Full diff: https://github.com/llvm/llvm-project/pull/160246.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+7-10)
  • (modified) mlir/test/Dialect/Linalg/decompose-pack.mlir (+19)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index e9a8b253eea35..69cbc7048f646 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1151,11 +1151,11 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
       packOp.getDimAndTileMapping();
   int64_t srcRank = packOp.getSourceRank();
   int64_t destRank = packOp.getDestRank();
-  int64_t numTiles = destRank - srcRank;
 
-  // 1. Extract the inner tile sizes.
-  // Where possible, values are replaced with constant attributes (to match the
-  // behaviour of `getPackOpSourceOrPaddedSource`).
+  // 1. Extract the inner tile sizes and the shapes for the tensor.empty op
+  // before transposing. Where possible, values are replaced with constant
+  // attributes (to match the behaviour of `getPackOpSourceOrPaddedSource`).
+  SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank, oneIdxAttr);
   SmallVector<OpFoldResult> tileSizes;
   for (auto i : llvm::seq<unsigned>(0, srcRank)) {
     if (dimAndTileMapping.count(i)) {
@@ -1165,6 +1165,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
       auto [_, tileSize] =
           getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
       tileSizes.push_back(tileSize);
+      transShapeForEmptyOp[i] = tileSize;
     }
   }
 
@@ -1194,18 +1195,14 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
   LDBG() << "Pack permutation: " << packOp;
   LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose);
 
-  // 2.1 Create tensor.empty (init value for TransposeOp)
-  SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles,
-                                                 oneIdxAttr);
-  transShapeForEmptyOp.append(tileSizes);
-
+  // 2.2 Transpose the tensor.empty shapes.
   applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
                                          srcPermForTranspose);
   Value empty =
       tensor::EmptyOp::create(rewriter, loc, transShapeForEmptyOp,
                               packOp.getSourceType().getElementType());
 
-  // 2.2 Create linalg.transpose
+  // 2.3 Create linalg.transpose
   auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
                                                   srcPermForTranspose);
 
diff --git a/mlir/test/Dialect/Linalg/decompose-pack.mlir b/mlir/test/Dialect/Linalg/decompose-pack.mlir
index 17e6c29754f9d..15521d415b8a7 100644
--- a/mlir/test/Dialect/Linalg/decompose-pack.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-pack.mlir
@@ -274,3 +274,22 @@ func.func @pack_with_adjacent_trailing_dimensions_inner_dims_pos_and_unit_outer(
 // CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
 // CHECK-SAME:      [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32>
 // CHECK:         return %[[INSERT]]
+
+// -----
+
+func.func @pack_with_zero_pos_tile_size(%arg0: tensor<8x1x1x1xf32>, %arg1:tensor<1x1x1x1x8x1xf32>) -> tensor<1x1x1x1x8x1xf32> {
+  %pack = linalg.pack %arg0 outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [0, 3] inner_tiles = [8, 1] into %arg1: tensor<8x1x1x1xf32> -> tensor<1x1x1x1x8x1xf32>
+  return %pack : tensor<1x1x1x1x8x1xf32>
+}
+
+// CHECK-LABEL: func.func @pack_with_zero_pos_tile_size
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x8x1xf32>
+// CHECK:         %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME:      ins(%[[SRC]] : tensor<8x1x1x1xf32>)
+// CHECK-SAME:      outs(%[[EMPTY]] : tensor<1x1x8x1xf32>)
+// CHECK-SAME:      permutation = [1, 2, 0, 3]
+// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
+// CHECK-SAME:      [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x1xf32> into tensor<1x1x1x1x8x1xf32>
+// CHECK:         return %[[INSERT]]
\ No newline at end of file

@banach-space
Copy link
Contributor

Thanks for the fix!

The original code seemed to assume that the tiling dimensions for the tensor.empty op before applying the transposing were always the last dimensions.

I think that it's a bit more subtle:

  • If tileSizes are not transposed, then we need to transpose them.
  • If tileSizes are transposed, then we do not need to transpose them.

In your example, the inner tiles are not transposed:

func.func @pack_with_zero_pos_tile_size(%arg0: tensor<8x1x1x1xf32>, %arg1:tensor<1x1x1x1x8x1xf32>) -> tensor<1x1x1x1x8x1xf32> {
  %pack = linalg.pack %arg0 outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [0, 3] inner_tiles = [8, 1] into %arg1: tensor<8x1x1x1xf32> -> tensor<1x1x1x1x8x1xf32>
  return %pack : tensor<1x1x1x1x8x1xf32>
}

However, in this example,, we do transpose the inner tiles:

func.func @simple_pad_and_pack_dynamic_tile_not_all_dims_tiled(%input: tensor<1x1x5x1xf32>, %output: tensor<1x1x1x1x2x?xf32>, %pad: f32, %high: index) -> tensor<1x1x1x1x2x?xf32> {
  %0 = linalg.pack %input padding_value(%pad : f32) outer_dims_perm = [1, 0, 2, 3] inner_dims_pos = [3, 2] inner_tiles = [2, %high] into %output : tensor<1x1x5x1xf32> -> tensor<1x1x1x1x2x?xf32>
  return %0 : tensor<1x1x1x1x2x?xf32>
}

I think that tracking what and when to transpose is fragile. Let me proposed a simpler fix:

SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles, oneIdxAttr);
transShapeForEmptyOp.append(packOp.getMixedTiles());

WDYT? Could this work?

@maxbartel maxbartel requested a review from rengolin as a code owner September 29, 2025 14:35
@maxbartel
Copy link
Contributor Author

Thanks for the review!

With a bit of glue this works and makes it a lot easier. I also noticed that the first loop is not necessary anymore.

I also added a comment to getMixedTiles because the OpFoldResult not return the constant Attribute confused me quite a bit.

@banach-space Can you give this another pass please?

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic, thank for the fix and for simplifying this!

I've left some minor suggestion re comments, but approving as is - the core logic LG.

@MaheshRavishankar
Copy link
Contributor

cc @Max191 can you take a look here?

Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@banach-space banach-space merged commit f53b624 into llvm:main Oct 8, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants