Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][tensor] Fix insert and extract slice canonicalization #72885

Closed
wants to merge 3 commits into from
Closed

[mlir][tensor] Fix insert and extract slice canonicalization #72885

wants to merge 3 commits into from

Conversation

rikhuijzer
Copy link
Member

Fixes #71150 by checking for non-negative dimensions during the InsertSliceOpSourceCastInserter and ExtractSliceOp canonicalizations. Also refactored the logic into one function so that we don't have to write a comment each time.

@llvmbot
Copy link
Collaborator

llvmbot commented Nov 20, 2023

@llvm/pr-subscribers-mlir-memref
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tensor

Author: Rik Huijzer (rikhuijzer)

Changes

Fixes #71150 by checking for non-negative dimensions during the InsertSliceOpSourceCastInserter and ExtractSliceOp canonicalizations. Also refactored the logic into one function so that we don't have to write a comment each time.


Full diff: https://github.com/llvm/llvm-project/pull/72885.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+6)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+4-11)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+9-6)
  • (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+6)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+24)
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 23a366036b9dd6f..9e39d81e5c4f96a 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -128,6 +128,12 @@ std::pair<ArrayAttr, SmallVector<Value>>
 decomposeMixedValues(Builder &b,
                      const SmallVectorImpl<OpFoldResult> &mixedValues);
 
+/// Helper function to check whether the dimensions are non-negative.
+///
+/// This is used to re-check whether dimensions are still non-negative after
+/// constant folding the dynamic dimensions.
+bool hasNegativeDimension(SmallVector<int64_t> values);
+
 /// Helper to sort `values` according to matching `keys`.
 SmallVector<Value>
 getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values,
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a2fc954ad07fae8..dd75ed2500306b2 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2621,17 +2621,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
-
-  // If one of the offsets or sizes is invalid, fail the canonicalization.
-  // These checks also occur in the verifier, but they are needed here
-  // because some dynamic dimensions may have been constant folded.
-  for (int64_t offset : staticOffsets)
-    if (offset < 0 && !ShapedType::isDynamic(offset))
-      return {};
-  for (int64_t size : staticSizes)
-    if (size < 0 && !ShapedType::isDynamic(size))
-      return {};
-
+  if (hasNegativeDimension(staticOffsets))
+    return {};
+  if (hasNegativeDimension(staticSizes))
+    return {};
   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
                                     staticSizes, staticStrides);
 }
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e469815496e1832..986e40a2e4eb34f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1259,13 +1259,8 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
     SmallVector<int64_t> newShape;
     operandsAndShape(resultType, dynamicExtents, newOperands, newShape);
 
-    for (int64_t newdim : newShape) {
-      // This check also occurs in the verifier, but we need it here too
-      // since intermediate passes may have replaced some dynamic dimensions
-      // by constants.
-      if (newdim < 0 && !ShapedType::isDynamic(newdim))
+    if (hasNegativeDimension(newShape))
         return failure();
-    }
 
     if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
       return failure();
@@ -1801,6 +1796,10 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
+  if (hasNegativeDimension(staticOffsets))
+    return {};
+  if (hasNegativeDimension(staticSizes))
+    return {};
   return ExtractSliceOp::inferCanonicalRankReducedResultType(
       desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
       staticStrides);
@@ -2370,6 +2369,8 @@ class InsertSliceOpConstantArgumentFolder final
     auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
         insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
         mixedOffsets, mixedSizes, mixedStrides);
+    if (!sourceType)
+      return failure();
     Value toInsert = insertSliceOp.getSource();
     if (sourceType != insertSliceOp.getSourceType()) {
       OpBuilder::InsertionGuard g(rewriter);
@@ -2500,6 +2501,8 @@ struct InsertSliceOpSourceCastInserter final
               getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
         newSrcShape[i] = *constInt;
     }
+    if (hasNegativeDimension(newSrcShape))
+      return failure();
 
     RankedTensorType newSrcType =
         RankedTensorType::get(newSrcShape, srcType.getElementType());
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 8a4ccc990331a7f..5d777ad74e9e852 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -200,6 +200,12 @@ decomposeMixedValues(Builder &b,
   return {b.getI64ArrayAttr(staticValues), dynamicValues};
 }
 
+bool hasNegativeDimension(SmallVector<int64_t> values) {
+  return llvm::any_of(values, [](int64_t value) {
+    return !ShapedType::isDynamic(value) && value < 0;
+  });
+}
+
 /// Helper to sort `values` according to matching `keys`.
 template <typename K, typename V>
 static SmallVector<V>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ea8c17640d7c143..1c0a2e868475f24 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1102,6 +1102,30 @@ func.func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>)
 
 // -----
 
+func.func @no_fold_extract_slice_negative_offset(%arg0: tensor<8xf32>) -> tensor<?xf32> {
+  %c-1 = arith.constant -1 : index
+  %e = tensor.extract_slice %arg0[1] [%c-1] [1] : tensor<8xf32> to tensor<?xf32>
+  return %e : tensor<?xf32>
+}
+// CHECK-LABEL: func @no_fold_extract_slice_negative_offset
+// CHECK: tensor.extract_slice
+
+// -----
+
+func.func @no_fold_insert_slice_cast_inserter_negative_offset() -> tensor<?xf32> {
+  %c = arith.constant 0 : index
+  %const = tensor.empty(%c) : tensor<?xf32>
+  %insert_val = tensor.empty(%c) : tensor<?xf32>
+  %c-1 = arith.constant -1 : index
+  %inserted = tensor.insert_slice %insert_val into %const[0][%c-1][1] : tensor<?xf32> into tensor<?xf32>
+  return %inserted : tensor<?xf32>
+}
+// CHECK-LABEL: func @no_fold_insert_slice_cast_inserter_negative_offset
+// CHECK: %[[CAST:.*]] = tensor.cast
+// CHECK: tensor.insert_slice %[[CAST:.+]]
+
+// -----
+
 func.func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> {
   %c0 = arith.constant dense<42> : tensor<2x8xi32>
   %0 = tensor.expand_shape %c0 [[0], [1, 2]]

@rikhuijzer
Copy link
Member Author

Seems like @matthias-springer and I were working on the same issue at exactly the right time even though the issue was reported 2 weeks ago (#72888) 🤷‍♂️ .

Closing this PR.

@rikhuijzer rikhuijzer closed this Nov 21, 2023
@rikhuijzer rikhuijzer deleted the rh/negative-extract-slice-canon branch December 2, 2023 13:49
@rikhuijzer rikhuijzer restored the rh/negative-extract-slice-canon branch December 2, 2023 15:08
rikhuijzer added a commit that referenced this pull request Dec 6, 2023
This PR fixes #73383 and is
another shot at the refactoring proposed in
#72885.

---------

Co-authored-by: Kai Sasaki <lewuathe@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
2 participants