Skip to content

Conversation

fabrizio-indirli
Copy link
Contributor

In the im2col decomposition, propagate the filter tensor encoding (if specified) through the tensor.collapse_shape op, so that it can be used by the consuming linalg.generic matmul op.

@fabrizio-indirli
Copy link
Contributor Author

CC @GeorgeARM @FranklandJack

@llvmbot
Copy link
Member

llvmbot commented Sep 22, 2025

@llvm/pr-subscribers-mlir

Author: None (fabrizio-indirli)

Changes

In the im2col decomposition, propagate the filter tensor encoding (if specified) through the tensor.collapse_shape op, so that it can be used by the consuming linalg.generic matmul op.


Full diff: https://github.com/llvm/llvm-project/pull/160099.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp (+21-3)
  • (modified) mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir (+29-1)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index 108abe800b13e..12e2b6f5c3f0e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -155,10 +155,16 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
 
   Location loc = convOp.getLoc();
 
+  if (!isa<RankedTensorType>(filterType))
+    return rewriter.notifyMatchFailure(
+        convOp, "expected filter type to be a ranked tensor");
+  auto tensorFilterType = cast<RankedTensorType>(filterType);
+
   // Reshape output and filter to the LHS and result of a (B)MNK matmul.
   SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
   auto reshapedFilterType =
