-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][tensor] Fix crash when canonicalizing invalid IR #72888
[mlir][tensor] Fix crash when canonicalizing invalid IR #72888
Conversation
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThis commit fixes a crash of the canonicalizer when there are slice ops with offset/size SSA values that have a negative constant value. Such ops are invalid if they are reachable and their offsets/sizes should not be folded to static integer values. (But such ops may appear in non-reachable block.) This commit partially fixes #71150. The canonicalizer no longer crashes, but invalid IR is still being produced. Full diff: https://github.com/llvm/llvm-project/pull/72888.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 23a366036b9dd6f..c2fbaea726abcbb 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -141,8 +141,10 @@ getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
/// Returns "success" when any of the elements in `ofrs` is a constant value. In
/// that case the value is replaced by an attribute. Returns "failure" when no
-/// folding happened.
-LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs);
+/// folding happened. If `onlyNonNegative` is set, only non-negative constant
+/// values are folded.
+LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
+ bool onlyNonNegative = false);
/// Return the number of iterations for a loop with a lower bound `lb`, upper
/// bound `ub` and step `step`.
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index a114e9af126f112..931309b0c596296 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -67,8 +67,8 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
SmallVector<OpFoldResult> mixedStrides(op.getMixedStrides());
// No constant operands were folded, just return;
- if (failed(foldDynamicIndexList(mixedOffsets)) &&
- failed(foldDynamicIndexList(mixedSizes)) &&
+ if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) &&
+ failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) &&
failed(foldDynamicIndexList(mixedStrides)))
return failure();
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e469815496e1832..5bfcb35127b5267 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2361,8 +2361,8 @@ class InsertSliceOpConstantArgumentFolder final
SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
// No constant operands were folded, just return;
- if (failed(foldDynamicIndexList(mixedOffsets)) &&
- failed(foldDynamicIndexList(mixedSizes)) &&
+ if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) &&
+ failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) &&
failed(foldDynamicIndexList(mixedStrides)))
return failure();
@@ -2497,8 +2497,12 @@ struct InsertSliceOpSourceCastInserter final
srcType.getShape().end());
for (int64_t i = 0; i < srcType.getRank(); ++i) {
if (std::optional<int64_t> constInt =
- getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
+ getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
+ // Bail on invalid IR.
+ if (*constInt < 0)
+ return failure();
newSrcShape[i] = *constInt;
+ }
}
RankedTensorType newSrcType =
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 8a4ccc990331a7f..1cc3b054762a2c1 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -256,13 +256,20 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant);
}
-LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs) {
+LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
+ bool onlyNonNegative) {
bool valuesChanged = false;
for (OpFoldResult &ofr : ofrs) {
if (ofr.is<Attribute>())
continue;
- Attribute attr;
- if (matchPattern(ofr.get<Value>(), m_Constant(&attr))) {
+ APInt intVal;
+ if (matchPattern(ofr.get<Value>(), m_ConstantInt(&intVal))) {
+ if (intVal.isNegative() && onlyNonNegative)
+ continue;
+ Attribute attr;
+ bool isConstant = matchPattern(ofr.get<Value>(), m_Constant(&attr));
+ (void)isConstant;
+ assert(isConstant && "expected constant value");
ofr = attr;
valuesChanged = true;
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ea8c17640d7c143..41bfd6fe7b6eedc 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1925,3 +1925,19 @@ func.func @fold_dst_style_ops_into_unpack(%arg0 : tensor<?x?x16x64xf32>, %init :
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
// CHECK-SAME: into %[[INIT]]
// CHECK: return %[[UNPACK]]
+
+// -----
+
+// The IR in this test case in invalid. This test tests that the canonicalizer
+// does not crash.
+
+// CHECK-LABEL: func @invalid_slice_ops(
+// CHECK: %[[c:.*]] = arith.constant -5 : index
+// CHECK: tensor.extract_slice {{.*}}%[[c]]
+// CHECK: tensor.insert_slice {{.*}}%[[c]]
+func.func @invalid_slice_ops(%t: tensor<?xf32>, %t2: tensor<?xf32>) -> tensor<?xf32> {
+ %c = arith.constant -5 : index
+ %0 = tensor.extract_slice %t[0][%c][1] : tensor<?xf32> to tensor<?xf32>
+ %1 = tensor.insert_slice %0 into %t2[2][%c][1] : tensor<?xf32> into tensor<?xf32>
+ return %1 : tensor<?xf32>
+}
|
This commit fixes a crash of the canonicalizer when there are slice ops with offset/size SSA values that have a negative constant value. Such ops are invalid if they are reachable and their offsets/sizes should not be folded to static integer values. (But such ops may appear in non-reachable block.) This commit partially fixes llvm#71150. The canonicalizer no longer crashes, but invalid IR is still being produced.
f22d45b
to
6fd3559
Compare
thanks! |
This commit fixes a crash of the canonicalizer when there are slice ops with offset/size SSA values that have a negative constant value. Such ops are invalid if they are reachable and their offsets/sizes should not be folded to static integer values. (But such ops may appear in non-reachable block.)
This commit fixes #71150.