[MLIR][Tensor] Fix incorrect operand consumption in expand_shape canonicalization#180705
[MLIR][Tensor] Fix incorrect operand consumption in expand_shape canonicalization#180705keshavvinayak01 wants to merge 4 commits into
Conversation
…ization Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
|
@llvm/pr-subscribers-mlir Author: Keshav Vinayak Jha (keshavvinayak01) ChangesFixes iree-org/iree#23427 The Also added lit test in Full diff: https://github.com/llvm/llvm-project/pull/180705.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index d837947e0dc3b..34e551071f7de 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2207,25 +2207,29 @@ struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
for (uint64_t outDim : innerReassoc) {
- if (ShapedType::isStatic(newOutputShape[outDim]))
- continue;
-
- // If the cast's src type is dynamic, don't infer any of the
- // corresponding expanded dimensions. `tensor.expand_shape` requires at
- // least one of the expanded dimensions to be dynamic if the input is
- // dynamic.
- Value val = *outputIt;
- ++outputIt;
- if (ShapedType::isDynamic(castSrcShape[inputDim])) {
- dynamicOutputShape.push_back(val);
- continue;
- }
-
- APInt cst;
- if (matchPattern(val, m_ConstantInt(&cst))) {
- newOutputShape[outDim] = cst.getSExtValue();
- } else {
- dynamicOutputShape.push_back(val);
+ // If the static output shape has a dynamic dim, we must consume an
+ // operand from the input list, even if the result type is static.
+ if (expandOp.getStaticOutputShape()[outDim] == ShapedType::kDynamic) {
+ Value val = *outputIt;
+ ++outputIt;
+ if (ShapedType::isStatic(newOutputShape[outDim]))
+ continue;
+
+ // If the cast's src type is dynamic, don't infer any of the
+ // corresponding expanded dimensions. `tensor.expand_shape` requires
+ // at least one of the expanded dimensions to be dynamic if the input
+ // is dynamic.
+ if (ShapedType::isDynamic(castSrcShape[inputDim])) {
+ dynamicOutputShape.push_back(val);
+ continue;
+ }
+
+ APInt cst;
+ if (matchPattern(val, m_ConstantInt(&cst))) {
+ newOutputShape[outDim] = cst.getSExtValue();
+ } else {
+ dynamicOutputShape.push_back(val);
+ }
}
}
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 7a2d53c0c5850..5b5d1ae6c77ef 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2554,6 +2554,32 @@ func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>)
// -----
+// CHECK-LABEL: func @fold_expand_of_cast_mixed_shape
+// CHECK-SAME: %[[ARG0:.*]]: tensor<4x8xf32>
+func.func @fold_expand_of_cast_mixed_shape(%arg0: tensor<4x8xf32>) -> (index, index, index) {
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %c8 = arith.constant 8 : index
+ %0 = tensor.cast %arg0 : tensor<4x8xf32> to tensor<?x?xf32>
+ %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c1, %c4, %c8] : tensor<?x?xf32> into tensor<1x?x?xf32>
+
+ %idx0 = arith.constant 0 : index
+ %idx1 = arith.constant 1 : index
+ %idx2 = arith.constant 2 : index
+
+ %dim0 = tensor.dim %1, %idx0 : tensor<1x?x?xf32>
+ %dim1 = tensor.dim %1, %idx1 : tensor<1x?x?xf32>
+ %dim2 = tensor.dim %1, %idx2 : tensor<1x?x?xf32>
+
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[C4:.*]] = arith.constant 4 : index
+ // CHECK: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK: return %[[C1]], %[[C4]], %[[C8]]
+ return %dim0, %dim1, %dim2 : index, index, index
+}
+
+// -----
+
func.func @sink_expand_of_cast(%arg0 : tensor<?x10xf32>)
-> tensor<?x?x?xf32> {
%c1 = arith.constant 1 : index
|
|
@llvm/pr-subscribers-mlir-tensor Author: Keshav Vinayak Jha (keshavvinayak01) ChangesFixes iree-org/iree#23427 The Also added lit test in Full diff: https://github.com/llvm/llvm-project/pull/180705.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index d837947e0dc3b..34e551071f7de 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2207,25 +2207,29 @@ struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
for (uint64_t outDim : innerReassoc) {
- if (ShapedType::isStatic(newOutputShape[outDim]))
- continue;
-
- // If the cast's src type is dynamic, don't infer any of the
- // corresponding expanded dimensions. `tensor.expand_shape` requires at
- // least one of the expanded dimensions to be dynamic if the input is
- // dynamic.
- Value val = *outputIt;
- ++outputIt;
- if (ShapedType::isDynamic(castSrcShape[inputDim])) {
- dynamicOutputShape.push_back(val);
- continue;
- }
-
- APInt cst;
- if (matchPattern(val, m_ConstantInt(&cst))) {
- newOutputShape[outDim] = cst.getSExtValue();
- } else {
- dynamicOutputShape.push_back(val);
+ // If the static output shape has a dynamic dim, we must consume an
+ // operand from the input list, even if the result type is static.
+ if (expandOp.getStaticOutputShape()[outDim] == ShapedType::kDynamic) {
+ Value val = *outputIt;
+ ++outputIt;
+ if (ShapedType::isStatic(newOutputShape[outDim]))
+ continue;
+
+ // If the cast's src type is dynamic, don't infer any of the
+ // corresponding expanded dimensions. `tensor.expand_shape` requires
+ // at least one of the expanded dimensions to be dynamic if the input
+ // is dynamic.
+ if (ShapedType::isDynamic(castSrcShape[inputDim])) {
+ dynamicOutputShape.push_back(val);
+ continue;
+ }
+
+ APInt cst;
+ if (matchPattern(val, m_ConstantInt(&cst))) {
+ newOutputShape[outDim] = cst.getSExtValue();
+ } else {
+ dynamicOutputShape.push_back(val);
+ }
}
}
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 7a2d53c0c5850..5b5d1ae6c77ef 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2554,6 +2554,32 @@ func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>)
// -----
+// CHECK-LABEL: func @fold_expand_of_cast_mixed_shape
+// CHECK-SAME: %[[ARG0:.*]]: tensor<4x8xf32>
+func.func @fold_expand_of_cast_mixed_shape(%arg0: tensor<4x8xf32>) -> (index, index, index) {
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %c8 = arith.constant 8 : index
+ %0 = tensor.cast %arg0 : tensor<4x8xf32> to tensor<?x?xf32>
+ %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c1, %c4, %c8] : tensor<?x?xf32> into tensor<1x?x?xf32>
+
+ %idx0 = arith.constant 0 : index
+ %idx1 = arith.constant 1 : index
+ %idx2 = arith.constant 2 : index
+
+ %dim0 = tensor.dim %1, %idx0 : tensor<1x?x?xf32>
+ %dim1 = tensor.dim %1, %idx1 : tensor<1x?x?xf32>
+ %dim2 = tensor.dim %1, %idx2 : tensor<1x?x?xf32>
+
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[C4:.*]] = arith.constant 4 : index
+ // CHECK: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK: return %[[C1]], %[[C4]], %[[C8]]
+ return %dim0, %dim1, %dim2 : index, index, index
+}
+
+// -----
+
func.func @sink_expand_of_cast(%arg0 : tensor<?x10xf32>)
-> tensor<?x?x?xf32> {
%c1 = arith.constant 1 : index
|
hanhanW
left a comment
There was a problem hiding this comment.
I'm surprise that tensor::ExpandShapeOp does not require the number of dynamic dims in output_shape to match the number of dynamic dims in result type. @MaheshRavishankar is it intended?
I'm not sure if adding such verification breaks anything, but it should be valid. A reasonable canonicalization pattern would be replacing the dynamic value with IndexAttr (or whatever IntegerAttr) in output_shape. The types still match, so you don't need to create a new op.
But it does not match my expectation when I see the op. The better solution is trying to fix the op semantic, if it is not intended. If it is intended, it should be documented in tablegen.
@keshavvinayak01 thanks for the patch. You can ignore my other review comments until we clarify the op semantic. I was confused when I saw the IR.
|
An additional note is that maybe @keshavvinayak01 can complete the verification and see if it breaks IREE or not. (It is interesting that tensor dialect does not have codeowners. :p) |
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Fixes iree-org/iree#23427
The
ConvertToStaticExpandShapepattern inTensorOps.cppincorrectly skipped operand consumption when the corresponding result dimension was static, even if theoutput_shapeattribute specified a dynamic dimension. This led to operand iterator to go out of sync, incorrect values are read for subsequent dynamic dimensions.Also added lit test in
mlir/test/Dialect/Tensor/canonicalize.mlirto test for this particular case.