Skip to content

[MLIR][Tensor] Fix incorrect operand consumption in expand_shape canonicalization#180705

Closed
keshavvinayak01 wants to merge 4 commits into
llvm:mainfrom
keshavvinayak01:users/keshavvinayak01/expandshape-folder-fix
Closed

[MLIR][Tensor] Fix incorrect operand consumption in expand_shape canonicalization#180705
keshavvinayak01 wants to merge 4 commits into
llvm:mainfrom
keshavvinayak01:users/keshavvinayak01/expandshape-folder-fix

Conversation

@keshavvinayak01
Copy link
Copy Markdown
Contributor

Fixes iree-org/iree#23427

The ConvertToStaticExpandShape pattern in TensorOps.cpp incorrectly skipped operand consumption when the corresponding result dimension was static, even if the output_shape attribute specified a dynamic dimension. This led to operand iterator to go out of sync, incorrect values are read for subsequent dynamic dimensions.

Also added lit test in mlir/test/Dialect/Tensor/canonicalize.mlir to test for this particular case.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Feb 10, 2026

✅ With the latest revision this PR passed the C/C++ code formatter.

Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
@keshavvinayak01 keshavvinayak01 marked this pull request as ready for review February 10, 2026 11:46
@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Feb 10, 2026

@llvm/pr-subscribers-mlir

Author: Keshav Vinayak Jha (keshavvinayak01)

Changes

Fixes iree-org/iree#23427

The ConvertToStaticExpandShape pattern in TensorOps.cpp incorrectly skipped operand consumption when the corresponding result dimension was static, even if the output_shape attribute specified a dynamic dimension. This led to operand iterator to go out of sync, incorrect values are read for subsequent dynamic dimensions.

Also added lit test in mlir/test/Dialect/Tensor/canonicalize.mlir to test for this particular case.


Full diff: https://github.com/llvm/llvm-project/pull/180705.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+23-19)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+26)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index d837947e0dc3b..34e551071f7de 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2207,25 +2207,29 @@ struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
 
     for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
       for (uint64_t outDim : innerReassoc) {
-        if (ShapedType::isStatic(newOutputShape[outDim]))
-          continue;
-
-        // If the cast's src type is dynamic, don't infer any of the
-        // corresponding expanded dimensions. `tensor.expand_shape` requires at
-        // least one of the expanded dimensions to be dynamic if the input is
-        // dynamic.
-        Value val = *outputIt;
-        ++outputIt;
-        if (ShapedType::isDynamic(castSrcShape[inputDim])) {
-          dynamicOutputShape.push_back(val);
-          continue;
-        }
-
-        APInt cst;
-        if (matchPattern(val, m_ConstantInt(&cst))) {
-          newOutputShape[outDim] = cst.getSExtValue();
-        } else {
-          dynamicOutputShape.push_back(val);
+        // If the static output shape has a dynamic dim, we must consume an
+        // operand from the input list, even if the result type is static.
+        if (expandOp.getStaticOutputShape()[outDim] == ShapedType::kDynamic) {
+          Value val = *outputIt;
+          ++outputIt;
+          if (ShapedType::isStatic(newOutputShape[outDim]))
+            continue;
+
+          // If the cast's src type is dynamic, don't infer any of the
+          // corresponding expanded dimensions. `tensor.expand_shape` requires
+          // at least one of the expanded dimensions to be dynamic if the input
+          // is dynamic.
+          if (ShapedType::isDynamic(castSrcShape[inputDim])) {
+            dynamicOutputShape.push_back(val);
+            continue;
+          }
+
+          APInt cst;
+          if (matchPattern(val, m_ConstantInt(&cst))) {
+            newOutputShape[outDim] = cst.getSExtValue();
+          } else {
+            dynamicOutputShape.push_back(val);
+          }
         }
       }
     }
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 7a2d53c0c5850..5b5d1ae6c77ef 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2554,6 +2554,32 @@ func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>)
 
 // -----
 
