-
Couldn't load subscription status.
- Fork 15k
[mlir][memref] Fix runtime verification for memref.subview when size dimension value is 0 #164897
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][memref] Fix runtime verification for memref.subview when size dimension value is 0 #164897
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-memref Author: Hanumanth (Hanumanth04) ChangesPreviously, the runtime verification pass would insert assertion statements with conditions that always evaluate to false for semantically valid The This patch fixes the issue by making the This issue was discovered through a LiteRT model, where a dynamic shape calculation resulted in a zero-sized dimension being passed to module {
memref.global "private" constant @<!-- -->__constant_2xi32 : memref<2xi32> = dense<-1> {alignment = 64 : i64}
memref.global "private" constant @<!-- -->__constant_1xi32 : memref<1xi32> = dense<0> {alignment = 64 : i64}
func.func @<!-- -->simpleRepro(%arg0: memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
%c0 = arith.constant 0 : index
%c-1 = arith.constant -1 : index
%0 = memref.get_global @<!-- -->__constant_1xi32 : memref<1xi32>
%1 = memref.get_global @<!-- -->__constant_2xi32 : memref<2xi32>
%alloca = memref.alloca() {alignment = 64 : i64} : memref<3xi32>
%subview = memref.subview %alloca[0] [1] [1] : memref<3xi32> to memref<1xi32, strided<[1]>>
memref.copy %0, %subview : memref<1xi32> to memref<1xi32, strided<[1]>>
%subview_0 = memref.subview %alloca[1] [2] [1] : memref<3xi32> to memref<2xi32, strided<[1], offset: 1>>
memref.copy %1, %subview_0 : memref<2xi32> to memref<2xi32, strided<[1], offset: 1>>
%2 = memref.load %alloca[%c0] : memref<3xi32>
%3 = index.casts %2 : i32 to index
%4 = arith.cmpi eq, %3, %c-1 : index
%5 = arith.select %4, %c10, %3 : index
%6 = memref.load %alloca[%c1] : memref<3xi32>
%7 = index.casts %6 : i32 to index
%8 = arith.cmpi eq, %7, %c-1 : index
%9 = arith.select %8, %c4, %7 : index
%10 = memref.load %alloca[%c2] : memref<3xi32>
%11 = index.casts %10 : i32 to index
%12 = arith.cmpi eq, %11, %c-1 : index
%13 = arith.select %12, %c1, %11 : index
%subview_1 = memref.subview %arg0[0, 0, 0] [%5, %9, %13] [1, 1, 1] : memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
return %subview_1 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
}
}P.S. This is a similar issue to the one fixed for Full diff: https://github.com/llvm/llvm-project/pull/164897.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 291da1f76ca9b..1979d5b7e6310 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
@@ -273,7 +274,9 @@ struct SubViewOpInterface
Value one = arith::ConstantIndexOp::create(builder, loc, 1);
auto metadataOp =
ExtractStridedMetadataOp::create(builder, loc, subView.getSource());
- for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
+ for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
+ // Reset insertion point to before the operation for each dimension
+ builder.setInsertionPoint(subView);
Value offset = getValueOrCreateConstantIndexOp(
builder, loc, subView.getMixedOffsets()[i]);
Value size = getValueOrCreateConstantIndexOp(builder, loc,
@@ -290,6 +293,42 @@ struct SubViewOpInterface
std::to_string(i) +
" is out-of-bounds"));
+ // Only verify if size > 0
+ Value sizeIsNonZero = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::sgt, size, zero);
+
+ /*
+ * Split the current block to create the below control flow structure:
+ *
+ * ^preCondBlock:
+ * ... // offset check already done above
+ * %size_nonzero = arith.cmpi sgt, %size, %zero
+ * cf.cond_br %size_nonzero, ^sizeBoundsCheckBlock, ^afterCheckBlock
+ *
+ * ^sizeBoundsCheckBlock:
+ * %last_pos = ... // compute offset + (size-1) * stride
+ * %last_pos_ok = ... // last position bounds check
+ * cf.assert %last_pos_ok, "extract_slice runs out-of-bounds"
+ * cf.br ^afterCheckBlock
+ *
+ * ^afterCheckBlock:
+ * tensor.extract_slice ... // the original operation
+ */
+ Block *preCondBlock = builder.getBlock();
+ Block *afterCheckBlock = preCondBlock->splitBlock(subView);
+
+ // Create the block for conditional size bounds verification.
+ Block *sizeBoundsCheckBlock = builder.createBlock(
+ preCondBlock->getParent(), Region::iterator(afterCheckBlock));
+
+ // Terminate the pre-condition block with the conditional branch.
+ builder.setInsertionPointToEnd(preCondBlock);
+ cf::CondBranchOp::create(builder, loc, sizeIsNonZero,
+ sizeBoundsCheckBlock, afterCheckBlock);
+
+ // Populate the size bounds check block with lastPos verification.
+ builder.setInsertionPointToStart(sizeBoundsCheckBlock);
+
// Verify that slice does not run out-of-bounds.
Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
Value sizeMinusOneTimesStride =
@@ -303,6 +342,7 @@ struct SubViewOpInterface
generateErrorMessage(op,
"subview runs out-of-bounds along dimension " +
std::to_string(i)));
+ cf::BranchOp::create(builder, loc, afterCheckBlock);
}
}
};
diff --git a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
index 71e813c0a6300..001c435086976 100644
--- a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
@@ -38,6 +38,17 @@ func.func @subview_dynamic_rank_reduce(%memref: memref<?x4xf32>, %offset: index,
return
}
+func.func @subview_zero_size_dim(%memref: memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>,
+ %dim_0: index,
+ %dim_1: index,
+ %dim_2: index) {
+ %subview = memref.subview %memref[0, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] :
+ memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>> to
+ memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+ return
+}
+
+
func.func @main() {
%0 = arith.constant 0 : index
%1 = arith.constant 1 : index
@@ -105,6 +116,14 @@ func.func @main() {
// CHECK-NOT: ERROR: Runtime op verification failed
func.call @subview_dynamic_rank_reduce(%alloca_4_dyn, %0, %1, %0) : (memref<?x4xf32>, index, index, index) -> ()
+ %alloca_10x4x1 = memref.alloca() : memref<10x4x1xf32>
+ %alloca_10x4x1_dyn_stride = memref.cast %alloca_10x4x1 : memref<10x4x1xf32> to memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ %dim_0 = arith.constant 0 : index
+ %dim_1 = arith.constant 4 : index
+ %dim_2 = arith.constant 1 : index
+ func.call @subview_zero_size_dim(%alloca_10x4x1_dyn_stride, %dim_0, %dim_1, %dim_2)
+ : (memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, index, index, index) -> ()
return
}
|
|
Hi @matthias-springer, could you please look at this PR when you get a chance? This is similar to the fix in #164878. Thanks! |
…dimension value is 0 (llvm#164897) Previously, the runtime verification pass would insert assertion statements with conditions that always evaluate to false for semantically valid `memref.subview` operations where one of the dimensions had a size of 0. The `memref.subview` runtime verification logic was unconditionally generating checks for the position of the last element (`offset + (size - 1) * stride`). When `size` is 0, this causes the assertion condition to always be false, leading to runtime failures even though the operation is semantically valid. This patch fixes the issue by making the `lastPos` check conditional. The offset is always verified, but the endpoint check is only performed when `size > 0` to avoid generating spurious assert statements. This issue was discovered through a LiteRT model, where a dynamic shape calculation resulted in a zero-sized dimension being passed to `memref.subview`. The following is a simplified IR snippet from the model. After running the runtime verification pass, an assertion that always fails is generated because the SSA value `%5` becomes 0. ```mlir module { memref.global "private" constant @__constant_2xi32 : memref<2xi32> = dense<-1> {alignment = 64 : i64} memref.global "private" constant @__constant_1xi32 : memref<1xi32> = dense<0> {alignment = 64 : i64} func.func @simpleRepro(%arg0: memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> { %c2 = arith.constant 2 : index %c4 = arith.constant 4 : index %c1 = arith.constant 1 : index %c10 = arith.constant 10 : index %c0 = arith.constant 0 : index %c-1 = arith.constant -1 : index %0 = memref.get_global @__constant_1xi32 : memref<1xi32> %1 = memref.get_global @__constant_2xi32 : memref<2xi32> %alloca = memref.alloca() {alignment = 64 : i64} : memref<3xi32> %subview = memref.subview %alloca[0] [1] [1] : memref<3xi32> to memref<1xi32, strided<[1]>> memref.copy %0, %subview : memref<1xi32> to memref<1xi32, strided<[1]>> %subview_0 = memref.subview %alloca[1] [2] [1] : memref<3xi32> to memref<2xi32, strided<[1], offset: 1>> memref.copy %1, %subview_0 : memref<2xi32> to memref<2xi32, strided<[1], offset: 1>> %2 = memref.load %alloca[%c0] : memref<3xi32> %3 = index.casts %2 : i32 to index %4 = arith.cmpi eq, %3, %c-1 : index %5 = arith.select %4, %c10, %3 : index %6 = memref.load %alloca[%c1] : memref<3xi32> %7 = index.casts %6 : i32 to index %8 = arith.cmpi eq, %7, %c-1 : index %9 = arith.select %8, %c4, %7 : index %10 = memref.load %alloca[%c2] : memref<3xi32> %11 = index.casts %10 : i32 to index %12 = arith.cmpi eq, %11, %c-1 : index %13 = arith.select %12, %c1, %11 : index %subview_1 = memref.subview %arg0[0, 0, 0] [%5, %9, %13] [1, 1, 1] : memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> return %subview_1 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> } } ``` P.S. This is a similar issue to the one fixed for `tensor.extract_slice` in llvm#164878 --------- Co-authored-by: Hanumanth Hanumantharayappa <hhanuman@ah-hhanuman-l.dhcp.mathworks.com>
Previously, the runtime verification pass would insert assertion statements with conditions that always evaluate to false for semantically valid
memref.subviewoperations where one of the dimensions had a size of 0.The
memref.subviewruntime verification logic was unconditionally generating checks for the position of the last element (offset + (size - 1) * stride). Whensizeis 0, this causes the assertion condition to always be false, leading to runtime failures even though the operation is semantically valid.This patch fixes the issue by making the
lastPoscheck conditional. The offset is always verified, but the endpoint check is only performed whensize > 0to avoid generating spurious assert statements.This issue was discovered through a LiteRT model, where a dynamic shape calculation resulted in a zero-sized dimension being passed to
memref.subview. The following is a simplified IR snippet from the model. After running the runtime verification pass, an assertion that always fails is generated because the SSA value%5becomes 0.P.S. This is a similar issue to the one fixed for
tensor.extract_slicein #164878