Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][tensor] Fix crash when canonicalizing invalid IR #72888

Merged

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Nov 20, 2023

This commit fixes a crash of the canonicalizer when there are slice ops with offset/size SSA values that have a negative constant value. Such ops are invalid if they are reachable and their offsets/sizes should not be folded to static integer values. (But such ops may appear in non-reachable block.)

This commit fixes #71150.

@llvmbot
Copy link
Collaborator

llvmbot commented Nov 20, 2023

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

This commit fixes a crash of the canonicalizer when there are slice ops with offset/size SSA values that have a negative constant value. Such ops are invalid if they are reachable and their offsets/sizes should not be folded to static integer values. (But such ops may appear in non-reachable block.)

This commit partially fixes #71150. The canonicalizer no longer crashes, but invalid IR is still being produced.


Full diff: https://github.com/llvm/llvm-project/pull/72888.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+4-2)
  • (modified) mlir/include/mlir/Interfaces/ViewLikeInterface.h (+2-2)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+7-3)
  • (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+10-3)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+16)
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 23a366036b9dd6f..c2fbaea726abcbb 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -141,8 +141,10 @@ getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
 
 /// Returns "success" when any of the elements in `ofrs` is a constant value. In
 /// that case the value is replaced by an attribute. Returns "failure" when no
-/// folding happened.
-LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs);
+/// folding happened. If `onlyNonNegative` is set, only non-negative constant
+/// values are folded.
+LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
+                                   bool onlyNonNegative = false);
 
 /// Return the number of iterations for a loop with a lower bound `lb`, upper
 /// bound `ub` and step `step`.
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index a114e9af126f112..931309b0c596296 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -67,8 +67,8 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
     SmallVector<OpFoldResult> mixedStrides(op.getMixedStrides());
 
     // No constant operands were folded, just return;
-    if (failed(foldDynamicIndexList(mixedOffsets)) &&
-        failed(foldDynamicIndexList(mixedSizes)) &&
+    if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) &&
+        failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) &&
         failed(foldDynamicIndexList(mixedStrides)))
       return failure();
 
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e469815496e1832..5bfcb35127b5267 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2361,8 +2361,8 @@ class InsertSliceOpConstantArgumentFolder final
     SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
 
     // No constant operands were folded, just return;
-    if (failed(foldDynamicIndexList(mixedOffsets)) &&
-        failed(foldDynamicIndexList(mixedSizes)) &&
+    if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) &&
+        failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) &&
         failed(foldDynamicIndexList(mixedStrides)))
       return failure();
 
@@ -2497,8 +2497,12 @@ struct InsertSliceOpSourceCastInserter final
                                      srcType.getShape().end());
     for (int64_t i = 0; i < srcType.getRank(); ++i) {
       if (std::optional<int64_t> constInt =
-              getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
+              getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
+        // Bail on invalid IR.
+        if (*constInt < 0)
+          return failure();
         newSrcShape[i] = *constInt;
+      }
     }
 
     RankedTensorType newSrcType =
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 8a4ccc990331a7f..1cc3b054762a2c1 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -256,13 +256,20 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
   return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant);
 }
 
-LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs) {
+LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
+                                   bool onlyNonNegative) {
   bool valuesChanged = false;
   for (OpFoldResult &ofr : ofrs) {
     if (ofr.is<Attribute>())
       continue;
-    Attribute attr;
-    if (matchPattern(ofr.get<Value>(), m_Constant(&attr))) {
+    APInt intVal;
+    if (matchPattern(ofr.get<Value>(), m_ConstantInt(&intVal))) {
+      if (intVal.isNegative() && onlyNonNegative)
+        continue;
+      Attribute attr;
+      bool isConstant = matchPattern(ofr.get<Value>(), m_Constant(&attr));
+      (void)isConstant;
+      assert(isConstant && "expected constant value");
       ofr = attr;
       valuesChanged = true;
     }
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ea8c17640d7c143..41bfd6fe7b6eedc 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1925,3 +1925,19 @@ func.func @fold_dst_style_ops_into_unpack(%arg0 : tensor<?x?x16x64xf32>, %init :
 //       CHECK:   %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
 //  CHECK-SAME:       into %[[INIT]]
 //       CHECK:   return %[[UNPACK]]
+
+// -----
+
+// The IR in this test case in invalid. This test tests that the canonicalizer
+// does not crash.
+
+// CHECK-LABEL: func @invalid_slice_ops(
+//       CHECK:   %[[c:.*]] = arith.constant -5 : index
+//       CHECK:   tensor.extract_slice {{.*}}%[[c]]
+//       CHECK:   tensor.insert_slice {{.*}}%[[c]]
+func.func @invalid_slice_ops(%t: tensor<?xf32>, %t2: tensor<?xf32>) -> tensor<?xf32> {
+  %c = arith.constant -5 : index
+  %0 = tensor.extract_slice %t[0][%c][1] : tensor<?xf32> to tensor<?xf32>
+  %1 = tensor.insert_slice %0 into %t2[2][%c][1] : tensor<?xf32> into tensor<?xf32>
+  return %1 : tensor<?xf32>
+}

This commit fixes a crash of the canonicalizer when there are slice ops with offset/size SSA values that have a negative constant value. Such ops are invalid if they are reachable and their offsets/sizes should not be folded to static integer values. (But such ops may appear in non-reachable block.)

This commit partially fixes llvm#71150. The canonicalizer no longer crashes, but invalid IR is still being produced.
@matthias-springer matthias-springer merged commit 68386a7 into llvm:main Nov 21, 2023
2 of 3 checks passed
@nicolasvasilache
Copy link
Contributor

thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
4 participants