+// CHECK-LABEL: func @fold_expand_of_cast_mixed_shape
+// CHECK-SAME: %[[ARG0:.*]]: tensor<4x8xf32>
+func.func @fold_expand_of_cast_mixed_shape(%arg0: tensor<4x8xf32>) -> (index, index, index) {
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c8 = arith.constant 8 : index
+  %0 = tensor.cast %arg0 : tensor<4x8xf32> to tensor<?x?xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c1, %c4, %c8] : tensor<?x?xf32> into tensor<1x?x?xf32>
+
+  %idx0 = arith.constant 0 : index
+  %idx1 = arith.constant 1 : index
+  %idx2 = arith.constant 2 : index
+
+  %dim0 = tensor.dim %1, %idx0 : tensor<1x?x?xf32>
+  %dim1 = tensor.dim %1, %idx1 : tensor<1x?x?xf32>
+  %dim2 = tensor.dim %1, %idx2 : tensor<1x?x?xf32>
+
+  // CHECK: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK: %[[C4:.*]] = arith.constant 4 : index
+  // CHECK: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK: return %[[C1]], %[[C4]], %[[C8]]
+  return %dim0, %dim1, %dim2 : index, index, index
+}
+
+// -----
+
 func.func @sink_expand_of_cast(%arg0 : tensor<?x10xf32>)
     -> tensor<?x?x?xf32> {
   %c1 = arith.constant 1 : index

@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Feb 10, 2026

@llvm/pr-subscribers-mlir-tensor

Author: Keshav Vinayak Jha (keshavvinayak01)

Changes

Fixes iree-org/iree#23427

The ConvertToStaticExpandShape pattern in TensorOps.cpp incorrectly skipped operand consumption when the corresponding result dimension was static, even if the output_shape attribute specified a dynamic dimension. This led to operand iterator to go out of sync, incorrect values are read for subsequent dynamic dimensions.

Also added lit test in mlir/test/Dialect/Tensor/canonicalize.mlir to test for this particular case.


Full diff: https://github.com/llvm/llvm-project/pull/180705.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+23-19)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+26)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index d837947e0dc3b..34e551071f7de 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2207,25 +2207,29 @@ struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
 
     for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
       for (uint64_t outDim : innerReassoc) {
-        if (ShapedType::isStatic(newOutputShape[outDim]))
-          continue;
-
-        // If the cast's src type is dynamic, don't infer any of the
-        // corresponding expanded dimensions. `tensor.expand_shape` requires at
-        // least one of the expanded dimensions to be dynamic if the input is
-        // dynamic.
-        Value val = *outputIt;
-        ++outputIt;
-        if (ShapedType::isDynamic(castSrcShape[inputDim])) {
-          dynamicOutputShape.push_back(val);
-          continue;
-        }
-
-        APInt cst;
-        if (matchPattern(val, m_ConstantInt(&cst))) {
-          newOutputShape[outDim] = cst.getSExtValue();
-        } else {
-          dynamicOutputShape.push_back(val);
+        // If the static output shape has a dynamic dim, we must consume an
+        // operand from the input list, even if the result type is static.
+        if (expandOp.getStaticOutputShape()[outDim] == ShapedType::kDynamic) {
+          Value val = *outputIt;
+          ++outputIt;
+          if (ShapedType::isStatic(newOutputShape[outDim]))
+            continue;
+
+          // If the cast's src type is dynamic, don't infer any of the
+          // corresponding expanded dimensions. `tensor.expand_shape` requires
+          // at least one of the expanded dimensions to be dynamic if the input
+          // is dynamic.
+          if (ShapedType::isDynamic(castSrcShape[inputDim])) {
+            dynamicOutputShape.push_back(val);
+            continue;
+          }
+
+          APInt cst;
+          if (matchPattern(val, m_ConstantInt(&cst))) {
+            newOutputShape[outDim] = cst.getSExtValue();
+          } else {
+            dynamicOutputShape.push_back(val);
+          }
         }
       }
     }
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 7a2d53c0c5850..5b5d1ae6c77ef 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2554,6 +2554,32 @@ func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>)
 
 // -----
 
+// CHECK-LABEL: func @fold_expand_of_cast_mixed_shape
+// CHECK-SAME: %[[ARG0:.*]]: tensor<4x8xf32>
+func.func @fold_expand_of_cast_mixed_shape(%arg0: tensor<4x8xf32>) -> (index, index, index) {
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c8 = arith.constant 8 : index
+  %0 = tensor.cast %arg0 : tensor<4x8xf32> to tensor<?x?xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c1, %c4, %c8] : tensor<?x?xf32> into tensor<1x?x?xf32>
+
+  %idx0 = arith.constant 0 : index
+  %idx1 = arith.constant 1 : index
+  %idx2 = arith.constant 2 : index
+
+  %dim0 = tensor.dim %1, %idx0 : tensor<1x?x?xf32>
+  %dim1 = tensor.dim %1, %idx1 : tensor<1x?x?xf32>
+  %dim2 = tensor.dim %1, %idx2 : tensor<1x?x?xf32>
+
+  // CHECK: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK: %[[C4:.*]] = arith.constant 4 : index
+  // CHECK: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK: return %[[C1]], %[[C4]], %[[C8]]
+  return %dim0, %dim1, %dim2 : index, index, index
+}
+
+// -----
+
 func.func @sink_expand_of_cast(%arg0 : tensor<?x10xf32>)
     -> tensor<?x?x?xf32> {
   %c1 = arith.constant 1 : index

Copy link
Copy Markdown
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprise that tensor::ExpandShapeOp does not require the number of dynamic dims in output_shape to match the number of dynamic dims in result type. @MaheshRavishankar is it intended?

I'm not sure if adding such verification breaks anything, but it should be valid. A reasonable canonicalization pattern would be replacing the dynamic value with IndexAttr (or whatever IntegerAttr) in output_shape. The types still match, so you don't need to create a new op.

But it does not match my expectation when I see the op. The better solution is trying to fix the op semantic, if it is not intended. If it is intended, it should be documented in tablegen.

@keshavvinayak01 thanks for the patch. You can ignore my other review comments until we clarify the op semantic. I was confused when I saw the IR.

Comment thread mlir/test/Dialect/Tensor/canonicalize.mlir
Comment thread mlir/lib/Dialect/Tensor/IR/TensorOps.cpp Outdated
@hanhanW
Copy link
Copy Markdown
Contributor

hanhanW commented Feb 10, 2026

An additional note is that maybe @keshavvinayak01 can complete the verification and see if it breaks IREE or not.

(It is interesting that tensor dialect does not have codeowners. :p)

Comment thread mlir/lib/Dialect/Tensor/IR/TensorOps.cpp Outdated
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

tensor.expand_shape on dynamic tensor produces wrong runtime dimensions

4 participants