-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][tensor] Fix insert and extract slice canonicalization #72885
Conversation
@llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir-tensor Author: Rik Huijzer (rikhuijzer) ChangesFixes #71150 by checking for non-negative dimensions during the Full diff: https://github.com/llvm/llvm-project/pull/72885.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 23a366036b9dd6f..9e39d81e5c4f96a 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -128,6 +128,12 @@ std::pair<ArrayAttr, SmallVector<Value>>
decomposeMixedValues(Builder &b,
const SmallVectorImpl<OpFoldResult> &mixedValues);
+/// Helper function to check whether the dimensions are non-negative.
+///
+/// This is used to re-check whether dimensions are still non-negative after
+/// constant folding the dynamic dimensions.
+bool hasNegativeDimension(SmallVector<int64_t> values);
+
/// Helper to sort `values` according to matching `keys`.
SmallVector<Value>
getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values,
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a2fc954ad07fae8..dd75ed2500306b2 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2621,17 +2621,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
-
- // If one of the offsets or sizes is invalid, fail the canonicalization.
- // These checks also occur in the verifier, but they are needed here
- // because some dynamic dimensions may have been constant folded.
- for (int64_t offset : staticOffsets)
- if (offset < 0 && !ShapedType::isDynamic(offset))
- return {};
- for (int64_t size : staticSizes)
- if (size < 0 && !ShapedType::isDynamic(size))
- return {};
-
+ if (hasNegativeDimension(staticOffsets))
+ return {};
+ if (hasNegativeDimension(staticSizes))
+ return {};
return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
staticSizes, staticStrides);
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e469815496e1832..986e40a2e4eb34f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1259,13 +1259,8 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
SmallVector<int64_t> newShape;
operandsAndShape(resultType, dynamicExtents, newOperands, newShape);
- for (int64_t newdim : newShape) {
- // This check also occurs in the verifier, but we need it here too
- // since intermediate passes may have replaced some dynamic dimensions
- // by constants.
- if (newdim < 0 && !ShapedType::isDynamic(newdim))
+ if (hasNegativeDimension(newShape))
return failure();
- }
if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
return failure();
@@ -1801,6 +1796,10 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
+ if (hasNegativeDimension(staticOffsets))
+ return {};
+ if (hasNegativeDimension(staticSizes))
+ return {};
return ExtractSliceOp::inferCanonicalRankReducedResultType(
desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
staticStrides);
@@ -2370,6 +2369,8 @@ class InsertSliceOpConstantArgumentFolder final
auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
mixedOffsets, mixedSizes, mixedStrides);
+ if (!sourceType)
+ return failure();
Value toInsert = insertSliceOp.getSource();
if (sourceType != insertSliceOp.getSourceType()) {
OpBuilder::InsertionGuard g(rewriter);
@@ -2500,6 +2501,8 @@ struct InsertSliceOpSourceCastInserter final
getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
newSrcShape[i] = *constInt;
}
+ if (hasNegativeDimension(newSrcShape))
+ return failure();
RankedTensorType newSrcType =
RankedTensorType::get(newSrcShape, srcType.getElementType());
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 8a4ccc990331a7f..5d777ad74e9e852 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -200,6 +200,12 @@ decomposeMixedValues(Builder &b,
return {b.getI64ArrayAttr(staticValues), dynamicValues};
}
+bool hasNegativeDimension(SmallVector<int64_t> values) {
+ return llvm::any_of(values, [](int64_t value) {
+ return !ShapedType::isDynamic(value) && value < 0;
+ });
+}
+
/// Helper to sort `values` according to matching `keys`.
template <typename K, typename V>
static SmallVector<V>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ea8c17640d7c143..1c0a2e868475f24 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1102,6 +1102,30 @@ func.func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>)
// -----
+func.func @no_fold_extract_slice_negative_offset(%arg0: tensor<8xf32>) -> tensor<?xf32> {
+ %c-1 = arith.constant -1 : index
+ %e = tensor.extract_slice %arg0[1] [%c-1] [1] : tensor<8xf32> to tensor<?xf32>
+ return %e : tensor<?xf32>
+}
+// CHECK-LABEL: func @no_fold_extract_slice_negative_offset
+// CHECK: tensor.extract_slice
+
+// -----
+
+func.func @no_fold_insert_slice_cast_inserter_negative_offset() -> tensor<?xf32> {
+ %c = arith.constant 0 : index
+ %const = tensor.empty(%c) : tensor<?xf32>
+ %insert_val = tensor.empty(%c) : tensor<?xf32>
+ %c-1 = arith.constant -1 : index
+ %inserted = tensor.insert_slice %insert_val into %const[0][%c-1][1] : tensor<?xf32> into tensor<?xf32>
+ return %inserted : tensor<?xf32>
+}
+// CHECK-LABEL: func @no_fold_insert_slice_cast_inserter_negative_offset
+// CHECK: %[[CAST:.*]] = tensor.cast
+// CHECK: tensor.insert_slice %[[CAST:.+]]
+
+// -----
+
func.func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> {
%c0 = arith.constant dense<42> : tensor<2x8xi32>
%0 = tensor.expand_shape %c0 [[0], [1, 2]]
|
Seems like @matthias-springer and I were working on the same issue at exactly the right time even though the issue was reported 2 weeks ago (#72888) 🤷♂️ . Closing this PR. |
Fixes #71150 by checking for non-negative dimensions during the
InsertSliceOpSourceCastInserter
andExtractSliceOp
canonicalizations. Also refactored the logic into one function so that we don't have to write a comment each time.