-      RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType());
+      RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType(),
+                            tensorFilterType.getEncoding());
   Value reshapedFilter = tensor::CollapseShapeOp::create(
       rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
 
@@ -435,9 +441,15 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
   auto loc = convOp.getLoc();
   MLIRContext *context = rewriter.getContext();
 
+  if (!isa<RankedTensorType>(filterType))
+    return rewriter.notifyMatchFailure(
+        convOp, "expected filter type to be a ranked tensor");
+  auto tensorFilterType = cast<RankedTensorType>(filterType);
+
   SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
   auto reshapedFilterType =
-      RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType());
+      RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType(),
+                            tensorFilterType.getEncoding());
   Value reshapedFilter = tensor::CollapseShapeOp::create(
       rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
 
@@ -560,11 +572,17 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
 
   Location loc = convOp.getLoc();
 
+  if (!isa<RankedTensorType>(filterType))
+    return rewriter.notifyMatchFailure(
+        convOp, "expected filter type to be a ranked tensor");
+  auto tensorFilterType = cast<RankedTensorType>(filterType);
+
   // Reshape output and filter to the LHS and result of a "row-wise" matrix
   // multiplication.
   SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
   auto reshapedFilterType =
-      RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType());
+      RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType(),
+                            tensorFilterType.getEncoding());
   Value reshapedFilter = tensor::CollapseShapeOp::create(
       rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
 
diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
index 8627fcd2576b9..af911e3b3919a 100644
--- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
+++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
@@ -290,7 +290,7 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
 //      CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
 //      CHECK: %[[COL_TENSOR:.+]] = linalg.generic
-//           CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
+//           CHECK-SAME: [#[[MAP0]], #[[MAP1]]], {{.*}} ins(%[[INPUT]] : tensor<1x16x16x4xf32>) outs(%[[INIT_COL_TENSOR]] : tensor<1x196x36xf32>)
 //                CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
 //                CHECK: linalg.yield %{{.+}} : f32
 //      CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
@@ -327,6 +327,34 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK: func.func @conv2d_decompose_im2col_with_filter_encoding
+// CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32>,
+// CHECK-SAME: %[[FILTER:.*]]: tensor<16x3x3x4xf32, 42 : i32>,
+// CHECK-SAME: %[[OUTPUT:.*]]: tensor<1x14x14x16xf32>
+//  CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]]
+  // CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] : tensor<16x3x3x4xf32, 42 : i32> into tensor<16x36xf32, 42 : i32>
+//  CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
+//  CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
+//  CHECK: %[[COL_TENSOR:.+]] = linalg.generic {{.*}} ins(%[[INPUT]] : tensor<1x16x16x4xf32>) outs(%[[INIT_COL_TENSOR]] : tensor<1x196x36xf32>)
+//  CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
+func.func @conv2d_decompose_im2col_with_filter_encoding(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32, 42 : i32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+    %0 = linalg.conv_2d_nhwc_fhwc
+      { dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+      ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32, 42 : i32>)
+      outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+    return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
 // Check for signed extend when the input type is smaller than the accumulator type.
 
 // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>

@llvmbot
Copy link
Member

llvmbot commented Sep 22, 2025

@llvm/pr-subscribers-mlir-linalg

Author: None (fabrizio-indirli)

Changes

In the im2col decomposition, propagate the filter tensor encoding (if specified) through the tensor.collapse_shape op, so that it can be used by the consuming linalg.generic matmul op.


Full diff: https://github.com/llvm/llvm-project/pull/160099.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp (+21-3)
  • (modified) mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir (+29-1)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index 108abe800b13e..12e2b6f5c3f0e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -155,10 +155,16 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
 
   Location loc = convOp.getLoc();
 
+  if (!isa<RankedTensorType>(filterType))
+    return rewriter.notifyMatchFailure(
+        convOp, "expected filter type to be a ranked tensor");
+  auto tensorFilterType = cast<RankedTensorType>(filterType);
+
   // Reshape output and filter to the LHS and result of a (B)MNK matmul.
   SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
   auto reshapedFilterType =
-      RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType());
+      RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType(),
+                            tensorFilterType.getEncoding());
   Value reshapedFilter = tensor::CollapseShapeOp::create(
       rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
 
@@ -435,9 +441,15 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
   auto loc = convOp.getLoc();
   MLIRContext *context = rewriter.getContext();
 
+  if (!isa<RankedTensorType>(filterType))
+    return rewriter.notifyMatchFailure(
+        convOp, "expected filter type to be a ranked tensor");
+  auto tensorFilterType = cast<RankedTensorType>(filterType);
+
   SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
   auto reshapedFilterType =
-      RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType());
+      RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType(),
+                            tensorFilterType.getEncoding());
   Value reshapedFilter = tensor::CollapseShapeOp::create(
       rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
 
@@ -560,11 +572,17 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
 
   Location loc = convOp.getLoc();
 
+  if (!isa<RankedTensorType>(filterType))
+    return rewriter.notifyMatchFailure(
+        convOp, "expected filter type to be a ranked tensor");
+  auto tensorFilterType = cast<RankedTensorType>(filterType);
+
   // Reshape output and filter to the LHS and result of a "row-wise" matrix
   // multiplication.
   SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
   auto reshapedFilterType =
-      RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType());
+      RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType(),
+                            tensorFilterType.getEncoding());
   Value reshapedFilter = tensor::CollapseShapeOp::create(
       rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
 
diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
index 8627fcd2576b9..af911e3b3919a 100644
--- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
+++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
@@ -290,7 +290,7 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
 //      CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
 //      CHECK: %[[COL_TENSOR:.+]] = linalg.generic
-//           CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
+//           CHECK-SAME: [#[[MAP0]], #[[MAP1]]], {{.*}} ins(%[[INPUT]] : tensor<1x16x16x4xf32>) outs(%[[INIT_COL_TENSOR]] : tensor<1x196x36xf32>)
 //                CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
 //                CHECK: linalg.yield %{{.+}} : f32
 //      CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
@@ -327,6 +327,34 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK: func.func @conv2d_decompose_im2col_with_filter_encoding
+// CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32>,
+// CHECK-SAME: %[[FILTER:.*]]: tensor<16x3x3x4xf32, 42 : i32>,
+// CHECK-SAME: %[[OUTPUT:.*]]: tensor<1x14x14x16xf32>
+//  CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]]
+  // CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] : tensor<16x3x3x4xf32, 42 : i32> into tensor<16x36xf32, 42 : i32>
+//  CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
+//  CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
+//  CHECK: %[[COL_TENSOR:.+]] = linalg.generic {{.*}} ins(%[[INPUT]] : tensor<1x16x16x4xf32>) outs(%[[INIT_COL_TENSOR]] : tensor<1x196x36xf32>)
+//  CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
+func.func @conv2d_decompose_im2col_with_filter_encoding(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32, 42 : i32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+    %0 = linalg.conv_2d_nhwc_fhwc
+      { dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+      ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32, 42 : i32>)
+      outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+    return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
 // Check for signed extend when the input type is smaller than the accumulator type.
 
 // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@fabrizio-indirli
Copy link
Contributor Author

Thanks for the review @banach-space ! Hopefully I addressed your comments, let me know should you have more suggestions :)

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM!

I've posted one request, but that's non-blocking and just nice-to-have. No need to wait for me to re-approve if you decide to update this PR.

In the im2col decomposition, propagate the filter tensor encoding
(if specified) through the tensor.collapse_shape op, so that it
can be used by the consuming linalg.generic matmul op.

Signed-off-by: Fabrizio Indirli <Fabrizio.Indirli@arm.com>
Change-Id: I275d6fad0257d9813b9821341a6160144ae983e7
@GeorgeARM GeorgeARM merged commit 5d51e00 into llvm:main Sep 26, 2025
9 checks passed
@llvm-ci
Copy link
Collaborator

llvm-ci commented Sep 30, 2025

LLVM Buildbot has detected a new failure on builder ppc64le-flang-rhel-clang running on ppc64le-flang-rhel-test while building mlir at step 6 "test-build-unified-tree-check-flang".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/157/builds/40528

Here is the relevant piece of the build log for the reference
Step 6 (test-build-unified-tree-check-flang) failure: 1200 seconds without output running [b'ninja', b'check-flang'], attempting to kill
16.126 [4/29/0] cd /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/runtimes/runtimes-bins && /home/buildbots/llvm-external-buildbots/cmake-3.31.2/bin/cmake --build /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/runtimes/runtimes-bins/ --target libomp-mod --config Release
ninja: no work to do.
16.567 [3/7/24] Building CXX object tools/flang/unittests/Frontend/CMakeFiles/FlangFrontendTests.dir/CompilerInstanceTest.cpp.o
16.712 [3/5/26] Building CXX object tools/flang/unittests/Frontend/CMakeFiles/FlangFrontendTests.dir/CodeGenActionTest.cpp.o
16.971 [2/5/27] Building CXX object tools/flang/unittests/Optimizer/CMakeFiles/FlangOptimizerTests.dir/RTBuilder.cpp.o
18.145 [2/4/28] Linking CXX executable tools/flang/unittests/Common/FlangCommonTests
49.985 [2/3/29] Linking CXX executable tools/flang/unittests/Frontend/FlangFrontendTests
command timed out: 1200 seconds without output running [b'ninja', b'check-flang'], attempting to kill
process killed by signal 9
program finished with exit code -1
elapsedTime=1250.935917

mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Oct 3, 2025
In the im2col decomposition, propagate the filter tensor encoding (if
specified) through the tensor.collapse_shape op, so that it can be used
by the consuming linalg.generic matmul op.

Signed-off-by: Fabrizio Indirli <Fabrizio.Indirli@arm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants